mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Remove rewind, use clip tokenizer
This commit is contained in:
@@ -1,528 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Inference script for ReWiND Reward Model.
|
||||
|
||||
This script loads a trained ReWiND model and runs inference on a dataset episode,
|
||||
generating visualizations of the predicted task progression over time.
|
||||
|
||||
Example usage:
|
||||
python scripts/visualize_rewind_predictions.py \
|
||||
--model-id username/rewind-model \
|
||||
--dataset-repo lerobot/aloha_sim_insertion_human \
|
||||
--episode-index 0 \
|
||||
--output-dir outputs/rewind_viz \
|
||||
--task-description "insert the peg into the socket"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.gridspec as gridspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run ReWiND inference and visualize predictions")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace model ID or local path to trained ReWiND model"
|
||||
)
|
||||
|
||||
# Dataset arguments
|
||||
parser.add_argument(
|
||||
"--dataset-repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode-index",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Index of the episode to visualize (default: 0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-description",
|
||||
type=str,
|
||||
default="perform the task",
|
||||
help="Task description for the reward model (default: 'perform the task')"
|
||||
)
|
||||
|
||||
# Output arguments
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="outputs/rewind_inference",
|
||||
help="Directory to save visualization outputs (default: outputs/rewind_inference)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Key for images in dataset (e.g., observation.images.image for jaco_play). If not specified, uses model config's image_key"
|
||||
)
|
||||
|
||||
# Visualization options
|
||||
parser.add_argument(
|
||||
"--show-frames",
|
||||
action="store_true",
|
||||
help="Include sample frames in the visualization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-sample-frames",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of sample frames to show (default: 8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--figsize",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[12, 6],
|
||||
help="Figure size as width height (default: 12 6)"
|
||||
)
|
||||
|
||||
# Device
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Device to run inference on (cuda/cpu, default: auto-detect)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_episode_data(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
image_key: str
|
||||
) -> tuple[np.ndarray, int, int, str]:
|
||||
"""
|
||||
Load all frames from a specific episode.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
episode_index: Index of the episode to load
|
||||
image_key: Key for accessing images in the dataset
|
||||
|
||||
Returns:
|
||||
Tuple of (frames, start_index, end_index, task_description)
|
||||
"""
|
||||
# Get episode boundaries
|
||||
episode_data = dataset.meta.episodes
|
||||
start_idx = episode_data["dataset_from_index"][episode_index]
|
||||
end_idx = episode_data["dataset_to_index"][episode_index]
|
||||
|
||||
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
|
||||
|
||||
# Get task description from the dataset if available
|
||||
task_description = None
|
||||
first_item = dataset[start_idx]
|
||||
if "task" in first_item:
|
||||
task_description = first_item["task"]
|
||||
print(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
|
||||
|
||||
# Load all frames from the episode
|
||||
frames = []
|
||||
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
|
||||
item = dataset[idx]
|
||||
# Get image from the item
|
||||
img = item[image_key]
|
||||
|
||||
# Convert to numpy if needed
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
|
||||
# Handle different image formats (C, H, W) or (H, W, C)
|
||||
if img.shape[0] in [1, 3]: # Channel first
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
# Convert to uint8 if needed
|
||||
if img.dtype != np.uint8:
|
||||
if img.max() <= 1.0:
|
||||
img = (img * 255).astype(np.uint8)
|
||||
else:
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
frames.append(img)
|
||||
|
||||
frames = np.array(frames)
|
||||
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
|
||||
|
||||
return frames, start_idx, end_idx, task_description
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(
|
||||
model: ReWiNDRewardModel,
|
||||
frames: np.ndarray,
|
||||
task_description: str,
|
||||
batch_size: int = 32
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Run ReWiND inference on video frames using the original ReWiND approach.
|
||||
|
||||
This function creates video slices for all frames at once (similar to the original
|
||||
metaworld_label_reward.py), where each slice contains frames from start up to that point.
|
||||
|
||||
Progress Normalization (from original ReWiND dataset.py):
|
||||
- Training: progress = [1, 2, ..., N] / remaining_length
|
||||
where remaining_length = episode_end - sequence_start
|
||||
- Inference: Starting from frame 0, remaining_length = total_episode_length
|
||||
So expected progress for frame i = (i + 1) / total_episode_length
|
||||
|
||||
This function computes both:
|
||||
1. Model predictions (what the model actually predicts)
|
||||
2. Expected progress (ground truth based on frame position)
|
||||
|
||||
Args:
|
||||
model: ReWiND model
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
task_description: Task description text
|
||||
batch_size: Batch size for processing slices
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Model predictions for each frame (num_frames,)
|
||||
- Expected progress for each frame (num_frames,)
|
||||
"""
|
||||
total_frames = len(frames)
|
||||
|
||||
logger.info("Encoding video frames with DINO...")
|
||||
video_embeddings = model.encode_images(frames)
|
||||
|
||||
logger.info("Encoding task description with MiniLM...")
|
||||
text_embedding = model.encode_text(task_description)
|
||||
|
||||
logger.info("Creating video slices (original ReWiND approach)...")
|
||||
# Convert to tensors
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
||||
|
||||
# Create video slices: for each frame i, create a sequence of frames [0:i+1]
|
||||
# This matches the original ReWiND inference approach
|
||||
video_slices = []
|
||||
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# Slice from start to current frame (inclusive)
|
||||
video_slice = video_embeddings[:i + 1]
|
||||
|
||||
# Pad or subsample to max_length
|
||||
if model.config.subsample_video:
|
||||
video_slice = model.padding_video(video_slice, model.config.max_length)
|
||||
|
||||
video_slices.append(video_slice)
|
||||
|
||||
video_slices = torch.stack(video_slices) # (num_frames, max_length, 768)
|
||||
|
||||
# Create last_index_mask to extract the relevant prediction for each slice
|
||||
# For slice i, the last valid frame is at position min(i, max_length-1)
|
||||
max_length = model.config.max_length
|
||||
last_index_mask = torch.zeros((len(video_slices), max_length), dtype=torch.bool)
|
||||
|
||||
for i in range(len(video_slices)):
|
||||
last_frame_idx = min(i, max_length - 1)
|
||||
last_index_mask[i, last_frame_idx] = 1
|
||||
|
||||
logger.info("Running ReWiND inference on all slices...")
|
||||
# Process in batches
|
||||
all_progress = []
|
||||
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
|
||||
batch_video = video_slices[i:i + batch_size].to(model.device)
|
||||
batch_mask = last_index_mask[i:i + batch_size].to(model.device)
|
||||
batch_size_actual = batch_video.shape[0]
|
||||
|
||||
# Replicate text embedding for batch
|
||||
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
|
||||
|
||||
# Get predictions for all frames in batch
|
||||
progress_preds = model.rewind_transformer(batch_video, batch_text) # (batch, max_length, 1)
|
||||
progress_preds = progress_preds.squeeze(-1) # (batch, max_length)
|
||||
|
||||
# Extract predictions using the last_index_mask
|
||||
# This gets the prediction for the last valid frame in each slice
|
||||
batch_progress = progress_preds[batch_mask].cpu().numpy()
|
||||
all_progress.extend(batch_progress)
|
||||
|
||||
predictions = np.array(all_progress)
|
||||
|
||||
# Compute expected progress based on original ReWiND normalization
|
||||
# When starting from frame 0, remaining_length = total_episode_length
|
||||
# Expected progress for frame i = (i + 1) / total_frames
|
||||
expected_progress = np.arange(1, total_frames + 1, dtype=np.float32) / total_frames
|
||||
|
||||
logger.info(f"Inference complete. Predicted progress range: [{predictions.min():.3f}, {predictions.max():.3f}]")
|
||||
logger.info(f"Expected progress range: [{expected_progress.min():.3f}, {expected_progress.max():.3f}]")
|
||||
|
||||
return predictions, expected_progress
|
||||
|
||||
|
||||
def visualize_predictions(
|
||||
frames: np.ndarray,
|
||||
predictions: np.ndarray,
|
||||
expected_progress: np.ndarray,
|
||||
task_description: str,
|
||||
output_path: Path,
|
||||
show_frames: bool = False,
|
||||
num_sample_frames: int = 8,
|
||||
figsize: tuple = (12, 6)
|
||||
):
|
||||
"""
|
||||
Create visualization of ReWiND predictions with expected progress comparison.
|
||||
|
||||
Args:
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
predictions: Model progress predictions (num_frames,)
|
||||
expected_progress: Expected progress based on frame position (num_frames,)
|
||||
task_description: Task description
|
||||
output_path: Path to save the figure
|
||||
show_frames: Whether to include sample frames
|
||||
num_sample_frames: Number of frames to show
|
||||
figsize: Figure size (width, height)
|
||||
"""
|
||||
if show_frames:
|
||||
# Create figure with progress plot and sample frames
|
||||
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
||||
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3)
|
||||
|
||||
# Progress plot
|
||||
ax_progress = fig.add_subplot(gs[0])
|
||||
else:
|
||||
# Just progress plot
|
||||
fig, ax_progress = plt.subplots(1, 1, figsize=figsize)
|
||||
|
||||
# Plot progress over time
|
||||
frame_indices = np.arange(len(predictions))
|
||||
|
||||
# Plot expected progress (ground truth)
|
||||
ax_progress.plot(frame_indices, expected_progress, linewidth=2, color='#A8DADC',
|
||||
linestyle='--', label='Expected Progress (Linear)', alpha=0.7)
|
||||
|
||||
# Plot model predictions
|
||||
ax_progress.plot(frame_indices, predictions, linewidth=2.5, color='#2E86AB',
|
||||
label='Model Predictions')
|
||||
ax_progress.fill_between(frame_indices, 0, predictions, alpha=0.2, color='#2E86AB')
|
||||
|
||||
# Add reference line at 1.0
|
||||
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
|
||||
|
||||
# Styling
|
||||
ax_progress.set_xlabel('Frame Index', fontsize=12)
|
||||
ax_progress.set_ylabel('Task Progress', fontsize=12)
|
||||
ax_progress.set_title(f'ReWiND Task Progress Prediction\nTask: "{task_description}"',
|
||||
fontsize=14, fontweight='bold')
|
||||
ax_progress.grid(True, alpha=0.3)
|
||||
ax_progress.set_ylim(-0.05, 1.1)
|
||||
ax_progress.legend(loc='upper left')
|
||||
|
||||
# Compute alignment metrics
|
||||
mae = np.mean(np.abs(predictions - expected_progress))
|
||||
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
|
||||
|
||||
# Add statistics box
|
||||
stats_text = (
|
||||
f'Frames: {len(predictions)}\n'
|
||||
f'Model Final: {predictions[-1]:.3f}\n'
|
||||
f'Model Max: {predictions.max():.3f}\n'
|
||||
f'Model Mean: {predictions.mean():.3f}\n'
|
||||
f'MAE: {mae:.3f}\n'
|
||||
f'RMSE: {rmse:.3f}'
|
||||
)
|
||||
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
|
||||
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
|
||||
# Show sample frames if requested
|
||||
if show_frames:
|
||||
# Select evenly spaced frames
|
||||
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
|
||||
|
||||
# Create subplot for frames
|
||||
ax_frames = fig.add_subplot(gs[1])
|
||||
ax_frames.axis('off')
|
||||
|
||||
# Create grid for frames
|
||||
frame_height = frames[0].shape[0]
|
||||
frame_width = frames[0].shape[1]
|
||||
|
||||
combined_width = frame_width * num_sample_frames
|
||||
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
for i, frame_idx in enumerate(frame_indices_to_show):
|
||||
frame = frames[frame_idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
|
||||
# Add frame to combined image
|
||||
x_start = i * frame_width
|
||||
x_end = (i + 1) * frame_width
|
||||
combined_image[:, x_start:x_end] = frame
|
||||
|
||||
# Add frame number and progress value
|
||||
progress_val = predictions[frame_idx]
|
||||
label = f'Frame {frame_idx}\nProgress: {progress_val:.3f}'
|
||||
|
||||
# Draw label on image
|
||||
ax_frames.text(x_start + frame_width / 2, -10, label,
|
||||
ha='center', va='top', fontsize=8,
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
|
||||
|
||||
ax_frames.imshow(combined_image)
|
||||
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
|
||||
|
||||
# Save figure
|
||||
plt.tight_layout()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Saved visualization to {output_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Setup device
|
||||
if args.device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
device = args.device
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading ReWiND model from {args.model_id}...")
|
||||
model = ReWiNDRewardModel.from_pretrained(args.model_id)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset {args.dataset_repo}...")
|
||||
dataset = LeRobotDataset(args.dataset_repo)
|
||||
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
|
||||
|
||||
# Validate episode index
|
||||
if args.episode_index >= len(dataset.meta.episodes):
|
||||
raise ValueError(
|
||||
f"Episode index {args.episode_index} out of range. "
|
||||
f"Dataset has {len(dataset.meta.episodes)} episodes."
|
||||
)
|
||||
|
||||
# Determine which image key to use
|
||||
image_key = args.image_key if args.image_key is not None else model.config.image_key
|
||||
logger.info(f"Using image key: {image_key}")
|
||||
|
||||
# Load episode data (this also extracts the task description from the episode)
|
||||
frames, start_idx, end_idx, dataset_task = load_episode_data(dataset, args.episode_index, image_key)
|
||||
|
||||
# Use task description from dataset if available, otherwise use command-line argument
|
||||
task_description = dataset_task if dataset_task is not None else args.task_description
|
||||
logger.info(f"Using task description: '{task_description}'")
|
||||
|
||||
# Run inference
|
||||
predictions, expected_progress = run_inference(model, frames, task_description)
|
||||
|
||||
# Create visualization
|
||||
output_dir = Path(args.output_dir)
|
||||
output_path = output_dir / f"rewind_prediction_ep{args.episode_index}.png"
|
||||
|
||||
visualize_predictions(
|
||||
frames,
|
||||
predictions,
|
||||
expected_progress,
|
||||
task_description,
|
||||
output_path,
|
||||
show_frames=args.show_frames,
|
||||
num_sample_frames=args.num_sample_frames,
|
||||
figsize=tuple(args.figsize)
|
||||
)
|
||||
|
||||
# Save predictions and expected progress as numpy arrays
|
||||
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npy"
|
||||
expected_path = output_dir / f"expected_progress_ep{args.episode_index}.npy"
|
||||
np.save(predictions_path, predictions)
|
||||
np.save(expected_path, expected_progress)
|
||||
logger.info(f"Saved predictions array to {predictions_path}")
|
||||
logger.info(f"Saved expected progress to {expected_path}")
|
||||
|
||||
# Compute alignment metrics
|
||||
mae = np.mean(np.abs(predictions - expected_progress))
|
||||
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
|
||||
correlation = np.corrcoef(predictions, expected_progress)[0, 1]
|
||||
|
||||
# Print summary
|
||||
logger.info("\n" + "="*60)
|
||||
logger.info("INFERENCE SUMMARY")
|
||||
logger.info("="*60)
|
||||
logger.info(f"Model: {args.model_id}")
|
||||
logger.info(f"Dataset: {args.dataset_repo}")
|
||||
logger.info(f"Episode: {args.episode_index}")
|
||||
logger.info(f"Task: {task_description}")
|
||||
logger.info(f"Frames: {len(frames)}")
|
||||
logger.info(f"\nModel Predictions:")
|
||||
logger.info(f" Final: {predictions[-1]:.3f}")
|
||||
logger.info(f" Max: {predictions.max():.3f}")
|
||||
logger.info(f" Mean: {predictions.mean():.3f}")
|
||||
logger.info(f" Std: {predictions.std():.3f}")
|
||||
logger.info(f"\nExpected Progress (Linear):")
|
||||
logger.info(f" Final: {expected_progress[-1]:.3f}")
|
||||
logger.info(f" Mean: {expected_progress.mean():.3f}")
|
||||
logger.info(f"\nAlignment Metrics:")
|
||||
logger.info(f" MAE: {mae:.3f}")
|
||||
logger.info(f" RMSE: {rmse:.3f}")
|
||||
logger.info(f" Correlation: {correlation:.3f}")
|
||||
logger.info(f"\nOutput:")
|
||||
logger.info(f" Visualization: {output_path}")
|
||||
logger.info("="*60)
|
||||
|
||||
# Diagnostic warnings
|
||||
if predictions.std() < 0.05:
|
||||
logger.warning("\n⚠ WARNING: Mode collapse detected (std < 0.05)")
|
||||
logger.warning(" Model predictions show very low variance.")
|
||||
logger.warning(" This indicates the model was likely trained with incorrect")
|
||||
logger.warning(" progress normalization (absolute indices instead of remaining length).")
|
||||
elif mae > 0.3:
|
||||
logger.warning("\n⚠ WARNING: High prediction error (MAE > 0.3)")
|
||||
logger.warning(" Model predictions deviate significantly from expected linear progress.")
|
||||
logger.warning(" Consider retraining with correct progress normalization.")
|
||||
elif correlation < 0.5:
|
||||
logger.warning("\n⚠ WARNING: Low correlation with expected progress (< 0.5)")
|
||||
logger.warning(" Model predictions don't align well with linear task progression.")
|
||||
else:
|
||||
logger.info("\n✓ Model predictions show healthy progression!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -244,7 +244,7 @@ def run_inference(
|
||||
logger.info("Encoding video frames with CLIP...")
|
||||
video_embeddings = model.encode_images(frames)
|
||||
|
||||
logger.info("Encoding task description with MiniLM...")
|
||||
logger.info("Encoding task description with CLIP...")
|
||||
text_embedding = model.encode_text(task_description)
|
||||
|
||||
# Get config values
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
ReWiND Sampler for temporal sequence loading.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
import random
|
||||
|
||||
|
||||
class ReWiNDTemporalSampler(Sampler):
|
||||
"""
|
||||
Sampler for ReWiND that samples random temporal windows from episodes.
|
||||
|
||||
Matches original ReWiND sampling:
|
||||
- Samples random start and end points within episodes
|
||||
- Minimum window size of 3 frames
|
||||
- Can sample from beginning, middle, or end of episodes
|
||||
|
||||
Args:
|
||||
dataset_from_index: Start indices of episodes
|
||||
dataset_to_index: End indices of episodes
|
||||
sequence_length: Maximum sequence length (for padding/subsampling)
|
||||
stride: Not used (kept for API compatibility)
|
||||
shuffle: Whether to shuffle sampling order
|
||||
seed: Random seed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_index: np.ndarray,
|
||||
dataset_to_index: np.ndarray,
|
||||
sequence_length: int = 32,
|
||||
stride: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
self.dataset_from_index = np.array(dataset_from_index)
|
||||
self.dataset_to_index = np.array(dataset_to_index)
|
||||
self.sequence_length = sequence_length
|
||||
self.shuffle = shuffle
|
||||
|
||||
if seed is not None:
|
||||
self.seed = seed
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
self.generator = torch.Generator().manual_seed(seed)
|
||||
else:
|
||||
self.generator = torch.Generator()
|
||||
|
||||
# Compute valid episodes (those with at least 3 frames)
|
||||
self._compute_valid_episodes()
|
||||
|
||||
# Number of samples per epoch (matching original ReWiND)
|
||||
self.samples_per_epoch = 100 * 64 # 100 batches of 64
|
||||
|
||||
logging.info(
|
||||
f"ReWiNDTemporalSampler: {len(self.valid_episodes)} valid episodes, "
|
||||
f"{self.samples_per_epoch} samples per epoch"
|
||||
)
|
||||
|
||||
def _compute_valid_episodes(self):
|
||||
"""Compute valid episodes (those with at least 3 frames)."""
|
||||
self.valid_episodes = []
|
||||
|
||||
for ep_idx in range(len(self.dataset_from_index)):
|
||||
ep_start = self.dataset_from_index[ep_idx]
|
||||
ep_end = self.dataset_to_index[ep_idx]
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
if episode_length >= 3: # Minimum 3 frames
|
||||
self.valid_episodes.append((ep_idx, ep_start, ep_end))
|
||||
|
||||
self.valid_episodes = np.array(self.valid_episodes)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.samples_per_epoch
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
"""
|
||||
Yields ONE index per sample (the end of a random window).
|
||||
|
||||
Matches original ReWiND behavior:
|
||||
1. Pick random episode
|
||||
2. Pick random end frame (at least 3 frames from start)
|
||||
3. Yield that end frame index
|
||||
4. Dataset/processor loads from episode start to this end frame
|
||||
5. Model pads/subsamples to sequence_length (32)
|
||||
|
||||
This allows sampling from anywhere in episodes:
|
||||
- Early frames → short sequences (mostly padding) → low progress
|
||||
- Middle frames → medium sequences (some subsampling) → medium progress
|
||||
- End frames → long sequences (full subsampling) → high progress approaching 1.0
|
||||
"""
|
||||
for _ in range(self.samples_per_epoch):
|
||||
# Randomly select an episode
|
||||
ep_idx, ep_start, ep_end = self.valid_episodes[
|
||||
np.random.randint(0, len(self.valid_episodes))
|
||||
]
|
||||
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# Sample a random end point (must be at least 3 frames from start)
|
||||
# This matches original: random.randint(start_idx+3, len(progress_dataset))
|
||||
end_offset = np.random.randint(3, episode_length + 1)
|
||||
end_idx = ep_start + end_offset
|
||||
|
||||
# Yield ONLY the end index
|
||||
# The dataset will load all frames from ep_start to end_idx
|
||||
yield int(end_idx - 1) # -1 because end_idx is exclusive
|
||||
@@ -15,12 +15,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Temporal Sequence Sampler for reward models and temporal policies.
|
||||
SARM Temporal Sampler for reward model training.
|
||||
|
||||
Supports multiple sampling modes:
|
||||
- "rewind": ReWiND-style sampling (random windows from episode start)
|
||||
- "sarm": SARM-style sampling (9-frame sequences with specific pattern)
|
||||
- "custom": Custom temporal sampling
|
||||
Samples frames from episodes ensuring sufficient temporal history for SARM's
|
||||
9-frame pattern (1 initial + 8 consecutive with frame_gap spacing).
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -31,24 +29,23 @@ from torch.utils.data import Sampler
|
||||
import random
|
||||
|
||||
|
||||
class TemporalSequenceSampler(Sampler):
|
||||
class SARMTemporalSampler(Sampler):
|
||||
"""
|
||||
Generalized temporal sampler for reward models.
|
||||
Temporal sampler for SARM reward model training.
|
||||
|
||||
Supports multiple sampling modes:
|
||||
- "rewind": Consecutive frames from episode start to random end point (ReWiND: 32 consecutive frames)
|
||||
- "sarm": 9-frame sequences with 1 initial + 8 consecutive (SARM)
|
||||
- "custom": Custom temporal sampling
|
||||
SARM uses 9 frames per sample:
|
||||
- Frame 0: Initial frame of the episode (always frame 0)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
|
||||
|
||||
This sampler ensures we only sample from positions that have enough
|
||||
temporal history (at least 7 * frame_gap frames from episode start).
|
||||
|
||||
Args:
|
||||
dataset_from_index: Start indices of episodes
|
||||
dataset_to_index: End indices of episodes
|
||||
sequence_length: Maximum sequence length (for padding/subsampling)
|
||||
stride: Frame stride for consecutive sampling (SARM mode)
|
||||
dataset_from_index: Start indices of episodes (global dataset indices)
|
||||
dataset_to_index: End indices of episodes (global dataset indices)
|
||||
frame_gap: Gap between consecutive frames (default: 30 = 1 second at 30fps)
|
||||
shuffle: Whether to shuffle sampling order
|
||||
seed: Random seed
|
||||
sampling_mode: Sampling mode ("rewind", "sarm", or "custom")
|
||||
min_frames: Minimum frames per episode (default: 3)
|
||||
seed: Random seed for reproducibility
|
||||
samples_per_epoch: Number of samples per epoch (default: 6400)
|
||||
"""
|
||||
|
||||
@@ -56,25 +53,21 @@ class TemporalSequenceSampler(Sampler):
|
||||
self,
|
||||
dataset_from_index: np.ndarray,
|
||||
dataset_to_index: np.ndarray,
|
||||
sequence_length: int = 32,
|
||||
stride: int = 1,
|
||||
frame_gap: int = 30,
|
||||
shuffle: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
sampling_mode: str = "rewind",
|
||||
min_frames: int = 3,
|
||||
samples_per_epoch: int = 6400,
|
||||
):
|
||||
self.dataset_from_index = np.array(dataset_from_index)
|
||||
self.dataset_to_index = np.array(dataset_to_index)
|
||||
self.sequence_length = sequence_length
|
||||
self.stride = stride
|
||||
self.frame_gap = frame_gap
|
||||
self.shuffle = shuffle
|
||||
self.sampling_mode = sampling_mode
|
||||
self.min_frames = min_frames
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
|
||||
if sampling_mode not in ["rewind", "sarm", "custom"]:
|
||||
raise ValueError(f"sampling_mode must be 'rewind', 'sarm', or 'custom', got {sampling_mode}")
|
||||
# Minimum frames needed for SARM pattern:
|
||||
# 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
|
||||
# (Plus the initial frame which is always available)
|
||||
self.min_frames_needed = 7 * frame_gap + 1
|
||||
|
||||
if seed is not None:
|
||||
self.seed = seed
|
||||
@@ -84,98 +77,68 @@ class TemporalSequenceSampler(Sampler):
|
||||
else:
|
||||
self.generator = torch.Generator()
|
||||
|
||||
# Compute valid episodes
|
||||
self._compute_valid_episodes()
|
||||
# Compute valid episodes and sampling positions
|
||||
self._compute_valid_positions()
|
||||
|
||||
logging.info(
|
||||
f"TemporalSequenceSampler ({sampling_mode} mode): "
|
||||
f"{len(self.valid_episodes)} valid episodes, "
|
||||
f"{self.samples_per_epoch} samples per epoch"
|
||||
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
|
||||
f"{len(self.all_valid_positions)} valid positions, "
|
||||
f"{self.samples_per_epoch} samples per epoch, "
|
||||
f"frame_gap={frame_gap}"
|
||||
)
|
||||
|
||||
def _compute_valid_episodes(self):
|
||||
"""Compute valid episodes based on minimum frame requirement."""
|
||||
def _compute_valid_positions(self):
|
||||
"""Compute valid episodes and all valid sampling positions."""
|
||||
self.valid_episodes = []
|
||||
self.all_valid_positions = []
|
||||
|
||||
for ep_idx in range(len(self.dataset_from_index)):
|
||||
ep_start = self.dataset_from_index[ep_idx]
|
||||
ep_end = self.dataset_to_index[ep_idx]
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# For SARM mode, need enough frames for the sequence pattern
|
||||
if self.sampling_mode == "sarm":
|
||||
# Need at least sequence_length * stride frames
|
||||
min_required = self.sequence_length * self.stride
|
||||
if episode_length >= min_required:
|
||||
self.valid_episodes.append((ep_idx, ep_start, ep_end))
|
||||
else:
|
||||
# For rewind mode, use min_frames
|
||||
if episode_length >= self.min_frames:
|
||||
self.valid_episodes.append((ep_idx, ep_start, ep_end))
|
||||
# Episode must have enough frames for SARM pattern
|
||||
if episode_length >= self.min_frames_needed:
|
||||
self.valid_episodes.append((ep_idx, ep_start, ep_end))
|
||||
|
||||
# Valid positions: from min_frames_needed to episode end
|
||||
# These are global dataset indices
|
||||
for pos in range(ep_start + self.min_frames_needed - 1, ep_end):
|
||||
self.all_valid_positions.append(pos)
|
||||
|
||||
self.valid_episodes = np.array(self.valid_episodes)
|
||||
self.all_valid_positions = np.array(self.all_valid_positions)
|
||||
|
||||
if len(self.all_valid_positions) == 0:
|
||||
raise ValueError(
|
||||
f"No valid sampling positions found! "
|
||||
f"Episodes need at least {self.min_frames_needed} frames "
|
||||
f"(7 * frame_gap + 1 = 7 * {self.frame_gap} + 1)."
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.samples_per_epoch
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
"""
|
||||
Yields ONE index per sample.
|
||||
Yields global dataset indices for sampling.
|
||||
|
||||
Sampling behavior depends on mode:
|
||||
|
||||
ReWiND mode:
|
||||
1. Pick random episode
|
||||
2. Pick random end frame (at least min_frames from start)
|
||||
3. Yield that end frame index
|
||||
4. Dataset loads from episode start to this end frame
|
||||
|
||||
SARM mode:
|
||||
1. Pick random episode
|
||||
2. Pick random end frame (must allow sequence_length frames with stride)
|
||||
3. Yield that end frame index
|
||||
4. Dataset loads sequence_length frames with stride spacing ending at this frame
|
||||
Each yielded index represents the "current frame" position.
|
||||
The dataset's observation_delta_indices then handles loading:
|
||||
- Frame 0: Episode initial frame (via large negative delta clamping)
|
||||
- Frames 1-8: Consecutive frames ending at the yielded index
|
||||
"""
|
||||
for _ in range(self.samples_per_epoch):
|
||||
# Randomly select an episode
|
||||
ep_idx, ep_start, ep_end = self.valid_episodes[
|
||||
np.random.randint(0, len(self.valid_episodes))
|
||||
]
|
||||
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
if self.sampling_mode == "rewind":
|
||||
# ReWiND: Sample random end point (at least min_frames from start)
|
||||
end_offset = np.random.randint(self.min_frames, episode_length + 1)
|
||||
end_idx = ep_start + end_offset
|
||||
|
||||
# Yield the end index (dataset will load from start to this point)
|
||||
yield int(end_idx - 1) # -1 because end_idx is exclusive
|
||||
|
||||
elif self.sampling_mode == "sarm":
|
||||
# SARM: Sample end point that allows full sequence
|
||||
# We need sequence_length frames with stride spacing
|
||||
min_end_offset = self.sequence_length * self.stride
|
||||
|
||||
if episode_length >= min_end_offset:
|
||||
# Can sample anywhere from min_end_offset to episode_length
|
||||
end_offset = np.random.randint(min_end_offset, episode_length + 1)
|
||||
else:
|
||||
# Episode is exactly the minimum length
|
||||
end_offset = episode_length
|
||||
|
||||
end_idx = ep_start + end_offset
|
||||
|
||||
# Yield the end index (dataset will load sequence with stride)
|
||||
yield int(end_idx - 1) # -1 because end_idx is exclusive
|
||||
|
||||
else: # custom mode
|
||||
# Default to rewind-style sampling
|
||||
end_offset = np.random.randint(self.min_frames, episode_length + 1)
|
||||
end_idx = ep_start + end_offset
|
||||
yield int(end_idx - 1)
|
||||
if self.shuffle:
|
||||
# Randomly sample from all valid positions
|
||||
for _ in range(self.samples_per_epoch):
|
||||
idx = np.random.randint(0, len(self.all_valid_positions))
|
||||
yield int(self.all_valid_positions[idx])
|
||||
else:
|
||||
# Sequential sampling with wrap-around
|
||||
for i in range(self.samples_per_epoch):
|
||||
idx = i % len(self.all_valid_positions)
|
||||
yield int(self.all_valid_positions[idx])
|
||||
|
||||
|
||||
# Backwards compatibility alias
|
||||
ReWiNDTemporalSampler = TemporalSequenceSampler
|
||||
|
||||
TemporalSequenceSampler = SARMTemporalSampler
|
||||
|
||||
@@ -1,383 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Video sampling utilities for temporal data augmentation and frame selection.
|
||||
|
||||
This module provides utilities for sampling and augmenting video sequences, particularly
|
||||
for reward model training. It includes functions for:
|
||||
- Padding/sampling videos to fixed lengths
|
||||
- Video rewind augmentation for learning to decrease rewards
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def sample_video_feature(
|
||||
video_feature: torch.Tensor,
|
||||
max_length: int = 32,
|
||||
random_sample: bool = True,
|
||||
remaining_length: int = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample or pad video features to a fixed length with progress targets.
|
||||
|
||||
Progress normalization matches original ReWiND implementation:
|
||||
- Progress = (position_in_sequence + 1) / remaining_trajectory_length
|
||||
- remaining_trajectory_length = frames from first sampled frame to episode end
|
||||
|
||||
Original ReWiND logic (dataset.py lines 12493-12499):
|
||||
video_frames = frames[start_idx:end_idx]
|
||||
full_frames = frames[start_idx:] # All frames from start to episode end
|
||||
progress = [1, 2, ..., len(video_frames)] / len(full_frames)
|
||||
|
||||
This ensures all sequences show increasing progress from near-zero, regardless
|
||||
of where they're sampled from in the episode.
|
||||
|
||||
Uses original ReWiND sampling: random start/end points with minimum 3 frames.
|
||||
|
||||
Args:
|
||||
video_feature: Video features tensor (num_frames, feature_dim)
|
||||
max_length: Target sequence length
|
||||
random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames.
|
||||
remaining_length: Remaining trajectory length from first frame to episode end
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Sampled/padded video features (max_length, feature_dim)
|
||||
- Progress targets for each frame (max_length,)
|
||||
"""
|
||||
video_length = len(video_feature)
|
||||
|
||||
# Original ReWiND sampling: random start/end with minimum 3 frames
|
||||
if video_length > 3:
|
||||
# Sample random start index (ensuring we can get at least 3 frames)
|
||||
start_idx = random.randint(0, max(0, video_length - 3))
|
||||
# Sample random end index (at least 3 frames after start, up to video_length)
|
||||
end_idx = random.randint(min(start_idx + 3, video_length), video_length)
|
||||
|
||||
# Extract the sampled segment
|
||||
video_feature = video_feature[start_idx:end_idx]
|
||||
|
||||
# Update video_length for the sampled segment
|
||||
video_length = len(video_feature)
|
||||
|
||||
# Adjust remaining_length to be from start_idx to episode end
|
||||
if remaining_length is not None:
|
||||
# The remaining length should be from start_idx to episode end
|
||||
# If we started at start_idx, we've already consumed start_idx frames
|
||||
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
|
||||
|
||||
# Generate progress targets using ORIGINAL ReWiND formula
|
||||
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
|
||||
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
||||
progress_targets = progress_indices / remaining_length
|
||||
|
||||
if video_length < max_length:
|
||||
# Pad with last frame
|
||||
padding_length = max_length - video_length
|
||||
last_frame = video_feature[-1].unsqueeze(0)
|
||||
padding_frames = last_frame.repeat(padding_length, 1)
|
||||
video_feature = torch.cat([video_feature, padding_frames], dim=0)
|
||||
|
||||
# Pad progress with last progress value
|
||||
last_progress = progress_targets[-1]
|
||||
padding_progress = torch.full((padding_length,), last_progress)
|
||||
progress_targets = torch.cat([progress_targets, padding_progress])
|
||||
|
||||
elif video_length > max_length:
|
||||
if random_sample:
|
||||
# Random sampling (maintains temporal order via sorted indices)
|
||||
frame_idx = sorted(random.sample(range(video_length), max_length))
|
||||
else:
|
||||
# Uniform sampling (consecutive frames with even spacing)
|
||||
frame_idx = np.linspace(0, video_length - 1, max_length, dtype=int)
|
||||
video_feature = video_feature[frame_idx]
|
||||
progress_targets = progress_targets[frame_idx]
|
||||
|
||||
return video_feature, progress_targets
|
||||
|
||||
|
||||
def sample_reverse_video_feature(
|
||||
video_feature: torch.Tensor,
|
||||
max_length: int = 32,
|
||||
random_sample: bool = True,
|
||||
remaining_length: int = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC.
|
||||
|
||||
This implements the EXACT video rewind augmentation from the original ReWiND paper:
|
||||
1. Take forward sequence (sampled with random start/end, min 3 frames)
|
||||
2. Append reversed frames from the END backwards
|
||||
3. Progress increases then decreases (simulating task completion then failure)
|
||||
|
||||
Progress normalization matches original ReWiND (same as sample_video_feature).
|
||||
Original ReWiND logic (dataset.py lines 12526-12541):
|
||||
progress = [1, 2, ..., len(video_frames)] / len(full_frames)
|
||||
reverse_progress = progress[::-1][1:selected_end_point]
|
||||
|
||||
Args:
|
||||
video_feature: Video features tensor (num_frames, feature_dim)
|
||||
max_length: Target sequence length
|
||||
random_sample: If True, use random sampling for frame selection
|
||||
remaining_length: Remaining trajectory length from first frame to episode end
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Rewound video features (max_length, feature_dim)
|
||||
- Progress targets for each frame (max_length,)
|
||||
"""
|
||||
video_length = len(video_feature)
|
||||
|
||||
# Original logic: start from first half, end in second half, ensure min 3 frames
|
||||
if video_length > 3:
|
||||
# Sample start from first half
|
||||
start_idx = random.randint(0, video_length // 2)
|
||||
# Sample end from second half
|
||||
end_idx = random.randint(video_length // 2, video_length)
|
||||
|
||||
# Ensure minimum 3 frames difference (original uses while loop)
|
||||
while end_idx - start_idx < 3:
|
||||
start_idx = random.randint(0, video_length // 2)
|
||||
end_idx = random.randint(video_length // 2, video_length)
|
||||
|
||||
# Extract the forward segment
|
||||
video_feature = video_feature[start_idx:end_idx]
|
||||
video_length = len(video_feature)
|
||||
|
||||
# Adjust remaining_length
|
||||
if remaining_length is not None:
|
||||
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
|
||||
|
||||
# Generate forward progress targets using ORIGINAL ReWiND formula
|
||||
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
|
||||
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
||||
forward_progress = progress_indices / remaining_length
|
||||
|
||||
# ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence
|
||||
# Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first)
|
||||
# Result: [A,B,C,D,E] + [D,C,B] = progress increases then decreases
|
||||
|
||||
# Randomly select how many frames to reverse and append
|
||||
selected_end_point = random.randint(2, min(video_length, max_length))
|
||||
|
||||
# Reverse the entire video and its progress
|
||||
reversed_video = video_feature.flip(0)
|
||||
reversed_progress = forward_progress.flip(0)
|
||||
|
||||
# Take frames from reversed (skip the first frame which is the last frame of original)
|
||||
reverse_frames = reversed_video[1:selected_end_point]
|
||||
reverse_progress = reversed_progress[1:selected_end_point]
|
||||
|
||||
# Concatenate forward + reversed (creates rewind effect)
|
||||
rewound_video = torch.cat([video_feature, reverse_frames], dim=0)
|
||||
progress_targets = torch.cat([forward_progress, reverse_progress], dim=0)
|
||||
|
||||
# Pad or sample to target length
|
||||
if len(rewound_video) < max_length:
|
||||
# Pad with last frame
|
||||
padding_length = max_length - len(rewound_video)
|
||||
last_frame = rewound_video[-1].unsqueeze(0)
|
||||
padding_frames = last_frame.repeat(padding_length, 1)
|
||||
rewound_video = torch.cat([rewound_video, padding_frames], dim=0)
|
||||
|
||||
# Pad progress with last progress value
|
||||
last_progress = progress_targets[-1]
|
||||
padding_progress = torch.full((padding_length,), last_progress)
|
||||
progress_targets = torch.cat([progress_targets, padding_progress])
|
||||
|
||||
elif len(rewound_video) > max_length:
|
||||
# Sample frames to fit max_length
|
||||
if random_sample:
|
||||
frame_idx = sorted(random.sample(range(len(rewound_video)), max_length))
|
||||
else:
|
||||
frame_idx = np.linspace(0, len(rewound_video) - 1, max_length, dtype=int)
|
||||
rewound_video = rewound_video[frame_idx]
|
||||
progress_targets = progress_targets[frame_idx]
|
||||
|
||||
return rewound_video, progress_targets
|
||||
|
||||
|
||||
def sample_sarm_video_feature(
|
||||
video_feature: torch.Tensor,
|
||||
num_frames: int = 9,
|
||||
frame_gap: int = 30,
|
||||
random_sample: bool = True,
|
||||
absolute_indices: torch.Tensor = None,
|
||||
episode_length: int = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample video features for SARM (Stage-Aware Reward Modeling).
|
||||
|
||||
SARM uses a specific pattern:
|
||||
- 1 initial frame (from episode start)
|
||||
- 8 consecutive frames with frame_gap spacing
|
||||
|
||||
Progress normalization matches SARM implementation:
|
||||
- Progress = absolute_frame_index / total_episode_length
|
||||
|
||||
Args:
|
||||
video_feature: Video features tensor (num_frames_available, feature_dim)
|
||||
num_frames: Target number of frames (default: 9)
|
||||
frame_gap: Gap between consecutive frames (default: 30, i.e., 1 second at 30fps)
|
||||
random_sample: If True, use random sampling (not used for SARM's fixed pattern)
|
||||
absolute_indices: Absolute frame indices in the episode (num_frames_available,)
|
||||
episode_length: Total length of the episode
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Sampled video features (num_frames, feature_dim)
|
||||
- Progress targets for each frame (num_frames,)
|
||||
"""
|
||||
video_length = len(video_feature)
|
||||
|
||||
# Generate progress targets based on relative position within sampled sequence
|
||||
# Note: SARM paper uses subtask annotations (Equation 2: yt = Pk−1 + ᾱk * τt)
|
||||
# Without annotations, we use linear progress relative to sequence position
|
||||
if absolute_indices is not None and episode_length is not None:
|
||||
# Compute relative progress: position within sequence / remaining trajectory
|
||||
# This ensures progress starts near 0 and increases, not starting at 0.8 if sampled from end
|
||||
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
|
||||
remaining_length = episode_length - first_frame_idx
|
||||
|
||||
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
|
||||
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
||||
progress_targets = progress_indices / remaining_length
|
||||
else:
|
||||
# Fallback: linear progress
|
||||
progress_targets = torch.linspace(1.0/video_length, 1.0, video_length)
|
||||
|
||||
# SARM pattern: first frame + (num_frames-1) consecutive frames with frame_gap
|
||||
# The first frame should be from the beginning of the sequence
|
||||
# The remaining frames are sampled with frame_gap spacing
|
||||
|
||||
if video_length < num_frames:
|
||||
# Not enough frames, pad with last frame
|
||||
sampled_video = video_feature
|
||||
sampled_progress = progress_targets
|
||||
|
||||
padding_length = num_frames - video_length
|
||||
last_frame = sampled_video[-1].unsqueeze(0)
|
||||
padding_frames = last_frame.repeat(padding_length, 1)
|
||||
sampled_video = torch.cat([sampled_video, padding_frames], dim=0)
|
||||
|
||||
last_progress = sampled_progress[-1]
|
||||
padding_progress = torch.full((padding_length,), last_progress)
|
||||
sampled_progress = torch.cat([sampled_progress, padding_progress])
|
||||
|
||||
else:
|
||||
# Sample frames: first frame + (num_frames-1) with frame_gap
|
||||
# The indices should represent: [0, gap, 2*gap, 3*gap, ..., (num_frames-1)*gap]
|
||||
# But we need to ensure we don't exceed video_length
|
||||
|
||||
frame_indices = [0] # First frame
|
||||
for i in range(1, num_frames):
|
||||
idx = i * frame_gap
|
||||
if idx >= video_length:
|
||||
idx = video_length - 1
|
||||
frame_indices.append(idx)
|
||||
|
||||
frame_indices = torch.tensor(frame_indices, dtype=torch.long)
|
||||
sampled_video = video_feature[frame_indices]
|
||||
sampled_progress = progress_targets[frame_indices]
|
||||
|
||||
return sampled_video, sampled_progress
|
||||
|
||||
|
||||
def sample_sarm_reverse_video_feature(
|
||||
video_feature: torch.Tensor,
|
||||
num_frames: int = 9,
|
||||
frame_gap: int = 30,
|
||||
random_sample: bool = True,
|
||||
absolute_indices: torch.Tensor = None,
|
||||
episode_length: int = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample video with reverse augmentation for SARM (rewind augmentation).
|
||||
|
||||
Similar to ReWiND's rewind augmentation but adapted for SARM's frame pattern:
|
||||
1. Take forward sequence (1 initial + 8 consecutive)
|
||||
2. Append some reversed frames from the end backwards
|
||||
3. Progress increases then decreases
|
||||
|
||||
Args:
|
||||
video_feature: Video features tensor (num_frames_available, feature_dim)
|
||||
num_frames: Target number of frames (default: 9)
|
||||
frame_gap: Gap between consecutive frames (default: 30)
|
||||
random_sample: If True, use random sampling for reverse section
|
||||
absolute_indices: Absolute frame indices in the episode
|
||||
episode_length: Total length of the episode
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Rewound video features (num_frames, feature_dim)
|
||||
- Progress targets for each frame (num_frames,)
|
||||
"""
|
||||
video_length = len(video_feature)
|
||||
|
||||
# Generate forward progress targets (relative to sequence, not absolute)
|
||||
if absolute_indices is not None and episode_length is not None:
|
||||
# Use same relative progress as normal sampling
|
||||
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
|
||||
remaining_length = episode_length - first_frame_idx
|
||||
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
|
||||
forward_progress = progress_indices / remaining_length
|
||||
else:
|
||||
forward_progress = torch.linspace(1.0/video_length, 1.0, video_length)
|
||||
|
||||
# Sample forward sequence first
|
||||
forward_video, forward_progress_sampled = sample_sarm_video_feature(
|
||||
video_feature, num_frames, frame_gap, random_sample, absolute_indices, episode_length
|
||||
)
|
||||
|
||||
# Randomly select how many frames to reverse and append
|
||||
# For SARM, we append 2-4 reversed frames
|
||||
num_reverse = random.randint(2, min(4, num_frames - 1))
|
||||
|
||||
# Reverse the video and progress
|
||||
reversed_video = video_feature.flip(0)
|
||||
reversed_progress = forward_progress.flip(0)
|
||||
|
||||
# Take frames from reversed (skip the first frame which is the last frame of original)
|
||||
reverse_frames = reversed_video[1:num_reverse+1]
|
||||
reverse_progress = reversed_progress[1:num_reverse+1]
|
||||
|
||||
# Concatenate forward + reversed (creates rewind effect)
|
||||
rewound_video = torch.cat([forward_video, reverse_frames], dim=0)
|
||||
progress_targets = torch.cat([forward_progress_sampled, reverse_progress], dim=0)
|
||||
|
||||
# Trim to num_frames if necessary
|
||||
if len(rewound_video) > num_frames:
|
||||
# Keep the first num_frames
|
||||
rewound_video = rewound_video[:num_frames]
|
||||
progress_targets = progress_targets[:num_frames]
|
||||
elif len(rewound_video) < num_frames:
|
||||
# Pad if necessary
|
||||
padding_length = num_frames - len(rewound_video)
|
||||
last_frame = rewound_video[-1].unsqueeze(0)
|
||||
padding_frames = last_frame.repeat(padding_length, 1)
|
||||
rewound_video = torch.cat([rewound_video, padding_frames], dim=0)
|
||||
|
||||
last_progress = progress_targets[-1]
|
||||
padding_progress = torch.full((padding_length,), last_progress)
|
||||
progress_targets = torch.cat([progress_targets, padding_progress])
|
||||
|
||||
return rewound_video, progress_targets
|
||||
|
||||
@@ -132,41 +132,6 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("cosine_with_min_lr")
|
||||
@dataclass
|
||||
class CosineWithMinLRSchedulerConfig(LRSchedulerConfig):
|
||||
"""Cosine learning rate scheduler with minimum learning rate floor.
|
||||
|
||||
Used by ReWiND for reward model training. Includes linear warmup phase
|
||||
followed by cosine annealing with a minimum learning rate.
|
||||
"""
|
||||
|
||||
num_warmup_steps: int
|
||||
min_lr: float = 0.0
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step):
|
||||
# Get base learning rate from optimizer
|
||||
base_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
if current_step <= self.num_warmup_steps:
|
||||
# Linear warmup
|
||||
if self.num_warmup_steps == 0:
|
||||
return 1.0
|
||||
return float(current_step) / float(max(1, self.num_warmup_steps))
|
||||
else:
|
||||
# Cosine annealing with minimum learning rate
|
||||
progress = (current_step - self.num_warmup_steps) / float(
|
||||
max(1, num_training_steps - self.num_warmup_steps)
|
||||
)
|
||||
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||
# Scale between min_lr and base_lr
|
||||
min_factor = self.min_lr / base_lr if base_lr > 0 else 0.0
|
||||
return min_factor + (1.0 - min_factor) * cosine_factor
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
|
||||
state_dict = scheduler.state_dict()
|
||||
write_json(state_dict, save_dir / SCHEDULER_STATE)
|
||||
|
||||
@@ -34,7 +34,6 @@ from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
@@ -105,10 +104,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "rewind":
|
||||
from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel
|
||||
|
||||
return ReWiNDRewardModel
|
||||
elif name == "sarm":
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
@@ -332,15 +327,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, ReWiNDConfig):
|
||||
from lerobot.policies.rewind.processor_rewind import make_rewind_pre_post_processors
|
||||
|
||||
processors = make_rewind_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
|
||||
from lerobot.policies.rewind.modeling_rewind import (
|
||||
ReWiNDRewardModel,
|
||||
ReWiNDTransformer,
|
||||
)
|
||||
from lerobot.policies.rewind.processor_rewind import (
|
||||
ReWiNDEncodingProcessorStep,
|
||||
make_rewind_pre_post_processors,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ReWiNDConfig",
|
||||
"ReWiNDRewardModel",
|
||||
"ReWiNDTransformer",
|
||||
"ReWiNDEncodingProcessorStep",
|
||||
"make_rewind_pre_post_processors",
|
||||
]
|
||||
|
||||
@@ -1,138 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("rewind")
|
||||
@dataclass
|
||||
class ReWiNDConfig(PreTrainedConfig):
|
||||
"""Configuration class for ReWiND Reward Model.
|
||||
|
||||
ReWiND (Reward from Video and Natural language Descriptions) is a reward model
|
||||
that computes task completion/progress rewards from video observations and
|
||||
language task descriptions.
|
||||
"""
|
||||
|
||||
# Model architecture parameters
|
||||
video_dim: int = 768 # DINO embedding dimension
|
||||
text_dim: int = 384 # MiniLM embedding dimension
|
||||
hidden_dim: int = 512
|
||||
num_heads: int = 8
|
||||
num_layers: int = 4
|
||||
|
||||
# Temporal parameters
|
||||
max_length: int = 32 # Maximum video sequence length, ORIGINAL: 16!
|
||||
subsample_video: bool = True # Whether to pad/subsample videos to max_length
|
||||
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
||||
sequence_stride: int = 1 # Stride between frames when using temporal sampler
|
||||
rewind_ratio: float = 0.8 # Probability of applying rewind augmentation (original: 0.8)
|
||||
|
||||
# Training parameters
|
||||
batch_size: int = 64
|
||||
dino_batch_size: int = 64
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
|
||||
# Model loading
|
||||
pretrained_model_path: str | None = None
|
||||
|
||||
# Device settings
|
||||
device: str | None = None
|
||||
|
||||
# Dropout
|
||||
dropout: float = 0.1 # Dropout rate for transformer
|
||||
|
||||
# Processor settings (for automatic preprocessing)
|
||||
image_key: str = "observation.images.top" # Key for images in dataset
|
||||
task_description: str = "perform the task" # Default task description (used if no task field in data)
|
||||
encode_on_the_fly: bool = True # Encode images/text during training
|
||||
use_dataset_task: bool = True # Use task descriptions from dataset (per-episode)
|
||||
|
||||
# Features (required by PreTrainedPolicy)
|
||||
input_features: dict = field(default_factory=lambda: {
|
||||
"video_features": {"shape": [768], "dtype": "float32"},
|
||||
"text_features": {"shape": [384], "dtype": "float32"}
|
||||
})
|
||||
output_features: dict = field(default_factory=lambda: {
|
||||
"progress": {"shape": [1], "dtype": "float32"}
|
||||
})
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})"
|
||||
)
|
||||
|
||||
if self.max_length <= 0:
|
||||
raise ValueError(f"max_length must be positive, got {self.max_length}")
|
||||
|
||||
if self.dropout < 0 or self.dropout >= 1:
|
||||
raise ValueError(f"dropout must be in [0, 1), got {self.dropout}")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Get default optimizer configuration for ReWiND training."""
|
||||
return AdamWConfig(
|
||||
lr=1e-4,
|
||||
weight_decay=1e-4,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Get default learning rate scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=1e-4,
|
||||
decay_lr=1e-5,
|
||||
num_warmup_steps=1000,
|
||||
num_decay_steps=100000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
"""Load all frames from episode start up to current frame.
|
||||
|
||||
The sampler yields a random end point in each episode.
|
||||
This property tells the dataset to load all frames from -(end_idx - start_idx) to 0.
|
||||
|
||||
Since we don't know the exact window size in advance, we load up to max_length frames.
|
||||
The dataset will automatically clamp to episode boundaries.
|
||||
|
||||
Returns:
|
||||
Indices for loading history: [-31, -30, ..., -1, 0] for max_length=32
|
||||
"""
|
||||
# Load the last max_length frames (or up to episode start)
|
||||
return list(range(-(self.max_length - 1), 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
"""ReWiND is a reward model, not an action policy."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""ReWiND doesn't use delta rewards."""
|
||||
return None
|
||||
|
||||
@@ -1,711 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Union, Dict, Optional
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import torchvision.transforms as T
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.datasets.video_sampler import sample_video_feature, sample_reverse_video_feature
|
||||
|
||||
|
||||
# Helper functions for encoding
|
||||
def dino_load_image(img: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Load an image and return a tensor that can be used as an input to DINOv2.
|
||||
|
||||
Args:
|
||||
img: Input image as numpy array (H, W, C) in uint8 format.
|
||||
|
||||
Returns:
|
||||
Transformed image tensor ready for DINO encoder (1, 3, 224, 224).
|
||||
"""
|
||||
# Define transform: center crop to 224x224, normalize to [-1, 1]
|
||||
dino_transform = T.Compose([
|
||||
T.ToTensor(),
|
||||
T.CenterCrop(224),
|
||||
T.Normalize([0.5], [0.5])
|
||||
])
|
||||
|
||||
img_pil = Image.fromarray(img)
|
||||
transformed_img = dino_transform(img_pil)[:3].unsqueeze(0)
|
||||
|
||||
return transformed_img
|
||||
|
||||
|
||||
def mean_pooling(model_output, attention_mask):
|
||||
"""
|
||||
Mean pooling - take attention mask into account for correct averaging.
|
||||
|
||||
Args:
|
||||
model_output: Model output containing token embeddings.
|
||||
attention_mask: Attention mask for the tokens.
|
||||
|
||||
Returns:
|
||||
Mean-pooled embeddings.
|
||||
"""
|
||||
token_embeddings = model_output[0] # First element contains all token embeddings
|
||||
input_mask_expanded = (
|
||||
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
)
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
||||
input_mask_expanded.sum(1), min=1e-9
|
||||
)
|
||||
|
||||
|
||||
class ReWiNDTransformer(nn.Module):
|
||||
"""
|
||||
ReWiND Transformer model for predicting task progress from video and text.
|
||||
|
||||
This model takes video frame embeddings and text embeddings as input,
|
||||
and predicts a progress score (0-1) for each frame indicating how much
|
||||
of the task has been completed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_dim: int = 768,
|
||||
text_dim: int = 384,
|
||||
hidden_dim: int = 512,
|
||||
num_heads: int = 8,
|
||||
num_layers: int = 4,
|
||||
max_length: int = 32,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_length = max_length
|
||||
|
||||
# Project video and text to common dimension
|
||||
self.video_proj = nn.Linear(video_dim, hidden_dim)
|
||||
self.text_proj = nn.Linear(text_dim, hidden_dim)
|
||||
|
||||
# Position embeddings for video sequence
|
||||
# We only add positional embedding to the first frame as in the original
|
||||
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
|
||||
|
||||
# Transformer encoder
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=hidden_dim * 4,
|
||||
dropout=dropout,
|
||||
batch_first=True
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
# Progress prediction head (applied to each frame)
|
||||
self.progress_head = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.LayerNorm(hidden_dim // 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_dim // 2, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Attention mask for causal self-attention
|
||||
# Will be created on-demand based on sequence length
|
||||
self.register_buffer("attention_mask", None, persistent=False)
|
||||
|
||||
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
|
||||
"""Generate or retrieve cached causal attention mask."""
|
||||
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
|
||||
# Create causal mask (upper triangular with -inf)
|
||||
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
|
||||
self.attention_mask = mask
|
||||
return self.attention_mask
|
||||
|
||||
def forward(self, video_frames: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the ReWiND transformer.
|
||||
|
||||
Args:
|
||||
video_frames: Video frame embeddings (batch_size, seq_len, video_dim)
|
||||
text_embed: Text embeddings (batch_size, text_dim)
|
||||
|
||||
Returns:
|
||||
Progress predictions for each frame (batch_size, seq_len, 1)
|
||||
"""
|
||||
batch_size = video_frames.shape[0]
|
||||
|
||||
# Project inputs to common dimension
|
||||
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
|
||||
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
|
||||
|
||||
# Add positional embedding to first video frame
|
||||
video_embed[:, 0] += self.first_pos_embed
|
||||
|
||||
# Combine sequence: [text, video_frames]
|
||||
sequence = torch.cat([text_embed, video_embed], dim=1)
|
||||
|
||||
# Get causal attention mask
|
||||
seq_length = sequence.shape[1]
|
||||
attention_mask = self._get_attention_mask(seq_length, sequence.device)
|
||||
|
||||
# Pass through transformer with causal masking
|
||||
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
|
||||
|
||||
# Get progress predictions for each frame (exclude text token)
|
||||
progress_preds = self.progress_head(transformed[:, 1:])
|
||||
|
||||
return progress_preds
|
||||
|
||||
|
||||
class ReWiNDRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
ReWiND Reward Model for computing task completion rewards from video and text.
|
||||
|
||||
This model combines:
|
||||
- DINO (DINOv2) for encoding video frames
|
||||
- MiniLM for encoding text descriptions
|
||||
- ReWiNDTransformer for predicting task progress
|
||||
"""
|
||||
|
||||
name = "rewind"
|
||||
config_class = ReWiNDConfig
|
||||
|
||||
def __init__(self, config: ReWiNDConfig, dataset_stats: dict | None = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Initialize DINO encoder for images
|
||||
logging.info("Loading DINO encoder...")
|
||||
self.dino_encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
|
||||
self.dino_encoder.to(self.device)
|
||||
self.dino_encoder.eval()
|
||||
|
||||
# Initialize MiniLM encoder for text
|
||||
logging.info("Loading MiniLM encoder...")
|
||||
self.minilm_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L12-v2"
|
||||
)
|
||||
self.minilm_model = AutoModel.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L12-v2"
|
||||
)
|
||||
self.minilm_model.to(self.device)
|
||||
self.minilm_model.eval()
|
||||
|
||||
# Initialize ReWiND transformer with explicit architecture parameters
|
||||
self.rewind_transformer = ReWiNDTransformer(
|
||||
video_dim=config.video_dim,
|
||||
text_dim=config.text_dim,
|
||||
hidden_dim=config.hidden_dim,
|
||||
num_heads=config.num_heads,
|
||||
num_layers=config.num_layers,
|
||||
max_length=config.max_length,
|
||||
dropout=config.dropout
|
||||
)
|
||||
self.rewind_transformer.to(self.device)
|
||||
|
||||
logging.info(f"ReWiND Reward Model initialized on {self.device}")
|
||||
|
||||
def to(self, device):
|
||||
"""Override to method to ensure all components move together."""
|
||||
super().to(device)
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
self.dino_encoder.to(device)
|
||||
self.minilm_model.to(device)
|
||||
self.rewind_transformer.to(device)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_images(self, images: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Encode video frames using DINO.
|
||||
|
||||
Args:
|
||||
images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8.
|
||||
Can also be (num_frames, H, W, C) for a single video.
|
||||
|
||||
Returns:
|
||||
Encoded image features (num_videos, num_frames, 768) or (num_frames, 768).
|
||||
"""
|
||||
# Handle single video case
|
||||
single_video = False
|
||||
if len(images.shape) == 4:
|
||||
images = images[np.newaxis, ...]
|
||||
single_video = True
|
||||
|
||||
assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}"
|
||||
|
||||
# Ensure channels are in correct position
|
||||
if images.shape[-1] == 3 and images.shape[2] != 3:
|
||||
images = np.transpose(images, (0, 1, 4, 2, 3))
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
for video in images:
|
||||
# Process each video
|
||||
video_embeddings = []
|
||||
|
||||
# Convert frames to list of numpy arrays
|
||||
frames = [frame.transpose(1, 2, 0).astype(np.uint8) if frame.shape[0] == 3 else frame for frame in video]
|
||||
|
||||
# Batch process frames with DINO
|
||||
episode_images_dino = [dino_load_image(frame) for frame in frames]
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(episode_images_dino), self.config.dino_batch_size):
|
||||
batch = torch.cat(episode_images_dino[i:i + self.config.dino_batch_size])
|
||||
batch = batch.to(self.device)
|
||||
embeddings = self.dino_encoder(batch).squeeze().detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
embeddings = embeddings.unsqueeze(0)
|
||||
|
||||
video_embeddings.append(embeddings)
|
||||
|
||||
video_embeddings = torch.cat(video_embeddings)
|
||||
all_embeddings.append(video_embeddings)
|
||||
|
||||
result = torch.stack(all_embeddings).numpy()
|
||||
|
||||
if single_video:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||||
"""
|
||||
Encode text using MiniLM.
|
||||
|
||||
Args:
|
||||
text: Text string or list of text strings.
|
||||
|
||||
Returns:
|
||||
Encoded text features (batch_size, 384) or (384,) for single text.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
single_text = True
|
||||
else:
|
||||
single_text = False
|
||||
|
||||
# Process in batches
|
||||
all_embeddings = []
|
||||
for i in range(0, len(text), self.config.batch_size):
|
||||
batch_text = text[i:i + self.config.batch_size]
|
||||
|
||||
encoded_input = self.minilm_tokenizer(
|
||||
batch_text, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
model_output = self.minilm_model(**encoded_input)
|
||||
text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
|
||||
|
||||
all_embeddings.append(text_embeddings.cpu())
|
||||
|
||||
result = torch.cat(all_embeddings).numpy()
|
||||
|
||||
if single_text:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
|
||||
def padding_video(self, video_frames: torch.Tensor, max_length: int) -> torch.Tensor:
|
||||
"""
|
||||
Pad or subsample video frames to a fixed length.
|
||||
|
||||
Args:
|
||||
video_frames: Video frames tensor (num_frames, embedding_dim)
|
||||
max_length: Target sequence length
|
||||
|
||||
Returns:
|
||||
Padded/subsampled video frames (max_length, embedding_dim)
|
||||
"""
|
||||
video_length = len(video_frames)
|
||||
|
||||
if isinstance(video_frames, np.ndarray):
|
||||
video_frames = torch.tensor(video_frames)
|
||||
|
||||
if video_length < max_length:
|
||||
# Pad with last frame
|
||||
padding_length = max_length - video_length
|
||||
last_frame = video_frames[-1].unsqueeze(0)
|
||||
padding_frames = last_frame.repeat(padding_length, 1)
|
||||
video_frames = torch.cat([video_frames, padding_frames], dim=0)
|
||||
|
||||
elif video_length > max_length:
|
||||
# Subsample uniformly
|
||||
frame_idx = np.linspace(0, video_length - 1, max_length).astype(int)
|
||||
video_frames = video_frames[frame_idx]
|
||||
|
||||
return video_frames
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_rewards(
|
||||
self,
|
||||
text_embeddings: Union[np.ndarray, torch.Tensor],
|
||||
video_embeddings: Union[np.ndarray, torch.Tensor],
|
||||
return_all_frames: bool = False
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Calculate rewards for given text and video representations.
|
||||
|
||||
Args:
|
||||
text_embeddings: Encoded text representations (batch_size, 384)
|
||||
video_embeddings: Encoded video representations (batch_size, num_frames, 768)
|
||||
return_all_frames: If True, return rewards for all frames. If False, return only last frame.
|
||||
|
||||
Returns:
|
||||
Reward values (batch_size,) or (batch_size, num_frames) if return_all_frames=True
|
||||
"""
|
||||
# Convert to tensors if needed
|
||||
if isinstance(text_embeddings, np.ndarray):
|
||||
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
|
||||
if isinstance(video_embeddings, np.ndarray):
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
|
||||
# Handle single sample case
|
||||
if text_embeddings.dim() == 1:
|
||||
text_embeddings = text_embeddings.unsqueeze(0)
|
||||
video_embeddings = video_embeddings.unsqueeze(0)
|
||||
single_sample = True
|
||||
else:
|
||||
single_sample = False
|
||||
|
||||
# Process in batches
|
||||
all_rewards = []
|
||||
for i in range(0, len(video_embeddings), self.config.batch_size):
|
||||
batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device)
|
||||
batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device)
|
||||
|
||||
# Pad/subsample videos if needed
|
||||
if self.config.subsample_video:
|
||||
padded_videos = []
|
||||
for video in batch_videos:
|
||||
padded_video = self.padding_video(video, self.config.max_length)
|
||||
padded_videos.append(padded_video)
|
||||
batch_videos = torch.stack(padded_videos).to(self.device)
|
||||
|
||||
# Get progress predictions
|
||||
rewards = self.rewind_transformer(batch_videos.float(), batch_texts.float())
|
||||
|
||||
if return_all_frames:
|
||||
all_rewards.append(rewards.squeeze(-1).cpu())
|
||||
else:
|
||||
# Return only last frame reward
|
||||
all_rewards.append(rewards[:, -1, 0].cpu())
|
||||
|
||||
result = torch.cat(all_rewards).numpy()
|
||||
|
||||
if single_sample:
|
||||
result = result[0] if not return_all_frames else result[0]
|
||||
|
||||
return result
|
||||
|
||||
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
|
||||
"""
|
||||
Load pretrained model weights from a checkpoint file.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to the .pth checkpoint file
|
||||
strict: Whether to strictly enforce that the keys match
|
||||
"""
|
||||
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if "model_state_dict" in checkpoint:
|
||||
state_dict = checkpoint["model_state_dict"]
|
||||
|
||||
# Check for architecture parameters in checkpoint
|
||||
if "args" in checkpoint:
|
||||
args = checkpoint["args"]
|
||||
logging.info(f"Checkpoint was trained with: max_length={args.max_length}")
|
||||
|
||||
# Warn if max_length differs
|
||||
if hasattr(args, 'max_length') and args.max_length != self.config.max_length:
|
||||
logging.warning(
|
||||
f"Checkpoint max_length ({args.max_length}) differs from config ({self.config.max_length}). "
|
||||
"This may cause issues if sequence lengths don't match."
|
||||
)
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
# Load only the ReWiNDTransformer weights
|
||||
missing_keys, unexpected_keys = self.rewind_transformer.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
if missing_keys:
|
||||
logging.warning(f"Missing keys when loading checkpoint: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
|
||||
|
||||
logging.info("Checkpoint loaded successfully")
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode. Note: DINO and MiniLM encoders always stay in eval mode."""
|
||||
super().train(mode)
|
||||
# Keep encoders in eval mode
|
||||
self.dino_encoder.eval()
|
||||
self.minilm_model.eval()
|
||||
# Only transformer can be trained
|
||||
self.rewind_transformer.train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Set evaluation mode."""
|
||||
return self.train(False)
|
||||
|
||||
def parameters(self):
|
||||
"""Return trainable parameters (only ReWiND transformer, not encoders)."""
|
||||
return self.rewind_transformer.parameters()
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward models.
|
||||
The reward model does not maintain state between episodes.
|
||||
"""
|
||||
pass
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward models.
|
||||
The rewind model is not an actor and does not produce action chunks.
|
||||
"""
|
||||
raise NotImplementedError("Rewind model does not predict action chunks")
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for rewind.
|
||||
The rewind model is not an actor and does not select actions.
|
||||
"""
|
||||
raise NotImplementedError("Rewind model does not select actions")
|
||||
|
||||
def forward(self, batch):
|
||||
"""
|
||||
Forward pass compatible with lerobot training pipeline.
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing observation with:
|
||||
- 'video_features': Pre-encoded video features (B, 768) or (B, T, 768)
|
||||
- 'text_features': Pre-encoded text features (B, 384)
|
||||
|
||||
Returns:
|
||||
loss: Total training loss
|
||||
output_dict: Dictionary of loss components for logging
|
||||
"""
|
||||
# Extract from observation dict
|
||||
observation = batch.get('observation', batch)
|
||||
video_features = observation['video_features'].to(self.device)
|
||||
text_features = observation['text_features'].to(self.device)
|
||||
|
||||
batch_size = video_features.shape[0]
|
||||
max_length = self.config.max_length
|
||||
|
||||
# Handle both single frames (B, 768) and sequences (B, T, 768)
|
||||
if video_features.dim() == 2:
|
||||
# Single frames: replicate to create pseudo-sequences
|
||||
video_features = video_features.unsqueeze(1).repeat(1, max_length, 1) # (B, max_length, 768)
|
||||
|
||||
# Now video_features is (B, T, 768) where T might be > max_length
|
||||
|
||||
# Process videos (with potential rewind augmentation)
|
||||
import random
|
||||
from lerobot.datasets.video_sampler import sample_video_feature, sample_reverse_video_feature
|
||||
|
||||
processed_videos = []
|
||||
progress_targets = []
|
||||
|
||||
# Extract episode metadata for correct progress normalization
|
||||
absolute_frame_indices = observation.get('absolute_frame_indices', None)
|
||||
episode_lengths = observation.get('episode_length', None)
|
||||
remaining_lengths = observation.get('remaining_length', None)
|
||||
|
||||
for i in range(batch_size):
|
||||
# Get metadata for this sample
|
||||
current_absolute_indices = None
|
||||
current_episode_length = None
|
||||
current_remaining_length = None
|
||||
|
||||
if absolute_frame_indices is not None:
|
||||
if isinstance(absolute_frame_indices, list):
|
||||
current_absolute_indices = absolute_frame_indices[i]
|
||||
else:
|
||||
current_absolute_indices = absolute_frame_indices
|
||||
|
||||
if episode_lengths is not None:
|
||||
if isinstance(episode_lengths, torch.Tensor) and episode_lengths.dim() > 0:
|
||||
current_episode_length = episode_lengths[i].item()
|
||||
else:
|
||||
current_episode_length = episode_lengths.item() if isinstance(episode_lengths, torch.Tensor) else episode_lengths
|
||||
|
||||
if remaining_lengths is not None:
|
||||
if isinstance(remaining_lengths, torch.Tensor) and remaining_lengths.dim() > 0:
|
||||
current_remaining_length = remaining_lengths[i].item()
|
||||
else:
|
||||
current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths
|
||||
|
||||
if random.random() < self.config.rewind_ratio: # Use configurable rewind ratio
|
||||
# Apply video rewind augmentation (now returns tuple)
|
||||
rewound_video, progress = sample_reverse_video_feature(
|
||||
video_features[i],
|
||||
max_length=max_length,
|
||||
random_sample=True, # Use random sampling (original ReWiND)
|
||||
remaining_length=current_remaining_length
|
||||
)
|
||||
processed_videos.append(rewound_video.to(self.device))
|
||||
progress_targets.append(progress.to(self.device))
|
||||
else:
|
||||
# Normal video sampling (now returns tuple with progress targets)
|
||||
sampled_video, progress = sample_video_feature(
|
||||
video_features[i],
|
||||
max_length=max_length,
|
||||
random_sample=True, # Use random sampling (original ReWiND)
|
||||
remaining_length=current_remaining_length
|
||||
)
|
||||
processed_videos.append(sampled_video.to(self.device))
|
||||
progress_targets.append(progress.to(self.device))
|
||||
|
||||
processed_videos = torch.stack(processed_videos)
|
||||
progress_targets = torch.stack(progress_targets)
|
||||
|
||||
# Compute progress loss
|
||||
progress_loss = compute_progress_loss(
|
||||
self.rewind_transformer,
|
||||
processed_videos,
|
||||
text_features,
|
||||
progress_targets
|
||||
)
|
||||
|
||||
total_loss = progress_loss
|
||||
output_dict = {'progress_loss': progress_loss.item()}
|
||||
|
||||
# Compute misaligned loss if requested (20% probability to match original)
|
||||
if random.random() < 0.2: # 20% chance of adding misalignment loss (original ReWiND uses 20%)
|
||||
if 'misaligned_video_features' in batch and 'misaligned_text_features' in batch:
|
||||
misaligned_videos = batch['misaligned_video_features'].to(self.device)
|
||||
misaligned_texts = batch['misaligned_text_features'].to(self.device)
|
||||
else:
|
||||
# Create misaligned pairs by shuffling
|
||||
shuffle_idx = torch.randperm(batch_size)
|
||||
misaligned_videos = processed_videos[shuffle_idx]
|
||||
misaligned_texts = text_features
|
||||
|
||||
# Sample misaligned videos (function now returns tuple)
|
||||
# For misaligned pairs, we don't need correct progress targets (will be set to 0)
|
||||
misaligned_videos_sampled = []
|
||||
for i in range(batch_size):
|
||||
# For misaligned videos, use video length as remaining_length
|
||||
video_len = len(misaligned_videos[i])
|
||||
sampled, _ = sample_video_feature(
|
||||
misaligned_videos[i],
|
||||
max_length=max_length,
|
||||
random_sample=True,
|
||||
remaining_length=video_len # Use video length for misaligned pairs
|
||||
)
|
||||
misaligned_videos_sampled.append(sampled.to(self.device))
|
||||
misaligned_videos_sampled = torch.stack(misaligned_videos_sampled)
|
||||
|
||||
misaligned_loss = compute_misaligned_loss(
|
||||
self.rewind_transformer,
|
||||
misaligned_videos_sampled,
|
||||
misaligned_texts
|
||||
)
|
||||
|
||||
total_loss = total_loss + misaligned_loss
|
||||
output_dict['misaligned_loss'] = misaligned_loss.item()
|
||||
|
||||
output_dict['total_loss'] = total_loss.item()
|
||||
|
||||
return total_loss, output_dict
|
||||
|
||||
|
||||
# Loss utilities
|
||||
def compute_progress_loss(
|
||||
model: ReWiNDTransformer,
|
||||
video_features: torch.Tensor,
|
||||
text_features: torch.Tensor,
|
||||
target_progress: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute progress prediction loss.
|
||||
|
||||
Args:
|
||||
model: ReWiNDTransformer model
|
||||
video_features: Batch of video features (batch_size, max_length, feature_dim)
|
||||
text_features: Batch of text features (batch_size, text_dim)
|
||||
target_progress: Optional target progress values (batch_size, max_length).
|
||||
If None, uses linear progress from 0 to 1.
|
||||
|
||||
Returns:
|
||||
Mean squared error loss
|
||||
"""
|
||||
# Get predictions
|
||||
progress_preds = model(video_features, text_features)
|
||||
|
||||
# Create target progress if not provided
|
||||
if target_progress is None:
|
||||
batch_size, max_length = video_features.shape[:2]
|
||||
target_progress = torch.linspace(0, 1, max_length, device=video_features.device)
|
||||
target_progress = target_progress.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
# Ensure target has correct shape
|
||||
if target_progress.dim() == 2:
|
||||
target_progress = target_progress.unsqueeze(-1)
|
||||
|
||||
# Compute MSE loss
|
||||
loss = F.mse_loss(progress_preds, target_progress)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def compute_misaligned_loss(
|
||||
model: ReWiNDTransformer,
|
||||
video_features: torch.Tensor,
|
||||
misaligned_text_features: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss for misaligned video-text pairs (should predict 0 progress).
|
||||
|
||||
Args:
|
||||
model: ReWiNDTransformer model
|
||||
video_features: Batch of video features (batch_size, max_length, feature_dim)
|
||||
misaligned_text_features: Batch of misaligned text features (batch_size, text_dim)
|
||||
|
||||
Returns:
|
||||
Mean squared error loss (predictions should be close to 0)
|
||||
"""
|
||||
# Get predictions
|
||||
progress_preds = model(video_features, misaligned_text_features)
|
||||
|
||||
# Target is all zeros
|
||||
target_zeros = torch.zeros_like(progress_preds)
|
||||
|
||||
# Compute MSE loss
|
||||
loss = F.mse_loss(progress_preds, target_zeros)
|
||||
|
||||
return loss
|
||||
@@ -1,405 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
|
||||
from lerobot.processor import (
|
||||
ProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
PolicyAction,
|
||||
DeviceProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.processor.pipeline import PipelineFeatureType
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.configs.types import PolicyFeature, FeatureType
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
class ReWiNDEncodingProcessorStep(ProcessorStep):
|
||||
"""
|
||||
ProcessorStep that encodes images and text for ReWiND training.
|
||||
|
||||
This step handles the DINO (image) and MiniLM (text) encoding that ReWiND needs.
|
||||
|
||||
Supports both single-frame and temporal sequence encoding:
|
||||
- Single frame: (B, C, H, W) → (B, 768) video features
|
||||
- Temporal sequence: (B, T, C, H, W) → (B, T, 768) video features
|
||||
|
||||
To use temporal sequences, configure the dataset with delta_timestamps for your image key.
|
||||
For example, to encode sequences of 32 frames:
|
||||
delta_timestamps = {
|
||||
"observation.images.top": [i / fps for i in range(-15, 17)] # 32 frames centered on current
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ReWiNDConfig,
|
||||
image_key: str | None = None,
|
||||
task_description: str | None = None,
|
||||
dataset_meta = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.image_key = image_key or config.image_key
|
||||
self.task_description = task_description or config.task_description
|
||||
self.dataset_meta = dataset_meta # Store dataset metadata for episode info
|
||||
|
||||
# Initialize encoders
|
||||
self._init_encoders()
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize DINO and MiniLM encoders."""
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
device = torch.device(
|
||||
self.config.device if self.config.device
|
||||
else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
logging.info("Initializing DINO encoder for ReWiND...")
|
||||
self.dino_encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
|
||||
self.dino_encoder.to(device)
|
||||
self.dino_encoder.eval()
|
||||
|
||||
logging.info("Initializing MiniLM encoder for ReWiND...")
|
||||
self.minilm_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L12-v2"
|
||||
)
|
||||
self.minilm_model = AutoModel.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L12-v2"
|
||||
)
|
||||
self.minilm_model.to(device)
|
||||
self.minilm_model.eval()
|
||||
|
||||
self.device = device
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Encode images and text in the transition."""
|
||||
self._current_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
|
||||
new_transition = self._current_transition
|
||||
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or not isinstance(observation, dict):
|
||||
# If no observation, just return the transition as-is
|
||||
return new_transition
|
||||
|
||||
# Extract images from observation and encode
|
||||
# For ReWiND, we need to load the sequence from episode start to current frame
|
||||
batch_size = 1
|
||||
if self.image_key in observation:
|
||||
image = observation[self.image_key]
|
||||
|
||||
# Handle different image formats
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
|
||||
# Check if we have temporal sequences or single frames
|
||||
# Temporal sampling: Load from episode start to current frame
|
||||
# This will be handled by the dataset if configured with delta_timestamps
|
||||
# Otherwise, we just encode the single frame
|
||||
video_features = self._encode_images_batch(image)
|
||||
observation['video_features'] = video_features
|
||||
|
||||
# Get batch size from the encoded features
|
||||
batch_size = video_features.shape[0]
|
||||
|
||||
# Get task descriptions - check if 'task' field exists in the transition
|
||||
# This allows per-episode task descriptions (e.g., for datasets with multiple tasks)
|
||||
task_descriptions = None
|
||||
if 'task' in new_transition:
|
||||
task_descriptions = new_transition['task']
|
||||
# Convert to list if it's a single string
|
||||
if isinstance(task_descriptions, str):
|
||||
task_descriptions = [task_descriptions] * batch_size
|
||||
|
||||
# Encode text
|
||||
if task_descriptions is not None:
|
||||
# Encode per-sample task descriptions
|
||||
text_features = self._encode_text_batch_list(task_descriptions)
|
||||
else:
|
||||
# Fall back to config task description if no task field in transition
|
||||
text_features = self._encode_text_batch(self.task_description, batch_size)
|
||||
|
||||
observation['text_features'] = text_features
|
||||
|
||||
# Compute episode metadata for progress normalization (used by ReWiND)
|
||||
# We need to pass absolute frame indices and total episode length for correct progress calculation
|
||||
if self.dataset_meta is not None and 'episode_index' in new_transition and 'index' in new_transition:
|
||||
episode_indices = new_transition['episode_index']
|
||||
frame_indices = new_transition['index']
|
||||
|
||||
# Handle both single samples and batches
|
||||
if isinstance(episode_indices, (int, np.integer)):
|
||||
ep_idx = int(episode_indices)
|
||||
frame_idx = int(frame_indices)
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# For temporal sequences with observation_delta_indices:
|
||||
# If we loaded frames using delta_indices (e.g., [-31, -30, ..., 0]),
|
||||
# we need to compute the absolute indices of those frames
|
||||
# The current frame is at frame_idx, and we loaded max_length frames before it
|
||||
if 'video_features' in observation and len(observation['video_features'].shape) > 1:
|
||||
# We have a temporal sequence
|
||||
num_loaded_frames = observation['video_features'].shape[0] if observation['video_features'].dim() == 2 else observation['video_features'].shape[1]
|
||||
# Absolute indices: from (frame_idx - num_frames + 1) to frame_idx
|
||||
start_idx = max(ep_start, frame_idx - num_loaded_frames + 1)
|
||||
absolute_indices = torch.arange(start_idx, frame_idx + 1)
|
||||
observation['absolute_frame_indices'] = absolute_indices
|
||||
# Compute remaining length: from first loaded frame to episode end
|
||||
observation['remaining_length'] = ep_end - start_idx
|
||||
else:
|
||||
# Single frame
|
||||
observation['absolute_frame_indices'] = torch.tensor([frame_idx])
|
||||
# Remaining length from this frame to episode end
|
||||
observation['remaining_length'] = ep_end - frame_idx
|
||||
|
||||
observation['episode_length'] = episode_length
|
||||
else:
|
||||
# Batch case
|
||||
absolute_indices_list = []
|
||||
episode_lengths = []
|
||||
remaining_lengths = []
|
||||
for ep_idx, frame_idx in zip(episode_indices, frame_indices):
|
||||
ep_idx = int(ep_idx.item() if hasattr(ep_idx, 'item') else ep_idx)
|
||||
frame_idx = int(frame_idx.item() if hasattr(frame_idx, 'item') else frame_idx)
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
episode_length = ep_end - ep_start
|
||||
episode_lengths.append(episode_length)
|
||||
|
||||
# Compute absolute indices for this sample
|
||||
if 'video_features' in observation and len(observation['video_features'].shape) > 1:
|
||||
num_loaded_frames = observation['video_features'].shape[1]
|
||||
start_idx = max(ep_start, frame_idx - num_loaded_frames + 1)
|
||||
absolute_indices = torch.arange(start_idx, frame_idx + 1)
|
||||
absolute_indices_list.append(absolute_indices)
|
||||
# Remaining length from first loaded frame to episode end
|
||||
remaining_lengths.append(ep_end - start_idx)
|
||||
else:
|
||||
absolute_indices_list.append(torch.tensor([frame_idx]))
|
||||
# Remaining length from this frame to episode end
|
||||
remaining_lengths.append(ep_end - frame_idx)
|
||||
|
||||
observation['absolute_frame_indices'] = absolute_indices_list
|
||||
observation['episode_length'] = torch.tensor(episode_lengths)
|
||||
observation['remaining_length'] = torch.tensor(remaining_lengths)
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
|
||||
"""Encode a batch of images (with optional temporal dimension) using DINO.
|
||||
|
||||
Args:
|
||||
images: Batched images with shape:
|
||||
- (B, C, H, W) for single frames, or
|
||||
- (B, T, C, H, W) for temporal sequences
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, 768) or (B, T, 768)
|
||||
"""
|
||||
from lerobot.policies.rewind.modeling_rewind import dino_load_image
|
||||
|
||||
# Check if we have temporal dimension
|
||||
has_temporal = len(images.shape) == 5
|
||||
|
||||
if has_temporal:
|
||||
# Shape: (B, T, C, H, W)
|
||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||
|
||||
# Reshape to (B*T, C, H, W) to process all frames at once
|
||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||
elif len(images.shape) == 4:
|
||||
# Shape: (B, C, H, W)
|
||||
batch_size = images.shape[0]
|
||||
seq_length = 1
|
||||
else:
|
||||
raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}")
|
||||
|
||||
# Convert to list of (H, W, C) images
|
||||
num_frames = images.shape[0]
|
||||
if images.shape[1] in [1, 3]: # Channel first (N, C, H, W)
|
||||
images_list = [images[i].transpose(1, 2, 0) for i in range(num_frames)]
|
||||
else: # Channel last (N, H, W, C)
|
||||
images_list = [images[i] for i in range(num_frames)]
|
||||
|
||||
# Encode each frame (can batch process with DINO for efficiency)
|
||||
all_embeddings = []
|
||||
for i in range(0, num_frames, self.config.dino_batch_size):
|
||||
batch_imgs = images_list[i:i + self.config.dino_batch_size]
|
||||
|
||||
# Prepare images for DINO
|
||||
dino_inputs = []
|
||||
for img in batch_imgs:
|
||||
# Handle single channel
|
||||
if img.shape[-1] == 1:
|
||||
img = np.repeat(img, 3, axis=-1)
|
||||
|
||||
# Convert to uint8
|
||||
if img.dtype != np.uint8:
|
||||
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
||||
|
||||
dino_inputs.append(dino_load_image(img))
|
||||
|
||||
# Batch encode
|
||||
dino_batch = torch.cat(dino_inputs).to(self.device)
|
||||
embeddings = self.dino_encoder(dino_batch).detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
embeddings = embeddings.unsqueeze(0)
|
||||
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
# Concatenate all embeddings
|
||||
all_embeddings = torch.cat(all_embeddings) # (B*T, 768)
|
||||
|
||||
# Reshape back if temporal
|
||||
if has_temporal:
|
||||
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 768)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_text_batch(self, text: str, batch_size: int) -> torch.Tensor:
|
||||
"""Encode a text string using MiniLM and replicate for batch.
|
||||
|
||||
Args:
|
||||
text: Text string to encode
|
||||
batch_size: Batch size to replicate for
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, 384)
|
||||
"""
|
||||
from lerobot.policies.rewind.modeling_rewind import mean_pooling
|
||||
|
||||
encoded_input = self.minilm_tokenizer(
|
||||
text, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
model_output = self.minilm_model(**encoded_input)
|
||||
text_embedding = mean_pooling(model_output, encoded_input["attention_mask"])
|
||||
text_embedding = text_embedding.squeeze().cpu()
|
||||
|
||||
# Replicate for batch (B, 384)
|
||||
text_embedding = text_embedding.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
return text_embedding
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_text_batch_list(self, text_list: list[str]) -> torch.Tensor:
|
||||
"""Encode a list of text strings using MiniLM (one per sample).
|
||||
|
||||
Args:
|
||||
text_list: List of text strings to encode
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, 384)
|
||||
"""
|
||||
from lerobot.policies.rewind.modeling_rewind import mean_pooling
|
||||
|
||||
# Encode all texts in the batch at once
|
||||
encoded_input = self.minilm_tokenizer(
|
||||
text_list, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
model_output = self.minilm_model(**encoded_input)
|
||||
text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
|
||||
text_embeddings = text_embeddings.cpu()
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Adds video_features and text_features to the observation features.
|
||||
"""
|
||||
# Add the encoded features
|
||||
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(768,) # DINO embedding dimension
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE,
|
||||
shape=(384,) # MiniLM embedding dimension
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def make_rewind_pre_post_processors(
|
||||
config: ReWiNDConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_meta = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Create pre-processor and post-processor pipelines for ReWiND.
|
||||
|
||||
The pre-processing pipeline:
|
||||
1. Encodes images with DINO (768-dim)
|
||||
2. Encodes text with MiniLM (384-dim)
|
||||
3. Computes remaining episode length for progress normalization
|
||||
4. Adds batch dimension
|
||||
5. Moves data to device
|
||||
|
||||
The post-processing pipeline moves data back to CPU.
|
||||
|
||||
Args:
|
||||
config: ReWiND configuration
|
||||
dataset_stats: Dataset statistics (not used for ReWiND)
|
||||
dataset_meta: Dataset metadata for computing episode remaining length
|
||||
|
||||
Returns:
|
||||
Tuple of (preprocessor, postprocessor) pipelines
|
||||
"""
|
||||
input_steps = [
|
||||
AddBatchDimensionProcessorStep(),
|
||||
ReWiNDEncodingProcessorStep(config=config, dataset_meta=dataset_meta),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -19,7 +19,6 @@ from lerobot.policies.sarm.modeling_sarm import (
|
||||
SARMRewardModel,
|
||||
SARMTransformer,
|
||||
compute_stage_loss,
|
||||
compute_progress_loss,
|
||||
)
|
||||
from lerobot.policies.sarm.processor_sarm import (
|
||||
SARMEncodingProcessorStep,
|
||||
@@ -31,7 +30,6 @@ __all__ = [
|
||||
"SARMRewardModel",
|
||||
"SARMTransformer",
|
||||
"compute_stage_loss",
|
||||
"compute_progress_loss",
|
||||
"SARMEncodingProcessorStep",
|
||||
"make_sarm_pre_post_processors",
|
||||
]
|
||||
|
||||
@@ -32,8 +32,8 @@ class SARMConfig(PreTrainedConfig):
|
||||
num_frames: int = 9 # 1 initial + 8 consecutive frames
|
||||
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
|
||||
|
||||
# Text encoding parameters
|
||||
text_dim: int = 384
|
||||
# Text encoding parameters (CLIP text encoder output dimension)
|
||||
text_dim: int = 512
|
||||
|
||||
# Joint state parameters
|
||||
state_dim: int | None = None # Auto-detected from dataset if None
|
||||
@@ -49,7 +49,6 @@ class SARMConfig(PreTrainedConfig):
|
||||
# Temporal parameters
|
||||
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
|
||||
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
||||
sampling_mode: str = "sarm" # Sampling mode: "sarm" or "rewind"
|
||||
|
||||
# Training parameters
|
||||
batch_size: int = 64
|
||||
@@ -101,11 +100,6 @@ class SARMConfig(PreTrainedConfig):
|
||||
|
||||
if self.num_stages < 2:
|
||||
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
|
||||
|
||||
if self.sampling_mode not in ["sarm", "rewind", "custom"]:
|
||||
raise ValueError(
|
||||
f"sampling_mode must be 'sarm' or 'rewind', got {self.sampling_mode}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Get default optimizer configuration for SARM training."""
|
||||
|
||||
@@ -24,33 +24,13 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
def mean_pooling(model_output, attention_mask):
|
||||
"""
|
||||
Mean pooling - take attention mask into account for correct averaging.
|
||||
|
||||
Args:
|
||||
model_output: Model output containing token embeddings.
|
||||
attention_mask: Attention mask for the tokens.
|
||||
|
||||
Returns:
|
||||
Mean-pooled embeddings.
|
||||
"""
|
||||
token_embeddings = model_output[0] # First element contains all token embeddings
|
||||
input_mask_expanded = (
|
||||
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
)
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
||||
input_mask_expanded.sum(1), min=1e-9
|
||||
)
|
||||
|
||||
|
||||
class SARMTransformer(nn.Module):
|
||||
"""
|
||||
SARM Transformer model for stage-aware reward prediction.
|
||||
@@ -65,7 +45,7 @@ class SARMTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
video_dim: int = 512,
|
||||
text_dim: int = 384,
|
||||
text_dim: int = 512, # CLIP text encoder output dimension (per SARM paper A.4)
|
||||
state_dim: int = 14,
|
||||
hidden_dim: int = 768,
|
||||
num_heads: int = 12,
|
||||
@@ -204,7 +184,7 @@ class SARMTransformer(nn.Module):
|
||||
stage_indices = torch.argmax(stage_probs, dim=-1) # [batch_size, seq_len]
|
||||
|
||||
# Get stage embeddings for conditioning
|
||||
stage_embeds = self.stage_embedding(stage_indices) # [batch_size, seq_len, hidden_dim//4]
|
||||
stage_embeds = self.stage_embedding(stage_indices)
|
||||
|
||||
# Concatenate frame features with stage embeddings
|
||||
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
|
||||
@@ -229,9 +209,11 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
SARM Reward Model for stage-aware task completion rewards.
|
||||
|
||||
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
|
||||
to process both RGB image sequences and task descriptions."
|
||||
|
||||
This model combines:
|
||||
- CLIP for encoding video frames
|
||||
- MiniLM for encoding text descriptions
|
||||
- CLIP for encoding video frames AND text descriptions
|
||||
- SARMTransformer for predicting task stage and progress
|
||||
- Optional RA-BC (Reward-Aligned Behavior Cloning) for weighted training
|
||||
"""
|
||||
@@ -249,24 +231,13 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
if dataset_meta is not None:
|
||||
self._update_num_stages_from_dataset(dataset_meta)
|
||||
|
||||
# Initialize CLIP encoder for images
|
||||
logging.info("Loading CLIP encoder...")
|
||||
# Initialize CLIP encoder for images AND text (per SARM paper A.4)
|
||||
logging.info("Loading CLIP encoder for images and text...")
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
# Initialize MiniLM encoder for text
|
||||
logging.info("Loading MiniLM encoder...")
|
||||
self.minilm_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L12-v2"
|
||||
)
|
||||
self.minilm_model = AutoModel.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L12-v2"
|
||||
)
|
||||
self.minilm_model.to(self.device)
|
||||
self.minilm_model.eval()
|
||||
|
||||
# Auto-detect state_dim from dataset_stats
|
||||
if config.state_dim is None:
|
||||
logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}")
|
||||
@@ -379,7 +350,6 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
super().to(device)
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
self.clip_model.to(device)
|
||||
self.minilm_model.to(device)
|
||||
self.sarm_transformer.to(device)
|
||||
return self
|
||||
|
||||
@@ -445,13 +415,13 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||||
"""
|
||||
Encode text using MiniLM.
|
||||
Encode text using CLIP text encoder (per SARM paper A.4).
|
||||
|
||||
Args:
|
||||
text: Text string or list of text strings.
|
||||
|
||||
Returns:
|
||||
Encoded text features (batch_size, 384) or (384,) for single text.
|
||||
Encoded text features (batch_size, 512) or (512,) for single text.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
@@ -459,18 +429,18 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
else:
|
||||
single_text = False
|
||||
|
||||
# Use CLIP's tokenizer directly (avoids image processor validation issues)
|
||||
tokenizer = self.clip_processor.tokenizer
|
||||
|
||||
# Process in batches
|
||||
all_embeddings = []
|
||||
for i in range(0, len(text), self.config.batch_size):
|
||||
batch_text = text[i:i + self.config.batch_size]
|
||||
|
||||
encoded_input = self.minilm_tokenizer(
|
||||
batch_text, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
model_output = self.minilm_model(**encoded_input)
|
||||
text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
|
||||
inputs = tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
text_embeddings = self.clip_model.get_text_features(**inputs)
|
||||
all_embeddings.append(text_embeddings.cpu())
|
||||
|
||||
result = torch.cat(all_embeddings).numpy()
|
||||
@@ -493,7 +463,7 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
Calculate rewards for given text, video, and state representations.
|
||||
|
||||
Args:
|
||||
text_embeddings: Encoded text representations (batch_size, 384)
|
||||
text_embeddings: Encoded text representations (batch_size, 512)
|
||||
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
|
||||
state_features: Joint state features (batch_size, num_frames, state_dim)
|
||||
return_all_frames: If True, return rewards for all frames
|
||||
@@ -585,11 +555,10 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
logging.info("Checkpoint loaded successfully")
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Set training mode. Note: CLIP and MiniLM encoders always stay in eval mode."""
|
||||
"""Set training mode. Note: CLIP encoder always stays in eval mode (frozen)."""
|
||||
super().train(mode)
|
||||
# Keep encoders in eval mode
|
||||
# Keep CLIP encoder in eval mode (frozen per SARM paper)
|
||||
self.clip_model.eval()
|
||||
self.minilm_model.eval()
|
||||
# Only transformer can be trained
|
||||
self.sarm_transformer.train(mode)
|
||||
return self
|
||||
@@ -618,30 +587,18 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _get_remaining_length(self, observation: dict, idx: int) -> float | None:
|
||||
"""Extract remaining length for a sample from observation metadata."""
|
||||
remaining_lengths = observation.get('remaining_length')
|
||||
if remaining_lengths is None:
|
||||
return None
|
||||
if isinstance(remaining_lengths, torch.Tensor):
|
||||
return remaining_lengths[idx].item() if remaining_lengths.dim() > 0 else remaining_lengths.item()
|
||||
return remaining_lengths
|
||||
|
||||
def _compute_progress_targets(self, remaining_length: float | None, seq_len: int) -> torch.Tensor:
|
||||
"""Compute progress targets based on remaining trajectory length."""
|
||||
if remaining_length is not None and remaining_length > 0:
|
||||
return torch.arange(1, seq_len + 1, dtype=torch.float32, device=self.device) / remaining_length
|
||||
else:
|
||||
raise ValueError("Remaining length is None, but is required for progress targets")
|
||||
|
||||
def _apply_rewind_augmentation(
|
||||
def _apply_temporal_augmentation(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
progress: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
max_length: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""Apply rewind augmentation: append up to 4 reversed frames (SARM paper A.4)."""
|
||||
"""Apply temporal augmentation by appending reversed frames (SARM paper A.4).
|
||||
|
||||
This helps the model learn to handle non-monotonic progress (failures, recoveries).
|
||||
Appends 1-4 reversed frames to simulate going backwards in task progress.
|
||||
"""
|
||||
num_reverse = random.randint(1, min(4, max_length - 1))
|
||||
|
||||
# Reverse and take frames (skip first which is last of original)
|
||||
@@ -672,14 +629,20 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
Forward pass for SARM reward model training.
|
||||
|
||||
Uses annotation-based progress targets following SARM paper Eq. 2:
|
||||
yt = Pk-1 + α̅k × τt
|
||||
where:
|
||||
- τt = (t - sk) / (ek - sk) is within-subtask normalized time
|
||||
- Pk-1 is cumulative prior (sum of previous subtask proportions)
|
||||
- α̅k is the temporal proportion for subtask k
|
||||
|
||||
Args:
|
||||
batch: Dictionary with 'observation' containing:
|
||||
- 'video_features': (B, T, 512) pre-encoded video features
|
||||
- 'text_features': (B, 384) pre-encoded text features
|
||||
- 'text_features': (B, 512) pre-encoded text features (CLIP)
|
||||
- 'state_features': (B, T, state_dim) joint state features
|
||||
- 'remaining_length': (B,) remaining trajectory lengths (optional)
|
||||
- 'stage_labels': (B, T) stage labels (optional, from annotations)
|
||||
- 'progress_targets': (B, T, 1) progress targets (optional, from annotations)
|
||||
- 'stage_labels': (B, T) stage labels from annotations
|
||||
- 'progress_targets': (B, T, 1) progress targets from annotations
|
||||
|
||||
Returns:
|
||||
Tuple of (total_loss, output_dict with loss components)
|
||||
@@ -702,21 +665,31 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
if state_features is not None and state_features.dim() == 2:
|
||||
state_features = state_features.unsqueeze(1).expand(-1, max_length, -1)
|
||||
|
||||
# Process each sample: compute progress targets and apply rewind augmentation
|
||||
# Get annotation-based progress targets (required for SARM paper formula)
|
||||
progress_from_annotations = observation.get('progress_targets')
|
||||
if progress_from_annotations is None:
|
||||
raise ValueError("progress_targets from annotations is required for SARM training")
|
||||
|
||||
progress_from_annotations = progress_from_annotations.to(self.device)
|
||||
if progress_from_annotations.dim() == 2:
|
||||
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
|
||||
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
|
||||
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
|
||||
|
||||
# Process each sample: apply temporal augmentation (SARM paper A.4)
|
||||
processed_videos = []
|
||||
processed_states = []
|
||||
progress_targets = []
|
||||
|
||||
for i in range(batch_size):
|
||||
remaining_length = self._get_remaining_length(observation, i)
|
||||
progress = self._compute_progress_targets(remaining_length, max_length)
|
||||
|
||||
video = video_features[i]
|
||||
state = state_features[i] if state_features is not None else None
|
||||
progress = progress_from_annotations[i].squeeze(-1) # (T,)
|
||||
|
||||
# Apply rewind augmentation with 50% probability (SARM paper)
|
||||
# Apply temporal augmentation with 50% probability (SARM paper A.4)
|
||||
# Appends up to 4 reversed frames to simulate failures/recoveries
|
||||
if random.random() < 0.5:
|
||||
video, progress, state = self._apply_rewind_augmentation(video, progress, state, max_length)
|
||||
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
|
||||
|
||||
# Ensure correct sequence length
|
||||
video = self._ensure_sequence_length(video, max_length)
|
||||
@@ -739,32 +712,22 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
processed_videos, text_features, processed_states
|
||||
)
|
||||
|
||||
# Use annotation-based progress targets
|
||||
progress_from_annotations = observation.get('progress_targets')
|
||||
if progress_from_annotations is not None:
|
||||
progress_from_annotations = progress_from_annotations.to(self.device)
|
||||
if progress_from_annotations.dim() == 2:
|
||||
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
|
||||
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
|
||||
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
|
||||
progress_targets = progress_from_annotations
|
||||
|
||||
# Compute progress loss
|
||||
# Compute progress loss (MSE)
|
||||
progress_loss = F.mse_loss(progress_preds, progress_targets)
|
||||
output_dict = {'progress_loss': progress_loss.item()}
|
||||
total_loss = progress_loss
|
||||
|
||||
# Compute stage loss if labels available
|
||||
# Compute stage loss (cross-entropy)
|
||||
stage_labels = observation.get('stage_labels')
|
||||
if stage_labels is not None:
|
||||
stage_labels = stage_labels.to(self.device)
|
||||
if stage_labels.dim() == 1:
|
||||
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
|
||||
stage_loss = compute_stage_loss(stage_logits, stage_labels)
|
||||
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
|
||||
output_dict['stage_loss'] = stage_loss.item()
|
||||
else:
|
||||
raise ValueError("Stage labels are None, but are required for stage loss")
|
||||
if stage_labels is None:
|
||||
raise ValueError("stage_labels from annotations is required for SARM training")
|
||||
|
||||
stage_labels = stage_labels.to(self.device)
|
||||
if stage_labels.dim() == 1:
|
||||
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
|
||||
stage_loss = compute_stage_loss(stage_logits, stage_labels)
|
||||
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
|
||||
output_dict['stage_loss'] = stage_loss.item()
|
||||
|
||||
# Misaligned loss: 20% probability (SARM paper - improve video-language alignment)
|
||||
if random.random() < 0.2:
|
||||
@@ -786,9 +749,3 @@ def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor)
|
||||
|
||||
loss = F.cross_entropy(stage_logits_flat, target_stages_flat)
|
||||
return loss
|
||||
|
||||
|
||||
def compute_progress_loss(progress_preds: torch.Tensor, target_progress: torch.Tensor) -> torch.Tensor:
|
||||
loss = F.mse_loss(progress_preds, target_progress)
|
||||
return loss
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.processor import (
|
||||
@@ -44,9 +44,12 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
"""
|
||||
ProcessorStep that encodes images and text for SARM training.
|
||||
|
||||
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
|
||||
to process both RGB image sequences and task descriptions."
|
||||
|
||||
This step handles:
|
||||
- CLIP (image) encoding
|
||||
- MiniLM (text) encoding
|
||||
- CLIP image encoding (512-dim)
|
||||
- CLIP text encoding (512-dim)
|
||||
- Joint state normalization
|
||||
|
||||
Supports temporal sequences: (B, T, C, H, W) → (B, T, 512) video features
|
||||
@@ -76,28 +79,18 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
self._init_encoders()
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize CLIP and MiniLM encoders."""
|
||||
"""Initialize CLIP encoder for both images and text (per SARM paper A.4)."""
|
||||
device = torch.device(
|
||||
self.config.device if self.config.device
|
||||
else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
logging.info("Initializing CLIP encoder for SARM...")
|
||||
logging.info("Initializing CLIP encoder for SARM (images + text)...")
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(device)
|
||||
self.clip_model.eval()
|
||||
|
||||
logging.info("Initializing MiniLM encoder for SARM...")
|
||||
self.minilm_tokenizer = AutoTokenizer.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L12-v2"
|
||||
)
|
||||
self.minilm_model = AutoModel.from_pretrained(
|
||||
"sentence-transformers/all-MiniLM-L12-v2"
|
||||
)
|
||||
self.minilm_model.to(device)
|
||||
self.minilm_model.eval()
|
||||
|
||||
self.device = device
|
||||
|
||||
def _compute_temporal_proportions(self):
|
||||
@@ -167,11 +160,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
for name in self.subtask_names
|
||||
}
|
||||
else:
|
||||
# Equal proportions if no duration info
|
||||
self.temporal_proportions = {
|
||||
name: 1.0 / len(self.subtask_names)
|
||||
for name in self.subtask_names
|
||||
}
|
||||
raise ValueError(
|
||||
"Cannot compute temporal proportions: all subtask durations are zero. "
|
||||
"Check that your dataset has valid subtask annotations with start/end times."
|
||||
)
|
||||
|
||||
# Store in config for the model to use in progress output conversion (SARM paper Eq. 4)
|
||||
self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names]
|
||||
|
||||
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
|
||||
|
||||
@@ -481,15 +476,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
observation['state_features'] = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
# 3. Encode text with MiniLM
|
||||
# 3. Encode text with CLIP (per SARM paper A.4)
|
||||
batch_size = video_features.shape[0]
|
||||
task_descriptions = new_transition.get('task')
|
||||
if task_descriptions is not None:
|
||||
if isinstance(task_descriptions, str):
|
||||
task_descriptions = [task_descriptions] * batch_size
|
||||
observation['text_features'] = self._encode_text_batch_list(task_descriptions)
|
||||
else:
|
||||
observation['text_features'] = self._encode_text_batch(self.task_description, batch_size)
|
||||
observation['text_features'] = self._encode_text_clip(self.task_description, batch_size)
|
||||
|
||||
# 4. Extract frame/episode indices from complementary data
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
@@ -609,54 +598,33 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
return all_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_text_batch(self, text: str, batch_size: int) -> torch.Tensor:
|
||||
"""Encode a text string using MiniLM and replicate for batch.
|
||||
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
|
||||
"""Encode text using CLIP text encoder (per SARM paper A.4).
|
||||
|
||||
Args:
|
||||
text: Text string to encode
|
||||
text: Task description text to encode
|
||||
batch_size: Batch size to replicate for
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, 384)
|
||||
Encoded text features with shape (B, 512)
|
||||
"""
|
||||
from lerobot.policies.rewind.modeling_rewind import mean_pooling
|
||||
# Use CLIP's tokenizer directly for text (avoids image processor validation issues)
|
||||
tokenizer = self.clip_processor.tokenizer
|
||||
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
encoded_input = self.minilm_tokenizer(
|
||||
text, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
# Get text features from CLIP
|
||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||
|
||||
model_output = self.minilm_model(**encoded_input)
|
||||
text_embedding = mean_pooling(model_output, encoded_input["attention_mask"])
|
||||
text_embedding = text_embedding.squeeze().cpu()
|
||||
# Handle single text case
|
||||
if text_embedding.dim() == 1:
|
||||
text_embedding = text_embedding.unsqueeze(0)
|
||||
|
||||
# Replicate for batch (B, 384)
|
||||
text_embedding = text_embedding.unsqueeze(0).repeat(batch_size, 1)
|
||||
# Replicate for batch (B, 512)
|
||||
text_embedding = text_embedding.expand(batch_size, -1)
|
||||
|
||||
return text_embedding
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_text_batch_list(self, text_list: list[str]) -> torch.Tensor:
|
||||
"""Encode a list of text strings using MiniLM.
|
||||
|
||||
Args:
|
||||
text_list: List of text strings to encode
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, 384)
|
||||
"""
|
||||
from lerobot.policies.rewind.modeling_rewind import mean_pooling
|
||||
|
||||
# Encode all texts in the batch at once
|
||||
encoded_input = self.minilm_tokenizer(
|
||||
text_list, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
model_output = self.minilm_model(**encoded_input)
|
||||
text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
|
||||
text_embeddings = text_embeddings.cpu()
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
@@ -688,9 +656,12 @@ def make_sarm_pre_post_processors(
|
||||
"""
|
||||
Create pre-processor and post-processor pipelines for SARM.
|
||||
|
||||
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
|
||||
to process both RGB image sequences and task descriptions."
|
||||
|
||||
The pre-processing pipeline:
|
||||
1. Encodes images with CLIP (512-dim)
|
||||
2. Encodes text with MiniLM (384-dim)
|
||||
2. Encodes text with CLIP (512-dim)
|
||||
3. Normalizes joint states
|
||||
4. Adds batch dimension
|
||||
5. Moves data to device
|
||||
|
||||
@@ -229,8 +229,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
# For ReWiND and SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type in ["rewind", "sarm"]:
|
||||
# For SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
@@ -319,20 +319,17 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
elif cfg.policy.type in ["rewind", "sarm"] and getattr(cfg.policy, "use_temporal_sampler", False):
|
||||
# Use temporal sequence sampler for loading sequences
|
||||
from lerobot.datasets.temporal_sampler import TemporalSequenceSampler
|
||||
elif cfg.policy.type == "sarm" and getattr(cfg.policy, "use_temporal_sampler", False):
|
||||
# Use SARM temporal sampler for reward model training
|
||||
from lerobot.datasets.temporal_sampler import SARMTemporalSampler
|
||||
|
||||
shuffle = False
|
||||
sampling_mode = getattr(cfg.policy, "sampling_mode", cfg.policy.type)
|
||||
sampler = TemporalSequenceSampler(
|
||||
sampler = SARMTemporalSampler(
|
||||
dataset_from_index=dataset.meta.episodes["dataset_from_index"],
|
||||
dataset_to_index=dataset.meta.episodes["dataset_to_index"],
|
||||
sequence_length=cfg.policy.max_length,
|
||||
stride=getattr(cfg.policy, "sequence_stride", 1) if cfg.policy.type == "rewind" else getattr(cfg.policy, "frame_gap", 30),
|
||||
frame_gap=getattr(cfg.policy, "frame_gap", 30),
|
||||
shuffle=True,
|
||||
seed=cfg.seed,
|
||||
sampling_mode=sampling_mode,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
|
||||
@@ -35,7 +35,7 @@ class RABCWeightComputer:
|
||||
and applies soft weighting based on progress deltas.
|
||||
|
||||
Args:
|
||||
reward_model: Pre-trained reward model (e.g., SARM, ReWiND)
|
||||
reward_model: Pre-trained reward model (e.g., SARM)
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||||
epsilon: Small constant for numerical stability (default: 1e-6)
|
||||
device: Device to run reward model on
|
||||
|
||||
Reference in New Issue
Block a user