Remove rewind, use clip tokenizer

This commit is contained in:
Pepijn
2025-11-26 21:06:20 +01:00
parent 425eced2de
commit 3ed0425d2c
17 changed files with 172 additions and 2668 deletions

View File

@@ -1,528 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inference script for ReWiND Reward Model.
This script loads a trained ReWiND model and runs inference on a dataset episode,
generating visualizations of the predicted task progression over time.
Example usage:
python scripts/visualize_rewind_predictions.py \
--model-id username/rewind-model \
--dataset-repo lerobot/aloha_sim_insertion_human \
--episode-index 0 \
--output-dir outputs/rewind_viz \
--task-description "insert the peg into the socket"
"""
import argparse
import logging
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Run ReWiND inference and visualize predictions")
# Model arguments
parser.add_argument(
"--model-id",
type=str,
required=True,
help="HuggingFace model ID or local path to trained ReWiND model"
)
# Dataset arguments
parser.add_argument(
"--dataset-repo",
type=str,
required=True,
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
)
parser.add_argument(
"--episode-index",
type=int,
default=0,
help="Index of the episode to visualize (default: 0)"
)
parser.add_argument(
"--task-description",
type=str,
default="perform the task",
help="Task description for the reward model (default: 'perform the task')"
)
# Output arguments
parser.add_argument(
"--output-dir",
type=str,
default="outputs/rewind_inference",
help="Directory to save visualization outputs (default: outputs/rewind_inference)"
)
parser.add_argument(
"--image-key",
type=str,
default=None,
help="Key for images in dataset (e.g., observation.images.image for jaco_play). If not specified, uses model config's image_key"
)
# Visualization options
parser.add_argument(
"--show-frames",
action="store_true",
help="Include sample frames in the visualization"
)
parser.add_argument(
"--num-sample-frames",
type=int,
default=8,
help="Number of sample frames to show (default: 8)"
)
parser.add_argument(
"--figsize",
type=int,
nargs=2,
default=[12, 6],
help="Figure size as width height (default: 12 6)"
)
# Device
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to run inference on (cuda/cpu, default: auto-detect)"
)
return parser.parse_args()
def load_episode_data(
dataset: LeRobotDataset,
episode_index: int,
image_key: str
) -> tuple[np.ndarray, int, int, str]:
"""
Load all frames from a specific episode.
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode to load
image_key: Key for accessing images in the dataset
Returns:
Tuple of (frames, start_index, end_index, task_description)
"""
# Get episode boundaries
episode_data = dataset.meta.episodes
start_idx = episode_data["dataset_from_index"][episode_index]
end_idx = episode_data["dataset_to_index"][episode_index]
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
# Get task description from the dataset if available
task_description = None
first_item = dataset[start_idx]
if "task" in first_item:
task_description = first_item["task"]
print(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
# Load all frames from the episode
frames = []
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
item = dataset[idx]
# Get image from the item
img = item[image_key]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
# Handle different image formats (C, H, W) or (H, W, C)
if img.shape[0] in [1, 3]: # Channel first
img = np.transpose(img, (1, 2, 0))
# Convert to uint8 if needed
if img.dtype != np.uint8:
if img.max() <= 1.0:
img = (img * 255).astype(np.uint8)
else:
img = img.astype(np.uint8)
frames.append(img)
frames = np.array(frames)
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
return frames, start_idx, end_idx, task_description
@torch.no_grad()
def run_inference(
model: ReWiNDRewardModel,
frames: np.ndarray,
task_description: str,
batch_size: int = 32
) -> tuple[np.ndarray, np.ndarray]:
"""
Run ReWiND inference on video frames using the original ReWiND approach.
This function creates video slices for all frames at once (similar to the original
metaworld_label_reward.py), where each slice contains frames from start up to that point.
Progress Normalization (from original ReWiND dataset.py):
- Training: progress = [1, 2, ..., N] / remaining_length
where remaining_length = episode_end - sequence_start
- Inference: Starting from frame 0, remaining_length = total_episode_length
So expected progress for frame i = (i + 1) / total_episode_length
This function computes both:
1. Model predictions (what the model actually predicts)
2. Expected progress (ground truth based on frame position)
Args:
model: ReWiND model
frames: Video frames (num_frames, H, W, C)
task_description: Task description text
batch_size: Batch size for processing slices
Returns:
Tuple of:
- Model predictions for each frame (num_frames,)
- Expected progress for each frame (num_frames,)
"""
total_frames = len(frames)
logger.info("Encoding video frames with DINO...")
video_embeddings = model.encode_images(frames)
logger.info("Encoding task description with MiniLM...")
text_embedding = model.encode_text(task_description)
logger.info("Creating video slices (original ReWiND approach)...")
# Convert to tensors
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
# Create video slices: for each frame i, create a sequence of frames [0:i+1]
# This matches the original ReWiND inference approach
video_slices = []
for i in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Slice from start to current frame (inclusive)
video_slice = video_embeddings[:i + 1]
# Pad or subsample to max_length
if model.config.subsample_video:
video_slice = model.padding_video(video_slice, model.config.max_length)
video_slices.append(video_slice)
video_slices = torch.stack(video_slices) # (num_frames, max_length, 768)
# Create last_index_mask to extract the relevant prediction for each slice
# For slice i, the last valid frame is at position min(i, max_length-1)
max_length = model.config.max_length
last_index_mask = torch.zeros((len(video_slices), max_length), dtype=torch.bool)
for i in range(len(video_slices)):
last_frame_idx = min(i, max_length - 1)
last_index_mask[i, last_frame_idx] = 1
logger.info("Running ReWiND inference on all slices...")
# Process in batches
all_progress = []
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
batch_video = video_slices[i:i + batch_size].to(model.device)
batch_mask = last_index_mask[i:i + batch_size].to(model.device)
batch_size_actual = batch_video.shape[0]
# Replicate text embedding for batch
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
# Get predictions for all frames in batch
progress_preds = model.rewind_transformer(batch_video, batch_text) # (batch, max_length, 1)
progress_preds = progress_preds.squeeze(-1) # (batch, max_length)
# Extract predictions using the last_index_mask
# This gets the prediction for the last valid frame in each slice
batch_progress = progress_preds[batch_mask].cpu().numpy()
all_progress.extend(batch_progress)
predictions = np.array(all_progress)
# Compute expected progress based on original ReWiND normalization
# When starting from frame 0, remaining_length = total_episode_length
# Expected progress for frame i = (i + 1) / total_frames
expected_progress = np.arange(1, total_frames + 1, dtype=np.float32) / total_frames
logger.info(f"Inference complete. Predicted progress range: [{predictions.min():.3f}, {predictions.max():.3f}]")
logger.info(f"Expected progress range: [{expected_progress.min():.3f}, {expected_progress.max():.3f}]")
return predictions, expected_progress
def visualize_predictions(
frames: np.ndarray,
predictions: np.ndarray,
expected_progress: np.ndarray,
task_description: str,
output_path: Path,
show_frames: bool = False,
num_sample_frames: int = 8,
figsize: tuple = (12, 6)
):
"""
Create visualization of ReWiND predictions with expected progress comparison.
Args:
frames: Video frames (num_frames, H, W, C)
predictions: Model progress predictions (num_frames,)
expected_progress: Expected progress based on frame position (num_frames,)
task_description: Task description
output_path: Path to save the figure
show_frames: Whether to include sample frames
num_sample_frames: Number of frames to show
figsize: Figure size (width, height)
"""
if show_frames:
# Create figure with progress plot and sample frames
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1], hspace=0.3)
# Progress plot
ax_progress = fig.add_subplot(gs[0])
else:
# Just progress plot
fig, ax_progress = plt.subplots(1, 1, figsize=figsize)
# Plot progress over time
frame_indices = np.arange(len(predictions))
# Plot expected progress (ground truth)
ax_progress.plot(frame_indices, expected_progress, linewidth=2, color='#A8DADC',
linestyle='--', label='Expected Progress (Linear)', alpha=0.7)
# Plot model predictions
ax_progress.plot(frame_indices, predictions, linewidth=2.5, color='#2E86AB',
label='Model Predictions')
ax_progress.fill_between(frame_indices, 0, predictions, alpha=0.2, color='#2E86AB')
# Add reference line at 1.0
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
# Styling
ax_progress.set_xlabel('Frame Index', fontsize=12)
ax_progress.set_ylabel('Task Progress', fontsize=12)
ax_progress.set_title(f'ReWiND Task Progress Prediction\nTask: "{task_description}"',
fontsize=14, fontweight='bold')
ax_progress.grid(True, alpha=0.3)
ax_progress.set_ylim(-0.05, 1.1)
ax_progress.legend(loc='upper left')
# Compute alignment metrics
mae = np.mean(np.abs(predictions - expected_progress))
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
# Add statistics box
stats_text = (
f'Frames: {len(predictions)}\n'
f'Model Final: {predictions[-1]:.3f}\n'
f'Model Max: {predictions.max():.3f}\n'
f'Model Mean: {predictions.mean():.3f}\n'
f'MAE: {mae:.3f}\n'
f'RMSE: {rmse:.3f}'
)
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Show sample frames if requested
if show_frames:
# Select evenly spaced frames
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
# Create subplot for frames
ax_frames = fig.add_subplot(gs[1])
ax_frames.axis('off')
# Create grid for frames
frame_height = frames[0].shape[0]
frame_width = frames[0].shape[1]
combined_width = frame_width * num_sample_frames
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
for i, frame_idx in enumerate(frame_indices_to_show):
frame = frames[frame_idx]
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
# Add frame to combined image
x_start = i * frame_width
x_end = (i + 1) * frame_width
combined_image[:, x_start:x_end] = frame
# Add frame number and progress value
progress_val = predictions[frame_idx]
label = f'Frame {frame_idx}\nProgress: {progress_val:.3f}'
# Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label,
ha='center', va='top', fontsize=8,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
ax_frames.imshow(combined_image)
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
# Save figure
plt.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Saved visualization to {output_path}")
plt.close()
def main():
args = parse_args()
# Setup device
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
logger.info(f"Using device: {device}")
# Load model
logger.info(f"Loading ReWiND model from {args.model_id}...")
model = ReWiNDRewardModel.from_pretrained(args.model_id)
model.to(device)
model.eval()
logger.info("Model loaded successfully")
# Load dataset
logger.info(f"Loading dataset {args.dataset_repo}...")
dataset = LeRobotDataset(args.dataset_repo)
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
# Validate episode index
if args.episode_index >= len(dataset.meta.episodes):
raise ValueError(
f"Episode index {args.episode_index} out of range. "
f"Dataset has {len(dataset.meta.episodes)} episodes."
)
# Determine which image key to use
image_key = args.image_key if args.image_key is not None else model.config.image_key
logger.info(f"Using image key: {image_key}")
# Load episode data (this also extracts the task description from the episode)
frames, start_idx, end_idx, dataset_task = load_episode_data(dataset, args.episode_index, image_key)
# Use task description from dataset if available, otherwise use command-line argument
task_description = dataset_task if dataset_task is not None else args.task_description
logger.info(f"Using task description: '{task_description}'")
# Run inference
predictions, expected_progress = run_inference(model, frames, task_description)
# Create visualization
output_dir = Path(args.output_dir)
output_path = output_dir / f"rewind_prediction_ep{args.episode_index}.png"
visualize_predictions(
frames,
predictions,
expected_progress,
task_description,
output_path,
show_frames=args.show_frames,
num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize)
)
# Save predictions and expected progress as numpy arrays
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npy"
expected_path = output_dir / f"expected_progress_ep{args.episode_index}.npy"
np.save(predictions_path, predictions)
np.save(expected_path, expected_progress)
logger.info(f"Saved predictions array to {predictions_path}")
logger.info(f"Saved expected progress to {expected_path}")
# Compute alignment metrics
mae = np.mean(np.abs(predictions - expected_progress))
rmse = np.sqrt(np.mean((predictions - expected_progress) ** 2))
correlation = np.corrcoef(predictions, expected_progress)[0, 1]
# Print summary
logger.info("\n" + "="*60)
logger.info("INFERENCE SUMMARY")
logger.info("="*60)
logger.info(f"Model: {args.model_id}")
logger.info(f"Dataset: {args.dataset_repo}")
logger.info(f"Episode: {args.episode_index}")
logger.info(f"Task: {task_description}")
logger.info(f"Frames: {len(frames)}")
logger.info(f"\nModel Predictions:")
logger.info(f" Final: {predictions[-1]:.3f}")
logger.info(f" Max: {predictions.max():.3f}")
logger.info(f" Mean: {predictions.mean():.3f}")
logger.info(f" Std: {predictions.std():.3f}")
logger.info(f"\nExpected Progress (Linear):")
logger.info(f" Final: {expected_progress[-1]:.3f}")
logger.info(f" Mean: {expected_progress.mean():.3f}")
logger.info(f"\nAlignment Metrics:")
logger.info(f" MAE: {mae:.3f}")
logger.info(f" RMSE: {rmse:.3f}")
logger.info(f" Correlation: {correlation:.3f}")
logger.info(f"\nOutput:")
logger.info(f" Visualization: {output_path}")
logger.info("="*60)
# Diagnostic warnings
if predictions.std() < 0.05:
logger.warning("\n⚠ WARNING: Mode collapse detected (std < 0.05)")
logger.warning(" Model predictions show very low variance.")
logger.warning(" This indicates the model was likely trained with incorrect")
logger.warning(" progress normalization (absolute indices instead of remaining length).")
elif mae > 0.3:
logger.warning("\n⚠ WARNING: High prediction error (MAE > 0.3)")
logger.warning(" Model predictions deviate significantly from expected linear progress.")
logger.warning(" Consider retraining with correct progress normalization.")
elif correlation < 0.5:
logger.warning("\n⚠ WARNING: Low correlation with expected progress (< 0.5)")
logger.warning(" Model predictions don't align well with linear task progression.")
else:
logger.info("\n✓ Model predictions show healthy progression!")
if __name__ == "__main__":
main()

View File

@@ -244,7 +244,7 @@ def run_inference(
logger.info("Encoding video frames with CLIP...")
video_embeddings = model.encode_images(frames)
logger.info("Encoding task description with MiniLM...")
logger.info("Encoding task description with CLIP...")
text_embedding = model.encode_text(task_description)
# Get config values

View File

@@ -1,128 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ReWiND Sampler for temporal sequence loading.
"""
import logging
from typing import Iterator, Optional
import numpy as np
import torch
from torch.utils.data import Sampler
import random
class ReWiNDTemporalSampler(Sampler):
"""
Sampler for ReWiND that samples random temporal windows from episodes.
Matches original ReWiND sampling:
- Samples random start and end points within episodes
- Minimum window size of 3 frames
- Can sample from beginning, middle, or end of episodes
Args:
dataset_from_index: Start indices of episodes
dataset_to_index: End indices of episodes
sequence_length: Maximum sequence length (for padding/subsampling)
stride: Not used (kept for API compatibility)
shuffle: Whether to shuffle sampling order
seed: Random seed
"""
def __init__(
self,
dataset_from_index: np.ndarray,
dataset_to_index: np.ndarray,
sequence_length: int = 32,
stride: int = 1,
shuffle: bool = True,
seed: Optional[int] = None,
):
self.dataset_from_index = np.array(dataset_from_index)
self.dataset_to_index = np.array(dataset_to_index)
self.sequence_length = sequence_length
self.shuffle = shuffle
if seed is not None:
self.seed = seed
random.seed(seed)
np.random.seed(seed)
self.generator = torch.Generator().manual_seed(seed)
else:
self.generator = torch.Generator()
# Compute valid episodes (those with at least 3 frames)
self._compute_valid_episodes()
# Number of samples per epoch (matching original ReWiND)
self.samples_per_epoch = 100 * 64 # 100 batches of 64
logging.info(
f"ReWiNDTemporalSampler: {len(self.valid_episodes)} valid episodes, "
f"{self.samples_per_epoch} samples per epoch"
)
def _compute_valid_episodes(self):
"""Compute valid episodes (those with at least 3 frames)."""
self.valid_episodes = []
for ep_idx in range(len(self.dataset_from_index)):
ep_start = self.dataset_from_index[ep_idx]
ep_end = self.dataset_to_index[ep_idx]
episode_length = ep_end - ep_start
if episode_length >= 3: # Minimum 3 frames
self.valid_episodes.append((ep_idx, ep_start, ep_end))
self.valid_episodes = np.array(self.valid_episodes)
def __len__(self) -> int:
return self.samples_per_epoch
def __iter__(self) -> Iterator[int]:
"""
Yields ONE index per sample (the end of a random window).
Matches original ReWiND behavior:
1. Pick random episode
2. Pick random end frame (at least 3 frames from start)
3. Yield that end frame index
4. Dataset/processor loads from episode start to this end frame
5. Model pads/subsamples to sequence_length (32)
This allows sampling from anywhere in episodes:
- Early frames → short sequences (mostly padding) → low progress
- Middle frames → medium sequences (some subsampling) → medium progress
- End frames → long sequences (full subsampling) → high progress approaching 1.0
"""
for _ in range(self.samples_per_epoch):
# Randomly select an episode
ep_idx, ep_start, ep_end = self.valid_episodes[
np.random.randint(0, len(self.valid_episodes))
]
episode_length = ep_end - ep_start
# Sample a random end point (must be at least 3 frames from start)
# This matches original: random.randint(start_idx+3, len(progress_dataset))
end_offset = np.random.randint(3, episode_length + 1)
end_idx = ep_start + end_offset
# Yield ONLY the end index
# The dataset will load all frames from ep_start to end_idx
yield int(end_idx - 1) # -1 because end_idx is exclusive

View File

@@ -15,12 +15,10 @@
# limitations under the License.
"""
Temporal Sequence Sampler for reward models and temporal policies.
SARM Temporal Sampler for reward model training.
Supports multiple sampling modes:
- "rewind": ReWiND-style sampling (random windows from episode start)
- "sarm": SARM-style sampling (9-frame sequences with specific pattern)
- "custom": Custom temporal sampling
Samples frames from episodes ensuring sufficient temporal history for SARM's
9-frame pattern (1 initial + 8 consecutive with frame_gap spacing).
"""
import logging
@@ -31,24 +29,23 @@ from torch.utils.data import Sampler
import random
class TemporalSequenceSampler(Sampler):
class SARMTemporalSampler(Sampler):
"""
Generalized temporal sampler for reward models.
Temporal sampler for SARM reward model training.
Supports multiple sampling modes:
- "rewind": Consecutive frames from episode start to random end point (ReWiND: 32 consecutive frames)
- "sarm": 9-frame sequences with 1 initial + 8 consecutive (SARM)
- "custom": Custom temporal sampling
SARM uses 9 frames per sample:
- Frame 0: Initial frame of the episode (always frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame
This sampler ensures we only sample from positions that have enough
temporal history (at least 7 * frame_gap frames from episode start).
Args:
dataset_from_index: Start indices of episodes
dataset_to_index: End indices of episodes
sequence_length: Maximum sequence length (for padding/subsampling)
stride: Frame stride for consecutive sampling (SARM mode)
dataset_from_index: Start indices of episodes (global dataset indices)
dataset_to_index: End indices of episodes (global dataset indices)
frame_gap: Gap between consecutive frames (default: 30 = 1 second at 30fps)
shuffle: Whether to shuffle sampling order
seed: Random seed
sampling_mode: Sampling mode ("rewind", "sarm", or "custom")
min_frames: Minimum frames per episode (default: 3)
seed: Random seed for reproducibility
samples_per_epoch: Number of samples per epoch (default: 6400)
"""
@@ -56,25 +53,21 @@ class TemporalSequenceSampler(Sampler):
self,
dataset_from_index: np.ndarray,
dataset_to_index: np.ndarray,
sequence_length: int = 32,
stride: int = 1,
frame_gap: int = 30,
shuffle: bool = True,
seed: Optional[int] = None,
sampling_mode: str = "rewind",
min_frames: int = 3,
samples_per_epoch: int = 6400,
):
self.dataset_from_index = np.array(dataset_from_index)
self.dataset_to_index = np.array(dataset_to_index)
self.sequence_length = sequence_length
self.stride = stride
self.frame_gap = frame_gap
self.shuffle = shuffle
self.sampling_mode = sampling_mode
self.min_frames = min_frames
self.samples_per_epoch = samples_per_epoch
if sampling_mode not in ["rewind", "sarm", "custom"]:
raise ValueError(f"sampling_mode must be 'rewind', 'sarm', or 'custom', got {sampling_mode}")
# Minimum frames needed for SARM pattern:
# 8 consecutive frames with frame_gap spacing = 7 * frame_gap + 1
# (Plus the initial frame which is always available)
self.min_frames_needed = 7 * frame_gap + 1
if seed is not None:
self.seed = seed
@@ -84,98 +77,68 @@ class TemporalSequenceSampler(Sampler):
else:
self.generator = torch.Generator()
# Compute valid episodes
self._compute_valid_episodes()
# Compute valid episodes and sampling positions
self._compute_valid_positions()
logging.info(
f"TemporalSequenceSampler ({sampling_mode} mode): "
f"{len(self.valid_episodes)} valid episodes, "
f"{self.samples_per_epoch} samples per epoch"
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
f"{len(self.all_valid_positions)} valid positions, "
f"{self.samples_per_epoch} samples per epoch, "
f"frame_gap={frame_gap}"
)
def _compute_valid_episodes(self):
"""Compute valid episodes based on minimum frame requirement."""
def _compute_valid_positions(self):
"""Compute valid episodes and all valid sampling positions."""
self.valid_episodes = []
self.all_valid_positions = []
for ep_idx in range(len(self.dataset_from_index)):
ep_start = self.dataset_from_index[ep_idx]
ep_end = self.dataset_to_index[ep_idx]
episode_length = ep_end - ep_start
# For SARM mode, need enough frames for the sequence pattern
if self.sampling_mode == "sarm":
# Need at least sequence_length * stride frames
min_required = self.sequence_length * self.stride
if episode_length >= min_required:
self.valid_episodes.append((ep_idx, ep_start, ep_end))
else:
# For rewind mode, use min_frames
if episode_length >= self.min_frames:
self.valid_episodes.append((ep_idx, ep_start, ep_end))
# Episode must have enough frames for SARM pattern
if episode_length >= self.min_frames_needed:
self.valid_episodes.append((ep_idx, ep_start, ep_end))
# Valid positions: from min_frames_needed to episode end
# These are global dataset indices
for pos in range(ep_start + self.min_frames_needed - 1, ep_end):
self.all_valid_positions.append(pos)
self.valid_episodes = np.array(self.valid_episodes)
self.all_valid_positions = np.array(self.all_valid_positions)
if len(self.all_valid_positions) == 0:
raise ValueError(
f"No valid sampling positions found! "
f"Episodes need at least {self.min_frames_needed} frames "
f"(7 * frame_gap + 1 = 7 * {self.frame_gap} + 1)."
)
def __len__(self) -> int:
return self.samples_per_epoch
def __iter__(self) -> Iterator[int]:
"""
Yields ONE index per sample.
Yields global dataset indices for sampling.
Sampling behavior depends on mode:
ReWiND mode:
1. Pick random episode
2. Pick random end frame (at least min_frames from start)
3. Yield that end frame index
4. Dataset loads from episode start to this end frame
SARM mode:
1. Pick random episode
2. Pick random end frame (must allow sequence_length frames with stride)
3. Yield that end frame index
4. Dataset loads sequence_length frames with stride spacing ending at this frame
Each yielded index represents the "current frame" position.
The dataset's observation_delta_indices then handles loading:
- Frame 0: Episode initial frame (via large negative delta clamping)
- Frames 1-8: Consecutive frames ending at the yielded index
"""
for _ in range(self.samples_per_epoch):
# Randomly select an episode
ep_idx, ep_start, ep_end = self.valid_episodes[
np.random.randint(0, len(self.valid_episodes))
]
episode_length = ep_end - ep_start
if self.sampling_mode == "rewind":
# ReWiND: Sample random end point (at least min_frames from start)
end_offset = np.random.randint(self.min_frames, episode_length + 1)
end_idx = ep_start + end_offset
# Yield the end index (dataset will load from start to this point)
yield int(end_idx - 1) # -1 because end_idx is exclusive
elif self.sampling_mode == "sarm":
# SARM: Sample end point that allows full sequence
# We need sequence_length frames with stride spacing
min_end_offset = self.sequence_length * self.stride
if episode_length >= min_end_offset:
# Can sample anywhere from min_end_offset to episode_length
end_offset = np.random.randint(min_end_offset, episode_length + 1)
else:
# Episode is exactly the minimum length
end_offset = episode_length
end_idx = ep_start + end_offset
# Yield the end index (dataset will load sequence with stride)
yield int(end_idx - 1) # -1 because end_idx is exclusive
else: # custom mode
# Default to rewind-style sampling
end_offset = np.random.randint(self.min_frames, episode_length + 1)
end_idx = ep_start + end_offset
yield int(end_idx - 1)
if self.shuffle:
# Randomly sample from all valid positions
for _ in range(self.samples_per_epoch):
idx = np.random.randint(0, len(self.all_valid_positions))
yield int(self.all_valid_positions[idx])
else:
# Sequential sampling with wrap-around
for i in range(self.samples_per_epoch):
idx = i % len(self.all_valid_positions)
yield int(self.all_valid_positions[idx])
# Backwards compatibility alias
ReWiNDTemporalSampler = TemporalSequenceSampler
TemporalSequenceSampler = SARMTemporalSampler

View File

@@ -1,383 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Video sampling utilities for temporal data augmentation and frame selection.
This module provides utilities for sampling and augmenting video sequences, particularly
for reward model training. It includes functions for:
- Padding/sampling videos to fixed lengths
- Video rewind augmentation for learning to decrease rewards
"""
import random
from typing import Tuple
import numpy as np
import torch
def sample_video_feature(
video_feature: torch.Tensor,
max_length: int = 32,
random_sample: bool = True,
remaining_length: int = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Sample or pad video features to a fixed length with progress targets.
Progress normalization matches original ReWiND implementation:
- Progress = (position_in_sequence + 1) / remaining_trajectory_length
- remaining_trajectory_length = frames from first sampled frame to episode end
Original ReWiND logic (dataset.py lines 12493-12499):
video_frames = frames[start_idx:end_idx]
full_frames = frames[start_idx:] # All frames from start to episode end
progress = [1, 2, ..., len(video_frames)] / len(full_frames)
This ensures all sequences show increasing progress from near-zero, regardless
of where they're sampled from in the episode.
Uses original ReWiND sampling: random start/end points with minimum 3 frames.
Args:
video_feature: Video features tensor (num_frames, feature_dim)
max_length: Target sequence length
random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames.
remaining_length: Remaining trajectory length from first frame to episode end
Returns:
Tuple of:
- Sampled/padded video features (max_length, feature_dim)
- Progress targets for each frame (max_length,)
"""
video_length = len(video_feature)
# Original ReWiND sampling: random start/end with minimum 3 frames
if video_length > 3:
# Sample random start index (ensuring we can get at least 3 frames)
start_idx = random.randint(0, max(0, video_length - 3))
# Sample random end index (at least 3 frames after start, up to video_length)
end_idx = random.randint(min(start_idx + 3, video_length), video_length)
# Extract the sampled segment
video_feature = video_feature[start_idx:end_idx]
# Update video_length for the sampled segment
video_length = len(video_feature)
# Adjust remaining_length to be from start_idx to episode end
if remaining_length is not None:
# The remaining length should be from start_idx to episode end
# If we started at start_idx, we've already consumed start_idx frames
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
# Generate progress targets using ORIGINAL ReWiND formula
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
progress_targets = progress_indices / remaining_length
if video_length < max_length:
# Pad with last frame
padding_length = max_length - video_length
last_frame = video_feature[-1].unsqueeze(0)
padding_frames = last_frame.repeat(padding_length, 1)
video_feature = torch.cat([video_feature, padding_frames], dim=0)
# Pad progress with last progress value
last_progress = progress_targets[-1]
padding_progress = torch.full((padding_length,), last_progress)
progress_targets = torch.cat([progress_targets, padding_progress])
elif video_length > max_length:
if random_sample:
# Random sampling (maintains temporal order via sorted indices)
frame_idx = sorted(random.sample(range(video_length), max_length))
else:
# Uniform sampling (consecutive frames with even spacing)
frame_idx = np.linspace(0, video_length - 1, max_length, dtype=int)
video_feature = video_feature[frame_idx]
progress_targets = progress_targets[frame_idx]
return video_feature, progress_targets
def sample_reverse_video_feature(
video_feature: torch.Tensor,
max_length: int = 32,
random_sample: bool = True,
remaining_length: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC.
This implements the EXACT video rewind augmentation from the original ReWiND paper:
1. Take forward sequence (sampled with random start/end, min 3 frames)
2. Append reversed frames from the END backwards
3. Progress increases then decreases (simulating task completion then failure)
Progress normalization matches original ReWiND (same as sample_video_feature).
Original ReWiND logic (dataset.py lines 12526-12541):
progress = [1, 2, ..., len(video_frames)] / len(full_frames)
reverse_progress = progress[::-1][1:selected_end_point]
Args:
video_feature: Video features tensor (num_frames, feature_dim)
max_length: Target sequence length
random_sample: If True, use random sampling for frame selection
remaining_length: Remaining trajectory length from first frame to episode end
Returns:
Tuple of:
- Rewound video features (max_length, feature_dim)
- Progress targets for each frame (max_length,)
"""
video_length = len(video_feature)
# Original logic: start from first half, end in second half, ensure min 3 frames
if video_length > 3:
# Sample start from first half
start_idx = random.randint(0, video_length // 2)
# Sample end from second half
end_idx = random.randint(video_length // 2, video_length)
# Ensure minimum 3 frames difference (original uses while loop)
while end_idx - start_idx < 3:
start_idx = random.randint(0, video_length // 2)
end_idx = random.randint(video_length // 2, video_length)
# Extract the forward segment
video_feature = video_feature[start_idx:end_idx]
video_length = len(video_feature)
# Adjust remaining_length
if remaining_length is not None:
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
# Generate forward progress targets using ORIGINAL ReWiND formula
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
forward_progress = progress_indices / remaining_length
# ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence
# Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first)
# Result: [A,B,C,D,E] + [D,C,B] = progress increases then decreases
# Randomly select how many frames to reverse and append
selected_end_point = random.randint(2, min(video_length, max_length))
# Reverse the entire video and its progress
reversed_video = video_feature.flip(0)
reversed_progress = forward_progress.flip(0)
# Take frames from reversed (skip the first frame which is the last frame of original)
reverse_frames = reversed_video[1:selected_end_point]
reverse_progress = reversed_progress[1:selected_end_point]
# Concatenate forward + reversed (creates rewind effect)
rewound_video = torch.cat([video_feature, reverse_frames], dim=0)
progress_targets = torch.cat([forward_progress, reverse_progress], dim=0)
# Pad or sample to target length
if len(rewound_video) < max_length:
# Pad with last frame
padding_length = max_length - len(rewound_video)
last_frame = rewound_video[-1].unsqueeze(0)
padding_frames = last_frame.repeat(padding_length, 1)
rewound_video = torch.cat([rewound_video, padding_frames], dim=0)
# Pad progress with last progress value
last_progress = progress_targets[-1]
padding_progress = torch.full((padding_length,), last_progress)
progress_targets = torch.cat([progress_targets, padding_progress])
elif len(rewound_video) > max_length:
# Sample frames to fit max_length
if random_sample:
frame_idx = sorted(random.sample(range(len(rewound_video)), max_length))
else:
frame_idx = np.linspace(0, len(rewound_video) - 1, max_length, dtype=int)
rewound_video = rewound_video[frame_idx]
progress_targets = progress_targets[frame_idx]
return rewound_video, progress_targets
def sample_sarm_video_feature(
video_feature: torch.Tensor,
num_frames: int = 9,
frame_gap: int = 30,
random_sample: bool = True,
absolute_indices: torch.Tensor = None,
episode_length: int = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Sample video features for SARM (Stage-Aware Reward Modeling).
SARM uses a specific pattern:
- 1 initial frame (from episode start)
- 8 consecutive frames with frame_gap spacing
Progress normalization matches SARM implementation:
- Progress = absolute_frame_index / total_episode_length
Args:
video_feature: Video features tensor (num_frames_available, feature_dim)
num_frames: Target number of frames (default: 9)
frame_gap: Gap between consecutive frames (default: 30, i.e., 1 second at 30fps)
random_sample: If True, use random sampling (not used for SARM's fixed pattern)
absolute_indices: Absolute frame indices in the episode (num_frames_available,)
episode_length: Total length of the episode
Returns:
Tuple of:
- Sampled video features (num_frames, feature_dim)
- Progress targets for each frame (num_frames,)
"""
video_length = len(video_feature)
# Generate progress targets based on relative position within sampled sequence
# Note: SARM paper uses subtask annotations (Equation 2: yt = Pk1 + ᾱk * τt)
# Without annotations, we use linear progress relative to sequence position
if absolute_indices is not None and episode_length is not None:
# Compute relative progress: position within sequence / remaining trajectory
# This ensures progress starts near 0 and increases, not starting at 0.8 if sampled from end
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
remaining_length = episode_length - first_frame_idx
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
progress_targets = progress_indices / remaining_length
else:
# Fallback: linear progress
progress_targets = torch.linspace(1.0/video_length, 1.0, video_length)
# SARM pattern: first frame + (num_frames-1) consecutive frames with frame_gap
# The first frame should be from the beginning of the sequence
# The remaining frames are sampled with frame_gap spacing
if video_length < num_frames:
# Not enough frames, pad with last frame
sampled_video = video_feature
sampled_progress = progress_targets
padding_length = num_frames - video_length
last_frame = sampled_video[-1].unsqueeze(0)
padding_frames = last_frame.repeat(padding_length, 1)
sampled_video = torch.cat([sampled_video, padding_frames], dim=0)
last_progress = sampled_progress[-1]
padding_progress = torch.full((padding_length,), last_progress)
sampled_progress = torch.cat([sampled_progress, padding_progress])
else:
# Sample frames: first frame + (num_frames-1) with frame_gap
# The indices should represent: [0, gap, 2*gap, 3*gap, ..., (num_frames-1)*gap]
# But we need to ensure we don't exceed video_length
frame_indices = [0] # First frame
for i in range(1, num_frames):
idx = i * frame_gap
if idx >= video_length:
idx = video_length - 1
frame_indices.append(idx)
frame_indices = torch.tensor(frame_indices, dtype=torch.long)
sampled_video = video_feature[frame_indices]
sampled_progress = progress_targets[frame_indices]
return sampled_video, sampled_progress
def sample_sarm_reverse_video_feature(
video_feature: torch.Tensor,
num_frames: int = 9,
frame_gap: int = 30,
random_sample: bool = True,
absolute_indices: torch.Tensor = None,
episode_length: int = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Sample video with reverse augmentation for SARM (rewind augmentation).
Similar to ReWiND's rewind augmentation but adapted for SARM's frame pattern:
1. Take forward sequence (1 initial + 8 consecutive)
2. Append some reversed frames from the end backwards
3. Progress increases then decreases
Args:
video_feature: Video features tensor (num_frames_available, feature_dim)
num_frames: Target number of frames (default: 9)
frame_gap: Gap between consecutive frames (default: 30)
random_sample: If True, use random sampling for reverse section
absolute_indices: Absolute frame indices in the episode
episode_length: Total length of the episode
Returns:
Tuple of:
- Rewound video features (num_frames, feature_dim)
- Progress targets for each frame (num_frames,)
"""
video_length = len(video_feature)
# Generate forward progress targets (relative to sequence, not absolute)
if absolute_indices is not None and episode_length is not None:
# Use same relative progress as normal sampling
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
remaining_length = episode_length - first_frame_idx
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
forward_progress = progress_indices / remaining_length
else:
forward_progress = torch.linspace(1.0/video_length, 1.0, video_length)
# Sample forward sequence first
forward_video, forward_progress_sampled = sample_sarm_video_feature(
video_feature, num_frames, frame_gap, random_sample, absolute_indices, episode_length
)
# Randomly select how many frames to reverse and append
# For SARM, we append 2-4 reversed frames
num_reverse = random.randint(2, min(4, num_frames - 1))
# Reverse the video and progress
reversed_video = video_feature.flip(0)
reversed_progress = forward_progress.flip(0)
# Take frames from reversed (skip the first frame which is the last frame of original)
reverse_frames = reversed_video[1:num_reverse+1]
reverse_progress = reversed_progress[1:num_reverse+1]
# Concatenate forward + reversed (creates rewind effect)
rewound_video = torch.cat([forward_video, reverse_frames], dim=0)
progress_targets = torch.cat([forward_progress_sampled, reverse_progress], dim=0)
# Trim to num_frames if necessary
if len(rewound_video) > num_frames:
# Keep the first num_frames
rewound_video = rewound_video[:num_frames]
progress_targets = progress_targets[:num_frames]
elif len(rewound_video) < num_frames:
# Pad if necessary
padding_length = num_frames - len(rewound_video)
last_frame = rewound_video[-1].unsqueeze(0)
padding_frames = last_frame.repeat(padding_length, 1)
rewound_video = torch.cat([rewound_video, padding_frames], dim=0)
last_progress = progress_targets[-1]
padding_progress = torch.full((padding_length,), last_progress)
progress_targets = torch.cat([progress_targets, padding_progress])
return rewound_video, progress_targets

View File

@@ -132,41 +132,6 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("cosine_with_min_lr")
@dataclass
class CosineWithMinLRSchedulerConfig(LRSchedulerConfig):
"""Cosine learning rate scheduler with minimum learning rate floor.
Used by ReWiND for reward model training. Includes linear warmup phase
followed by cosine annealing with a minimum learning rate.
"""
num_warmup_steps: int
min_lr: float = 0.0
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step):
# Get base learning rate from optimizer
base_lr = optimizer.param_groups[0]['lr']
if current_step <= self.num_warmup_steps:
# Linear warmup
if self.num_warmup_steps == 0:
return 1.0
return float(current_step) / float(max(1, self.num_warmup_steps))
else:
# Cosine annealing with minimum learning rate
progress = (current_step - self.num_warmup_steps) / float(
max(1, num_training_steps - self.num_warmup_steps)
)
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
# Scale between min_lr and base_lr
min_factor = self.min_lr / base_lr if base_lr > 0 else 0.0
return min_factor + (1.0 - min_factor) * cosine_factor
return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict()
write_json(state_dict, save_dir / SCHEDULER_STATE)

View File

@@ -34,7 +34,6 @@ from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -105,10 +104,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "rewind":
from lerobot.policies.rewind.modeling_rewind import ReWiNDRewardModel
return ReWiNDRewardModel
elif name == "sarm":
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
@@ -332,15 +327,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, ReWiNDConfig):
from lerobot.policies.rewind.processor_rewind import make_rewind_pre_post_processors
processors = make_rewind_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, SARMConfig):
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors

View File

@@ -1,34 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
from lerobot.policies.rewind.modeling_rewind import (
ReWiNDRewardModel,
ReWiNDTransformer,
)
from lerobot.policies.rewind.processor_rewind import (
ReWiNDEncodingProcessorStep,
make_rewind_pre_post_processors,
)
__all__ = [
"ReWiNDConfig",
"ReWiNDRewardModel",
"ReWiNDTransformer",
"ReWiNDEncodingProcessorStep",
"make_rewind_pre_post_processors",
]

View File

@@ -1,138 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("rewind")
@dataclass
class ReWiNDConfig(PreTrainedConfig):
"""Configuration class for ReWiND Reward Model.
ReWiND (Reward from Video and Natural language Descriptions) is a reward model
that computes task completion/progress rewards from video observations and
language task descriptions.
"""
# Model architecture parameters
video_dim: int = 768 # DINO embedding dimension
text_dim: int = 384 # MiniLM embedding dimension
hidden_dim: int = 512
num_heads: int = 8
num_layers: int = 4
# Temporal parameters
max_length: int = 32 # Maximum video sequence length, ORIGINAL: 16!
subsample_video: bool = True # Whether to pad/subsample videos to max_length
use_temporal_sampler: bool = True # Always enable temporal sequence loading
sequence_stride: int = 1 # Stride between frames when using temporal sampler
rewind_ratio: float = 0.8 # Probability of applying rewind augmentation (original: 0.8)
# Training parameters
batch_size: int = 64
dino_batch_size: int = 64
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
# Model loading
pretrained_model_path: str | None = None
# Device settings
device: str | None = None
# Dropout
dropout: float = 0.1 # Dropout rate for transformer
# Processor settings (for automatic preprocessing)
image_key: str = "observation.images.top" # Key for images in dataset
task_description: str = "perform the task" # Default task description (used if no task field in data)
encode_on_the_fly: bool = True # Encode images/text during training
use_dataset_task: bool = True # Use task descriptions from dataset (per-episode)
# Features (required by PreTrainedPolicy)
input_features: dict = field(default_factory=lambda: {
"video_features": {"shape": [768], "dtype": "float32"},
"text_features": {"shape": [384], "dtype": "float32"}
})
output_features: dict = field(default_factory=lambda: {
"progress": {"shape": [1], "dtype": "float32"}
})
def __post_init__(self):
super().__post_init__()
# Validate configuration
if self.hidden_dim % self.num_heads != 0:
raise ValueError(
f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})"
)
if self.max_length <= 0:
raise ValueError(f"max_length must be positive, got {self.max_length}")
if self.dropout < 0 or self.dropout >= 1:
raise ValueError(f"dropout must be in [0, 1), got {self.dropout}")
def get_optimizer_preset(self) -> AdamWConfig:
"""Get default optimizer configuration for ReWiND training."""
return AdamWConfig(
lr=1e-4,
weight_decay=1e-4,
betas=(0.9, 0.999),
eps=1e-8,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Get default learning rate scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=1e-4,
decay_lr=1e-5,
num_warmup_steps=1000,
num_decay_steps=100000,
)
def validate_features(self) -> None:
pass
@property
def observation_delta_indices(self) -> list[int]:
"""Load all frames from episode start up to current frame.
The sampler yields a random end point in each episode.
This property tells the dataset to load all frames from -(end_idx - start_idx) to 0.
Since we don't know the exact window size in advance, we load up to max_length frames.
The dataset will automatically clamp to episode boundaries.
Returns:
Indices for loading history: [-31, -30, ..., -1, 0] for max_length=32
"""
# Load the last max_length frames (or up to episode start)
return list(range(-(self.max_length - 1), 1))
@property
def action_delta_indices(self) -> None:
"""ReWiND is a reward model, not an action policy."""
return None
@property
def reward_delta_indices(self) -> None:
"""ReWiND doesn't use delta rewards."""
return None

View File

@@ -1,711 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Union, Dict, Optional
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import torchvision.transforms as T
from torch import Tensor
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.datasets.video_sampler import sample_video_feature, sample_reverse_video_feature
# Helper functions for encoding
def dino_load_image(img: np.ndarray) -> torch.Tensor:
"""
Load an image and return a tensor that can be used as an input to DINOv2.
Args:
img: Input image as numpy array (H, W, C) in uint8 format.
Returns:
Transformed image tensor ready for DINO encoder (1, 3, 224, 224).
"""
# Define transform: center crop to 224x224, normalize to [-1, 1]
dino_transform = T.Compose([
T.ToTensor(),
T.CenterCrop(224),
T.Normalize([0.5], [0.5])
])
img_pil = Image.fromarray(img)
transformed_img = dino_transform(img_pil)[:3].unsqueeze(0)
return transformed_img
def mean_pooling(model_output, attention_mask):
"""
Mean pooling - take attention mask into account for correct averaging.
Args:
model_output: Model output containing token embeddings.
attention_mask: Attention mask for the tokens.
Returns:
Mean-pooled embeddings.
"""
token_embeddings = model_output[0] # First element contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
class ReWiNDTransformer(nn.Module):
"""
ReWiND Transformer model for predicting task progress from video and text.
This model takes video frame embeddings and text embeddings as input,
and predicts a progress score (0-1) for each frame indicating how much
of the task has been completed.
"""
def __init__(
self,
video_dim: int = 768,
text_dim: int = 384,
hidden_dim: int = 512,
num_heads: int = 8,
num_layers: int = 4,
max_length: int = 32,
dropout: float = 0.1
):
super().__init__()
self.hidden_dim = hidden_dim
self.max_length = max_length
# Project video and text to common dimension
self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
# Position embeddings for video sequence
# We only add positional embedding to the first frame as in the original
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Progress prediction head (applied to each frame)
self.progress_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
# Attention mask for causal self-attention
# Will be created on-demand based on sequence length
self.register_buffer("attention_mask", None, persistent=False)
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
"""Generate or retrieve cached causal attention mask."""
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
# Create causal mask (upper triangular with -inf)
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
self.attention_mask = mask
return self.attention_mask
def forward(self, video_frames: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the ReWiND transformer.
Args:
video_frames: Video frame embeddings (batch_size, seq_len, video_dim)
text_embed: Text embeddings (batch_size, text_dim)
Returns:
Progress predictions for each frame (batch_size, seq_len, 1)
"""
batch_size = video_frames.shape[0]
# Project inputs to common dimension
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
# Add positional embedding to first video frame
video_embed[:, 0] += self.first_pos_embed
# Combine sequence: [text, video_frames]
sequence = torch.cat([text_embed, video_embed], dim=1)
# Get causal attention mask
seq_length = sequence.shape[1]
attention_mask = self._get_attention_mask(seq_length, sequence.device)
# Pass through transformer with causal masking
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
# Get progress predictions for each frame (exclude text token)
progress_preds = self.progress_head(transformed[:, 1:])
return progress_preds
class ReWiNDRewardModel(PreTrainedPolicy):
"""
ReWiND Reward Model for computing task completion rewards from video and text.
This model combines:
- DINO (DINOv2) for encoding video frames
- MiniLM for encoding text descriptions
- ReWiNDTransformer for predicting task progress
"""
name = "rewind"
config_class = ReWiNDConfig
def __init__(self, config: ReWiNDConfig, dataset_stats: dict | None = None):
super().__init__(config, dataset_stats)
self.config = config
self.dataset_stats = dataset_stats
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Initialize DINO encoder for images
logging.info("Loading DINO encoder...")
self.dino_encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
self.dino_encoder.to(self.device)
self.dino_encoder.eval()
# Initialize MiniLM encoder for text
logging.info("Loading MiniLM encoder...")
self.minilm_tokenizer = AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model = AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model.to(self.device)
self.minilm_model.eval()
# Initialize ReWiND transformer with explicit architecture parameters
self.rewind_transformer = ReWiNDTransformer(
video_dim=config.video_dim,
text_dim=config.text_dim,
hidden_dim=config.hidden_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
max_length=config.max_length,
dropout=config.dropout
)
self.rewind_transformer.to(self.device)
logging.info(f"ReWiND Reward Model initialized on {self.device}")
def to(self, device):
"""Override to method to ensure all components move together."""
super().to(device)
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.dino_encoder.to(device)
self.minilm_model.to(device)
self.rewind_transformer.to(device)
return self
@torch.no_grad()
def encode_images(self, images: np.ndarray) -> np.ndarray:
"""
Encode video frames using DINO.
Args:
images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8.
Can also be (num_frames, H, W, C) for a single video.
Returns:
Encoded image features (num_videos, num_frames, 768) or (num_frames, 768).
"""
# Handle single video case
single_video = False
if len(images.shape) == 4:
images = images[np.newaxis, ...]
single_video = True
assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}"
# Ensure channels are in correct position
if images.shape[-1] == 3 and images.shape[2] != 3:
images = np.transpose(images, (0, 1, 4, 2, 3))
all_embeddings = []
for video in images:
# Process each video
video_embeddings = []
# Convert frames to list of numpy arrays
frames = [frame.transpose(1, 2, 0).astype(np.uint8) if frame.shape[0] == 3 else frame for frame in video]
# Batch process frames with DINO
episode_images_dino = [dino_load_image(frame) for frame in frames]
# Process in batches
for i in range(0, len(episode_images_dino), self.config.dino_batch_size):
batch = torch.cat(episode_images_dino[i:i + self.config.dino_batch_size])
batch = batch.to(self.device)
embeddings = self.dino_encoder(batch).squeeze().detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
video_embeddings.append(embeddings)
video_embeddings = torch.cat(video_embeddings)
all_embeddings.append(video_embeddings)
result = torch.stack(all_embeddings).numpy()
if single_video:
result = result[0]
return result
@torch.no_grad()
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""
Encode text using MiniLM.
Args:
text: Text string or list of text strings.
Returns:
Encoded text features (batch_size, 384) or (384,) for single text.
"""
if isinstance(text, str):
text = [text]
single_text = True
else:
single_text = False
# Process in batches
all_embeddings = []
for i in range(0, len(text), self.config.batch_size):
batch_text = text[i:i + self.config.batch_size]
encoded_input = self.minilm_tokenizer(
batch_text, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
model_output = self.minilm_model(**encoded_input)
text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
all_embeddings.append(text_embeddings.cpu())
result = torch.cat(all_embeddings).numpy()
if single_text:
result = result[0]
return result
def padding_video(self, video_frames: torch.Tensor, max_length: int) -> torch.Tensor:
"""
Pad or subsample video frames to a fixed length.
Args:
video_frames: Video frames tensor (num_frames, embedding_dim)
max_length: Target sequence length
Returns:
Padded/subsampled video frames (max_length, embedding_dim)
"""
video_length = len(video_frames)
if isinstance(video_frames, np.ndarray):
video_frames = torch.tensor(video_frames)
if video_length < max_length:
# Pad with last frame
padding_length = max_length - video_length
last_frame = video_frames[-1].unsqueeze(0)
padding_frames = last_frame.repeat(padding_length, 1)
video_frames = torch.cat([video_frames, padding_frames], dim=0)
elif video_length > max_length:
# Subsample uniformly
frame_idx = np.linspace(0, video_length - 1, max_length).astype(int)
video_frames = video_frames[frame_idx]
return video_frames
@torch.no_grad()
def calculate_rewards(
self,
text_embeddings: Union[np.ndarray, torch.Tensor],
video_embeddings: Union[np.ndarray, torch.Tensor],
return_all_frames: bool = False
) -> np.ndarray:
"""
Calculate rewards for given text and video representations.
Args:
text_embeddings: Encoded text representations (batch_size, 384)
video_embeddings: Encoded video representations (batch_size, num_frames, 768)
return_all_frames: If True, return rewards for all frames. If False, return only last frame.
Returns:
Reward values (batch_size,) or (batch_size, num_frames) if return_all_frames=True
"""
# Convert to tensors if needed
if isinstance(text_embeddings, np.ndarray):
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
if isinstance(video_embeddings, np.ndarray):
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
# Handle single sample case
if text_embeddings.dim() == 1:
text_embeddings = text_embeddings.unsqueeze(0)
video_embeddings = video_embeddings.unsqueeze(0)
single_sample = True
else:
single_sample = False
# Process in batches
all_rewards = []
for i in range(0, len(video_embeddings), self.config.batch_size):
batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device)
batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device)
# Pad/subsample videos if needed
if self.config.subsample_video:
padded_videos = []
for video in batch_videos:
padded_video = self.padding_video(video, self.config.max_length)
padded_videos.append(padded_video)
batch_videos = torch.stack(padded_videos).to(self.device)
# Get progress predictions
rewards = self.rewind_transformer(batch_videos.float(), batch_texts.float())
if return_all_frames:
all_rewards.append(rewards.squeeze(-1).cpu())
else:
# Return only last frame reward
all_rewards.append(rewards[:, -1, 0].cpu())
result = torch.cat(all_rewards).numpy()
if single_sample:
result = result[0] if not return_all_frames else result[0]
return result
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
"""
Load pretrained model weights from a checkpoint file.
Args:
checkpoint_path: Path to the .pth checkpoint file
strict: Whether to strictly enforce that the keys match
"""
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
# Handle different checkpoint formats
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
# Check for architecture parameters in checkpoint
if "args" in checkpoint:
args = checkpoint["args"]
logging.info(f"Checkpoint was trained with: max_length={args.max_length}")
# Warn if max_length differs
if hasattr(args, 'max_length') and args.max_length != self.config.max_length:
logging.warning(
f"Checkpoint max_length ({args.max_length}) differs from config ({self.config.max_length}). "
"This may cause issues if sequence lengths don't match."
)
else:
state_dict = checkpoint
# Load only the ReWiNDTransformer weights
missing_keys, unexpected_keys = self.rewind_transformer.load_state_dict(state_dict, strict=strict)
if missing_keys:
logging.warning(f"Missing keys when loading checkpoint: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
logging.info("Checkpoint loaded successfully")
def train(self, mode: bool = True):
"""Set training mode. Note: DINO and MiniLM encoders always stay in eval mode."""
super().train(mode)
# Keep encoders in eval mode
self.dino_encoder.eval()
self.minilm_model.eval()
# Only transformer can be trained
self.rewind_transformer.train(mode)
return self
def eval(self):
"""Set evaluation mode."""
return self.train(False)
def parameters(self):
"""Return trainable parameters (only ReWiND transformer, not encoders)."""
return self.rewind_transformer.parameters()
def get_optim_params(self):
"""Return optimizer parameters for the policy."""
return self.parameters()
def reset(self):
"""
This method is required by PreTrainedPolicy but not used for reward models.
The reward model does not maintain state between episodes.
"""
pass
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward models.
The rewind model is not an actor and does not produce action chunks.
"""
raise NotImplementedError("Rewind model does not predict action chunks")
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for rewind.
The rewind model is not an actor and does not select actions.
"""
raise NotImplementedError("Rewind model does not select actions")
def forward(self, batch):
"""
Forward pass compatible with lerobot training pipeline.
Args:
batch: Dictionary containing observation with:
- 'video_features': Pre-encoded video features (B, 768) or (B, T, 768)
- 'text_features': Pre-encoded text features (B, 384)
Returns:
loss: Total training loss
output_dict: Dictionary of loss components for logging
"""
# Extract from observation dict
observation = batch.get('observation', batch)
video_features = observation['video_features'].to(self.device)
text_features = observation['text_features'].to(self.device)
batch_size = video_features.shape[0]
max_length = self.config.max_length
# Handle both single frames (B, 768) and sequences (B, T, 768)
if video_features.dim() == 2:
# Single frames: replicate to create pseudo-sequences
video_features = video_features.unsqueeze(1).repeat(1, max_length, 1) # (B, max_length, 768)
# Now video_features is (B, T, 768) where T might be > max_length
# Process videos (with potential rewind augmentation)
import random
from lerobot.datasets.video_sampler import sample_video_feature, sample_reverse_video_feature
processed_videos = []
progress_targets = []
# Extract episode metadata for correct progress normalization
absolute_frame_indices = observation.get('absolute_frame_indices', None)
episode_lengths = observation.get('episode_length', None)
remaining_lengths = observation.get('remaining_length', None)
for i in range(batch_size):
# Get metadata for this sample
current_absolute_indices = None
current_episode_length = None
current_remaining_length = None
if absolute_frame_indices is not None:
if isinstance(absolute_frame_indices, list):
current_absolute_indices = absolute_frame_indices[i]
else:
current_absolute_indices = absolute_frame_indices
if episode_lengths is not None:
if isinstance(episode_lengths, torch.Tensor) and episode_lengths.dim() > 0:
current_episode_length = episode_lengths[i].item()
else:
current_episode_length = episode_lengths.item() if isinstance(episode_lengths, torch.Tensor) else episode_lengths
if remaining_lengths is not None:
if isinstance(remaining_lengths, torch.Tensor) and remaining_lengths.dim() > 0:
current_remaining_length = remaining_lengths[i].item()
else:
current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths
if random.random() < self.config.rewind_ratio: # Use configurable rewind ratio
# Apply video rewind augmentation (now returns tuple)
rewound_video, progress = sample_reverse_video_feature(
video_features[i],
max_length=max_length,
random_sample=True, # Use random sampling (original ReWiND)
remaining_length=current_remaining_length
)
processed_videos.append(rewound_video.to(self.device))
progress_targets.append(progress.to(self.device))
else:
# Normal video sampling (now returns tuple with progress targets)
sampled_video, progress = sample_video_feature(
video_features[i],
max_length=max_length,
random_sample=True, # Use random sampling (original ReWiND)
remaining_length=current_remaining_length
)
processed_videos.append(sampled_video.to(self.device))
progress_targets.append(progress.to(self.device))
processed_videos = torch.stack(processed_videos)
progress_targets = torch.stack(progress_targets)
# Compute progress loss
progress_loss = compute_progress_loss(
self.rewind_transformer,
processed_videos,
text_features,
progress_targets
)
total_loss = progress_loss
output_dict = {'progress_loss': progress_loss.item()}
# Compute misaligned loss if requested (20% probability to match original)
if random.random() < 0.2: # 20% chance of adding misalignment loss (original ReWiND uses 20%)
if 'misaligned_video_features' in batch and 'misaligned_text_features' in batch:
misaligned_videos = batch['misaligned_video_features'].to(self.device)
misaligned_texts = batch['misaligned_text_features'].to(self.device)
else:
# Create misaligned pairs by shuffling
shuffle_idx = torch.randperm(batch_size)
misaligned_videos = processed_videos[shuffle_idx]
misaligned_texts = text_features
# Sample misaligned videos (function now returns tuple)
# For misaligned pairs, we don't need correct progress targets (will be set to 0)
misaligned_videos_sampled = []
for i in range(batch_size):
# For misaligned videos, use video length as remaining_length
video_len = len(misaligned_videos[i])
sampled, _ = sample_video_feature(
misaligned_videos[i],
max_length=max_length,
random_sample=True,
remaining_length=video_len # Use video length for misaligned pairs
)
misaligned_videos_sampled.append(sampled.to(self.device))
misaligned_videos_sampled = torch.stack(misaligned_videos_sampled)
misaligned_loss = compute_misaligned_loss(
self.rewind_transformer,
misaligned_videos_sampled,
misaligned_texts
)
total_loss = total_loss + misaligned_loss
output_dict['misaligned_loss'] = misaligned_loss.item()
output_dict['total_loss'] = total_loss.item()
return total_loss, output_dict
# Loss utilities
def compute_progress_loss(
model: ReWiNDTransformer,
video_features: torch.Tensor,
text_features: torch.Tensor,
target_progress: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute progress prediction loss.
Args:
model: ReWiNDTransformer model
video_features: Batch of video features (batch_size, max_length, feature_dim)
text_features: Batch of text features (batch_size, text_dim)
target_progress: Optional target progress values (batch_size, max_length).
If None, uses linear progress from 0 to 1.
Returns:
Mean squared error loss
"""
# Get predictions
progress_preds = model(video_features, text_features)
# Create target progress if not provided
if target_progress is None:
batch_size, max_length = video_features.shape[:2]
target_progress = torch.linspace(0, 1, max_length, device=video_features.device)
target_progress = target_progress.unsqueeze(0).repeat(batch_size, 1)
# Ensure target has correct shape
if target_progress.dim() == 2:
target_progress = target_progress.unsqueeze(-1)
# Compute MSE loss
loss = F.mse_loss(progress_preds, target_progress)
return loss
def compute_misaligned_loss(
model: ReWiNDTransformer,
video_features: torch.Tensor,
misaligned_text_features: torch.Tensor
) -> torch.Tensor:
"""
Compute loss for misaligned video-text pairs (should predict 0 progress).
Args:
model: ReWiNDTransformer model
video_features: Batch of video features (batch_size, max_length, feature_dim)
misaligned_text_features: Batch of misaligned text features (batch_size, text_dim)
Returns:
Mean squared error loss (predictions should be close to 0)
"""
# Get predictions
progress_preds = model(video_features, misaligned_text_features)
# Target is all zeros
target_zeros = torch.zeros_like(progress_preds)
# Compute MSE loss
loss = F.mse_loss(progress_preds, target_zeros)
return loss

View File

@@ -1,405 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, Any, List, Optional
import numpy as np
import torch
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
from lerobot.processor import (
ProcessorStep,
PolicyProcessorPipeline,
PolicyAction,
DeviceProcessorStep,
AddBatchDimensionProcessorStep,
)
from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.pipeline import PipelineFeatureType
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.configs.types import PolicyFeature, FeatureType
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class ReWiNDEncodingProcessorStep(ProcessorStep):
"""
ProcessorStep that encodes images and text for ReWiND training.
This step handles the DINO (image) and MiniLM (text) encoding that ReWiND needs.
Supports both single-frame and temporal sequence encoding:
- Single frame: (B, C, H, W) → (B, 768) video features
- Temporal sequence: (B, T, C, H, W) → (B, T, 768) video features
To use temporal sequences, configure the dataset with delta_timestamps for your image key.
For example, to encode sequences of 32 frames:
delta_timestamps = {
"observation.images.top": [i / fps for i in range(-15, 17)] # 32 frames centered on current
}
"""
def __init__(
self,
config: ReWiNDConfig,
image_key: str | None = None,
task_description: str | None = None,
dataset_meta = None,
):
super().__init__()
self.config = config
self.image_key = image_key or config.image_key
self.task_description = task_description or config.task_description
self.dataset_meta = dataset_meta # Store dataset metadata for episode info
# Initialize encoders
self._init_encoders()
def _init_encoders(self):
"""Initialize DINO and MiniLM encoders."""
from transformers import AutoModel, AutoTokenizer
device = torch.device(
self.config.device if self.config.device
else "cuda" if torch.cuda.is_available() else "cpu"
)
logging.info("Initializing DINO encoder for ReWiND...")
self.dino_encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
self.dino_encoder.to(device)
self.dino_encoder.eval()
logging.info("Initializing MiniLM encoder for ReWiND...")
self.minilm_tokenizer = AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model = AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model.to(device)
self.minilm_model.eval()
self.device = device
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Encode images and text in the transition."""
self._current_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
new_transition = self._current_transition
observation = new_transition.get(TransitionKey.OBSERVATION)
if observation is None or not isinstance(observation, dict):
# If no observation, just return the transition as-is
return new_transition
# Extract images from observation and encode
# For ReWiND, we need to load the sequence from episode start to current frame
batch_size = 1
if self.image_key in observation:
image = observation[self.image_key]
# Handle different image formats
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
# Check if we have temporal sequences or single frames
# Temporal sampling: Load from episode start to current frame
# This will be handled by the dataset if configured with delta_timestamps
# Otherwise, we just encode the single frame
video_features = self._encode_images_batch(image)
observation['video_features'] = video_features
# Get batch size from the encoded features
batch_size = video_features.shape[0]
# Get task descriptions - check if 'task' field exists in the transition
# This allows per-episode task descriptions (e.g., for datasets with multiple tasks)
task_descriptions = None
if 'task' in new_transition:
task_descriptions = new_transition['task']
# Convert to list if it's a single string
if isinstance(task_descriptions, str):
task_descriptions = [task_descriptions] * batch_size
# Encode text
if task_descriptions is not None:
# Encode per-sample task descriptions
text_features = self._encode_text_batch_list(task_descriptions)
else:
# Fall back to config task description if no task field in transition
text_features = self._encode_text_batch(self.task_description, batch_size)
observation['text_features'] = text_features
# Compute episode metadata for progress normalization (used by ReWiND)
# We need to pass absolute frame indices and total episode length for correct progress calculation
if self.dataset_meta is not None and 'episode_index' in new_transition and 'index' in new_transition:
episode_indices = new_transition['episode_index']
frame_indices = new_transition['index']
# Handle both single samples and batches
if isinstance(episode_indices, (int, np.integer)):
ep_idx = int(episode_indices)
frame_idx = int(frame_indices)
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
episode_length = ep_end - ep_start
# For temporal sequences with observation_delta_indices:
# If we loaded frames using delta_indices (e.g., [-31, -30, ..., 0]),
# we need to compute the absolute indices of those frames
# The current frame is at frame_idx, and we loaded max_length frames before it
if 'video_features' in observation and len(observation['video_features'].shape) > 1:
# We have a temporal sequence
num_loaded_frames = observation['video_features'].shape[0] if observation['video_features'].dim() == 2 else observation['video_features'].shape[1]
# Absolute indices: from (frame_idx - num_frames + 1) to frame_idx
start_idx = max(ep_start, frame_idx - num_loaded_frames + 1)
absolute_indices = torch.arange(start_idx, frame_idx + 1)
observation['absolute_frame_indices'] = absolute_indices
# Compute remaining length: from first loaded frame to episode end
observation['remaining_length'] = ep_end - start_idx
else:
# Single frame
observation['absolute_frame_indices'] = torch.tensor([frame_idx])
# Remaining length from this frame to episode end
observation['remaining_length'] = ep_end - frame_idx
observation['episode_length'] = episode_length
else:
# Batch case
absolute_indices_list = []
episode_lengths = []
remaining_lengths = []
for ep_idx, frame_idx in zip(episode_indices, frame_indices):
ep_idx = int(ep_idx.item() if hasattr(ep_idx, 'item') else ep_idx)
frame_idx = int(frame_idx.item() if hasattr(frame_idx, 'item') else frame_idx)
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
episode_length = ep_end - ep_start
episode_lengths.append(episode_length)
# Compute absolute indices for this sample
if 'video_features' in observation and len(observation['video_features'].shape) > 1:
num_loaded_frames = observation['video_features'].shape[1]
start_idx = max(ep_start, frame_idx - num_loaded_frames + 1)
absolute_indices = torch.arange(start_idx, frame_idx + 1)
absolute_indices_list.append(absolute_indices)
# Remaining length from first loaded frame to episode end
remaining_lengths.append(ep_end - start_idx)
else:
absolute_indices_list.append(torch.tensor([frame_idx]))
# Remaining length from this frame to episode end
remaining_lengths.append(ep_end - frame_idx)
observation['absolute_frame_indices'] = absolute_indices_list
observation['episode_length'] = torch.tensor(episode_lengths)
observation['remaining_length'] = torch.tensor(remaining_lengths)
new_transition[TransitionKey.OBSERVATION] = observation
return new_transition
@torch.no_grad()
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
"""Encode a batch of images (with optional temporal dimension) using DINO.
Args:
images: Batched images with shape:
- (B, C, H, W) for single frames, or
- (B, T, C, H, W) for temporal sequences
Returns:
Encoded feature vectors with shape (B, 768) or (B, T, 768)
"""
from lerobot.policies.rewind.modeling_rewind import dino_load_image
# Check if we have temporal dimension
has_temporal = len(images.shape) == 5
if has_temporal:
# Shape: (B, T, C, H, W)
batch_size, seq_length = images.shape[0], images.shape[1]
# Reshape to (B*T, C, H, W) to process all frames at once
images = images.reshape(batch_size * seq_length, *images.shape[2:])
elif len(images.shape) == 4:
# Shape: (B, C, H, W)
batch_size = images.shape[0]
seq_length = 1
else:
raise ValueError(f"Expected 4D (B, C, H, W) or 5D (B, T, C, H, W) input, got shape {images.shape}")
# Convert to list of (H, W, C) images
num_frames = images.shape[0]
if images.shape[1] in [1, 3]: # Channel first (N, C, H, W)
images_list = [images[i].transpose(1, 2, 0) for i in range(num_frames)]
else: # Channel last (N, H, W, C)
images_list = [images[i] for i in range(num_frames)]
# Encode each frame (can batch process with DINO for efficiency)
all_embeddings = []
for i in range(0, num_frames, self.config.dino_batch_size):
batch_imgs = images_list[i:i + self.config.dino_batch_size]
# Prepare images for DINO
dino_inputs = []
for img in batch_imgs:
# Handle single channel
if img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)
# Convert to uint8
if img.dtype != np.uint8:
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
dino_inputs.append(dino_load_image(img))
# Batch encode
dino_batch = torch.cat(dino_inputs).to(self.device)
embeddings = self.dino_encoder(dino_batch).detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
all_embeddings.append(embeddings)
# Concatenate all embeddings
all_embeddings = torch.cat(all_embeddings) # (B*T, 768)
# Reshape back if temporal
if has_temporal:
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 768)
return all_embeddings
@torch.no_grad()
def _encode_text_batch(self, text: str, batch_size: int) -> torch.Tensor:
"""Encode a text string using MiniLM and replicate for batch.
Args:
text: Text string to encode
batch_size: Batch size to replicate for
Returns:
Encoded feature vectors with shape (B, 384)
"""
from lerobot.policies.rewind.modeling_rewind import mean_pooling
encoded_input = self.minilm_tokenizer(
text, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
model_output = self.minilm_model(**encoded_input)
text_embedding = mean_pooling(model_output, encoded_input["attention_mask"])
text_embedding = text_embedding.squeeze().cpu()
# Replicate for batch (B, 384)
text_embedding = text_embedding.unsqueeze(0).repeat(batch_size, 1)
return text_embedding
@torch.no_grad()
def _encode_text_batch_list(self, text_list: list[str]) -> torch.Tensor:
"""Encode a list of text strings using MiniLM (one per sample).
Args:
text_list: List of text strings to encode
Returns:
Encoded feature vectors with shape (B, 384)
"""
from lerobot.policies.rewind.modeling_rewind import mean_pooling
# Encode all texts in the batch at once
encoded_input = self.minilm_tokenizer(
text_list, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
model_output = self.minilm_model(**encoded_input)
text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
text_embeddings = text_embeddings.cpu()
return text_embeddings
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Adds video_features and text_features to the observation features.
"""
# Add the encoded features
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(768,) # DINO embedding dimension
)
features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature(
type=FeatureType.LANGUAGE,
shape=(384,) # MiniLM embedding dimension
)
return features
def make_rewind_pre_post_processors(
config: ReWiNDConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
dataset_meta = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Create pre-processor and post-processor pipelines for ReWiND.
The pre-processing pipeline:
1. Encodes images with DINO (768-dim)
2. Encodes text with MiniLM (384-dim)
3. Computes remaining episode length for progress normalization
4. Adds batch dimension
5. Moves data to device
The post-processing pipeline moves data back to CPU.
Args:
config: ReWiND configuration
dataset_stats: Dataset statistics (not used for ReWiND)
dataset_meta: Dataset metadata for computing episode remaining length
Returns:
Tuple of (preprocessor, postprocessor) pipelines
"""
input_steps = [
AddBatchDimensionProcessorStep(),
ReWiNDEncodingProcessorStep(config=config, dataset_meta=dataset_meta),
DeviceProcessorStep(device=config.device),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -19,7 +19,6 @@ from lerobot.policies.sarm.modeling_sarm import (
SARMRewardModel,
SARMTransformer,
compute_stage_loss,
compute_progress_loss,
)
from lerobot.policies.sarm.processor_sarm import (
SARMEncodingProcessorStep,
@@ -31,7 +30,6 @@ __all__ = [
"SARMRewardModel",
"SARMTransformer",
"compute_stage_loss",
"compute_progress_loss",
"SARMEncodingProcessorStep",
"make_sarm_pre_post_processors",
]

View File

@@ -32,8 +32,8 @@ class SARMConfig(PreTrainedConfig):
num_frames: int = 9 # 1 initial + 8 consecutive frames
frame_gap: int = 30 # Frame gap between consecutive frames (at 30 fps = 1 second)
# Text encoding parameters
text_dim: int = 384
# Text encoding parameters (CLIP text encoder output dimension)
text_dim: int = 512
# Joint state parameters
state_dim: int | None = None # Auto-detected from dataset if None
@@ -49,7 +49,6 @@ class SARMConfig(PreTrainedConfig):
# Temporal parameters
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
use_temporal_sampler: bool = True # Always enable temporal sequence loading
sampling_mode: str = "sarm" # Sampling mode: "sarm" or "rewind"
# Training parameters
batch_size: int = 64
@@ -101,11 +100,6 @@ class SARMConfig(PreTrainedConfig):
if self.num_stages < 2:
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
if self.sampling_mode not in ["sarm", "rewind", "custom"]:
raise ValueError(
f"sampling_mode must be 'sarm' or 'rewind', got {self.sampling_mode}"
)
def get_optimizer_preset(self) -> AdamWConfig:
"""Get default optimizer configuration for SARM training."""

View File

@@ -24,33 +24,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor
from transformers import CLIPModel, CLIPProcessor
from torch import Tensor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.pretrained import PreTrainedPolicy
def mean_pooling(model_output, attention_mask):
"""
Mean pooling - take attention mask into account for correct averaging.
Args:
model_output: Model output containing token embeddings.
attention_mask: Attention mask for the tokens.
Returns:
Mean-pooled embeddings.
"""
token_embeddings = model_output[0] # First element contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
class SARMTransformer(nn.Module):
"""
SARM Transformer model for stage-aware reward prediction.
@@ -65,7 +45,7 @@ class SARMTransformer(nn.Module):
def __init__(
self,
video_dim: int = 512,
text_dim: int = 384,
text_dim: int = 512, # CLIP text encoder output dimension (per SARM paper A.4)
state_dim: int = 14,
hidden_dim: int = 768,
num_heads: int = 12,
@@ -204,7 +184,7 @@ class SARMTransformer(nn.Module):
stage_indices = torch.argmax(stage_probs, dim=-1) # [batch_size, seq_len]
# Get stage embeddings for conditioning
stage_embeds = self.stage_embedding(stage_indices) # [batch_size, seq_len, hidden_dim//4]
stage_embeds = self.stage_embedding(stage_indices)
# Concatenate frame features with stage embeddings
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
@@ -229,9 +209,11 @@ class SARMRewardModel(PreTrainedPolicy):
"""
SARM Reward Model for stage-aware task completion rewards.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
This model combines:
- CLIP for encoding video frames
- MiniLM for encoding text descriptions
- CLIP for encoding video frames AND text descriptions
- SARMTransformer for predicting task stage and progress
- Optional RA-BC (Reward-Aligned Behavior Cloning) for weighted training
"""
@@ -249,24 +231,13 @@ class SARMRewardModel(PreTrainedPolicy):
if dataset_meta is not None:
self._update_num_stages_from_dataset(dataset_meta)
# Initialize CLIP encoder for images
logging.info("Loading CLIP encoder...")
# Initialize CLIP encoder for images AND text (per SARM paper A.4)
logging.info("Loading CLIP encoder for images and text...")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(self.device)
self.clip_model.eval()
# Initialize MiniLM encoder for text
logging.info("Loading MiniLM encoder...")
self.minilm_tokenizer = AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model = AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model.to(self.device)
self.minilm_model.eval()
# Auto-detect state_dim from dataset_stats
if config.state_dim is None:
logging.info(f"Attempting to auto-detect state_dim. dataset_stats is None: {dataset_stats is None}")
@@ -379,7 +350,6 @@ class SARMRewardModel(PreTrainedPolicy):
super().to(device)
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.clip_model.to(device)
self.minilm_model.to(device)
self.sarm_transformer.to(device)
return self
@@ -445,13 +415,13 @@ class SARMRewardModel(PreTrainedPolicy):
@torch.no_grad()
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""
Encode text using MiniLM.
Encode text using CLIP text encoder (per SARM paper A.4).
Args:
text: Text string or list of text strings.
Returns:
Encoded text features (batch_size, 384) or (384,) for single text.
Encoded text features (batch_size, 512) or (512,) for single text.
"""
if isinstance(text, str):
text = [text]
@@ -459,18 +429,18 @@ class SARMRewardModel(PreTrainedPolicy):
else:
single_text = False
# Use CLIP's tokenizer directly (avoids image processor validation issues)
tokenizer = self.clip_processor.tokenizer
# Process in batches
all_embeddings = []
for i in range(0, len(text), self.config.batch_size):
batch_text = text[i:i + self.config.batch_size]
encoded_input = self.minilm_tokenizer(
batch_text, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
model_output = self.minilm_model(**encoded_input)
text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
inputs = tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
text_embeddings = self.clip_model.get_text_features(**inputs)
all_embeddings.append(text_embeddings.cpu())
result = torch.cat(all_embeddings).numpy()
@@ -493,7 +463,7 @@ class SARMRewardModel(PreTrainedPolicy):
Calculate rewards for given text, video, and state representations.
Args:
text_embeddings: Encoded text representations (batch_size, 384)
text_embeddings: Encoded text representations (batch_size, 512)
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
state_features: Joint state features (batch_size, num_frames, state_dim)
return_all_frames: If True, return rewards for all frames
@@ -585,11 +555,10 @@ class SARMRewardModel(PreTrainedPolicy):
logging.info("Checkpoint loaded successfully")
def train(self, mode: bool = True):
"""Set training mode. Note: CLIP and MiniLM encoders always stay in eval mode."""
"""Set training mode. Note: CLIP encoder always stays in eval mode (frozen)."""
super().train(mode)
# Keep encoders in eval mode
# Keep CLIP encoder in eval mode (frozen per SARM paper)
self.clip_model.eval()
self.minilm_model.eval()
# Only transformer can be trained
self.sarm_transformer.train(mode)
return self
@@ -618,30 +587,18 @@ class SARMRewardModel(PreTrainedPolicy):
"""Required by PreTrainedPolicy but not used for SARM."""
raise NotImplementedError("SARM model does not select actions")
def _get_remaining_length(self, observation: dict, idx: int) -> float | None:
"""Extract remaining length for a sample from observation metadata."""
remaining_lengths = observation.get('remaining_length')
if remaining_lengths is None:
return None
if isinstance(remaining_lengths, torch.Tensor):
return remaining_lengths[idx].item() if remaining_lengths.dim() > 0 else remaining_lengths.item()
return remaining_lengths
def _compute_progress_targets(self, remaining_length: float | None, seq_len: int) -> torch.Tensor:
"""Compute progress targets based on remaining trajectory length."""
if remaining_length is not None and remaining_length > 0:
return torch.arange(1, seq_len + 1, dtype=torch.float32, device=self.device) / remaining_length
else:
raise ValueError("Remaining length is None, but is required for progress targets")
def _apply_rewind_augmentation(
def _apply_temporal_augmentation(
self,
video: torch.Tensor,
progress: torch.Tensor,
state: torch.Tensor | None,
max_length: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Apply rewind augmentation: append up to 4 reversed frames (SARM paper A.4)."""
"""Apply temporal augmentation by appending reversed frames (SARM paper A.4).
This helps the model learn to handle non-monotonic progress (failures, recoveries).
Appends 1-4 reversed frames to simulate going backwards in task progress.
"""
num_reverse = random.randint(1, min(4, max_length - 1))
# Reverse and take frames (skip first which is last of original)
@@ -672,14 +629,20 @@ class SARMRewardModel(PreTrainedPolicy):
"""
Forward pass for SARM reward model training.
Uses annotation-based progress targets following SARM paper Eq. 2:
yt = Pk-1 + α̅k × τt
where:
- τt = (t - sk) / (ek - sk) is within-subtask normalized time
- Pk-1 is cumulative prior (sum of previous subtask proportions)
- α̅k is the temporal proportion for subtask k
Args:
batch: Dictionary with 'observation' containing:
- 'video_features': (B, T, 512) pre-encoded video features
- 'text_features': (B, 384) pre-encoded text features
- 'text_features': (B, 512) pre-encoded text features (CLIP)
- 'state_features': (B, T, state_dim) joint state features
- 'remaining_length': (B,) remaining trajectory lengths (optional)
- 'stage_labels': (B, T) stage labels (optional, from annotations)
- 'progress_targets': (B, T, 1) progress targets (optional, from annotations)
- 'stage_labels': (B, T) stage labels from annotations
- 'progress_targets': (B, T, 1) progress targets from annotations
Returns:
Tuple of (total_loss, output_dict with loss components)
@@ -702,21 +665,31 @@ class SARMRewardModel(PreTrainedPolicy):
if state_features is not None and state_features.dim() == 2:
state_features = state_features.unsqueeze(1).expand(-1, max_length, -1)
# Process each sample: compute progress targets and apply rewind augmentation
# Get annotation-based progress targets (required for SARM paper formula)
progress_from_annotations = observation.get('progress_targets')
if progress_from_annotations is None:
raise ValueError("progress_targets from annotations is required for SARM training")
progress_from_annotations = progress_from_annotations.to(self.device)
if progress_from_annotations.dim() == 2:
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
# Process each sample: apply temporal augmentation (SARM paper A.4)
processed_videos = []
processed_states = []
progress_targets = []
for i in range(batch_size):
remaining_length = self._get_remaining_length(observation, i)
progress = self._compute_progress_targets(remaining_length, max_length)
video = video_features[i]
state = state_features[i] if state_features is not None else None
progress = progress_from_annotations[i].squeeze(-1) # (T,)
# Apply rewind augmentation with 50% probability (SARM paper)
# Apply temporal augmentation with 50% probability (SARM paper A.4)
# Appends up to 4 reversed frames to simulate failures/recoveries
if random.random() < 0.5:
video, progress, state = self._apply_rewind_augmentation(video, progress, state, max_length)
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
# Ensure correct sequence length
video = self._ensure_sequence_length(video, max_length)
@@ -739,32 +712,22 @@ class SARMRewardModel(PreTrainedPolicy):
processed_videos, text_features, processed_states
)
# Use annotation-based progress targets
progress_from_annotations = observation.get('progress_targets')
if progress_from_annotations is not None:
progress_from_annotations = progress_from_annotations.to(self.device)
if progress_from_annotations.dim() == 2:
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
progress_targets = progress_from_annotations
# Compute progress loss
# Compute progress loss (MSE)
progress_loss = F.mse_loss(progress_preds, progress_targets)
output_dict = {'progress_loss': progress_loss.item()}
total_loss = progress_loss
# Compute stage loss if labels available
# Compute stage loss (cross-entropy)
stage_labels = observation.get('stage_labels')
if stage_labels is not None:
stage_labels = stage_labels.to(self.device)
if stage_labels.dim() == 1:
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
stage_loss = compute_stage_loss(stage_logits, stage_labels)
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
output_dict['stage_loss'] = stage_loss.item()
else:
raise ValueError("Stage labels are None, but are required for stage loss")
if stage_labels is None:
raise ValueError("stage_labels from annotations is required for SARM training")
stage_labels = stage_labels.to(self.device)
if stage_labels.dim() == 1:
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
stage_loss = compute_stage_loss(stage_logits, stage_labels)
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
output_dict['stage_loss'] = stage_loss.item()
# Misaligned loss: 20% probability (SARM paper - improve video-language alignment)
if random.random() < 0.2:
@@ -786,9 +749,3 @@ def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor)
loss = F.cross_entropy(stage_logits_flat, target_stages_flat)
return loss
def compute_progress_loss(progress_preds: torch.Tensor, target_progress: torch.Tensor) -> torch.Tensor:
loss = F.mse_loss(progress_preds, target_progress)
return loss

View File

@@ -20,7 +20,7 @@ import numpy as np
import torch
from PIL import Image
import pandas as pd
from transformers import AutoModel, AutoTokenizer, CLIPModel, CLIPProcessor
from transformers import CLIPModel, CLIPProcessor
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.processor import (
@@ -44,9 +44,12 @@ class SARMEncodingProcessorStep(ProcessorStep):
"""
ProcessorStep that encodes images and text for SARM training.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
This step handles:
- CLIP (image) encoding
- MiniLM (text) encoding
- CLIP image encoding (512-dim)
- CLIP text encoding (512-dim)
- Joint state normalization
Supports temporal sequences: (B, T, C, H, W) → (B, T, 512) video features
@@ -76,28 +79,18 @@ class SARMEncodingProcessorStep(ProcessorStep):
self._init_encoders()
def _init_encoders(self):
"""Initialize CLIP and MiniLM encoders."""
"""Initialize CLIP encoder for both images and text (per SARM paper A.4)."""
device = torch.device(
self.config.device if self.config.device
else "cuda" if torch.cuda.is_available() else "cpu"
)
logging.info("Initializing CLIP encoder for SARM...")
logging.info("Initializing CLIP encoder for SARM (images + text)...")
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
self.clip_model.to(device)
self.clip_model.eval()
logging.info("Initializing MiniLM encoder for SARM...")
self.minilm_tokenizer = AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model = AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model.to(device)
self.minilm_model.eval()
self.device = device
def _compute_temporal_proportions(self):
@@ -167,11 +160,13 @@ class SARMEncodingProcessorStep(ProcessorStep):
for name in self.subtask_names
}
else:
# Equal proportions if no duration info
self.temporal_proportions = {
name: 1.0 / len(self.subtask_names)
for name in self.subtask_names
}
raise ValueError(
"Cannot compute temporal proportions: all subtask durations are zero. "
"Check that your dataset has valid subtask annotations with start/end times."
)
# Store in config for the model to use in progress output conversion (SARM paper Eq. 4)
self.config.temporal_proportions = [self.temporal_proportions[name] for name in self.subtask_names]
logging.info(f"Computed temporal proportions for {len(self.subtask_names)} subtasks: {self.temporal_proportions}")
@@ -481,15 +476,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
observation['state_features'] = torch.tensor(state_data, dtype=torch.float32)
# 3. Encode text with MiniLM
# 3. Encode text with CLIP (per SARM paper A.4)
batch_size = video_features.shape[0]
task_descriptions = new_transition.get('task')
if task_descriptions is not None:
if isinstance(task_descriptions, str):
task_descriptions = [task_descriptions] * batch_size
observation['text_features'] = self._encode_text_batch_list(task_descriptions)
else:
observation['text_features'] = self._encode_text_batch(self.task_description, batch_size)
observation['text_features'] = self._encode_text_clip(self.task_description, batch_size)
# 4. Extract frame/episode indices from complementary data
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
@@ -609,54 +598,33 @@ class SARMEncodingProcessorStep(ProcessorStep):
return all_embeddings
@torch.no_grad()
def _encode_text_batch(self, text: str, batch_size: int) -> torch.Tensor:
"""Encode a text string using MiniLM and replicate for batch.
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
"""Encode text using CLIP text encoder (per SARM paper A.4).
Args:
text: Text string to encode
text: Task description text to encode
batch_size: Batch size to replicate for
Returns:
Encoded feature vectors with shape (B, 384)
Encoded text features with shape (B, 512)
"""
from lerobot.policies.rewind.modeling_rewind import mean_pooling
# Use CLIP's tokenizer directly for text (avoids image processor validation issues)
tokenizer = self.clip_processor.tokenizer
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
encoded_input = self.minilm_tokenizer(
text, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
# Get text features from CLIP
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
model_output = self.minilm_model(**encoded_input)
text_embedding = mean_pooling(model_output, encoded_input["attention_mask"])
text_embedding = text_embedding.squeeze().cpu()
# Handle single text case
if text_embedding.dim() == 1:
text_embedding = text_embedding.unsqueeze(0)
# Replicate for batch (B, 384)
text_embedding = text_embedding.unsqueeze(0).repeat(batch_size, 1)
# Replicate for batch (B, 512)
text_embedding = text_embedding.expand(batch_size, -1)
return text_embedding
@torch.no_grad()
def _encode_text_batch_list(self, text_list: list[str]) -> torch.Tensor:
"""Encode a list of text strings using MiniLM.
Args:
text_list: List of text strings to encode
Returns:
Encoded feature vectors with shape (B, 384)
"""
from lerobot.policies.rewind.modeling_rewind import mean_pooling
# Encode all texts in the batch at once
encoded_input = self.minilm_tokenizer(
text_list, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
model_output = self.minilm_model(**encoded_input)
text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
text_embeddings = text_embeddings.cpu()
return text_embeddings
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
@@ -688,9 +656,12 @@ def make_sarm_pre_post_processors(
"""
Create pre-processor and post-processor pipelines for SARM.
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
to process both RGB image sequences and task descriptions."
The pre-processing pipeline:
1. Encodes images with CLIP (512-dim)
2. Encodes text with MiniLM (384-dim)
2. Encodes text with CLIP (512-dim)
3. Normalizes joint states
4. Adds batch dimension
5. Moves data to device

View File

@@ -229,8 +229,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For ReWiND and SARM, always provide dataset_meta for progress normalization
if cfg.policy.type in ["rewind", "sarm"]:
# For SARM, always provide dataset_meta for progress normalization
if cfg.policy.type == "sarm":
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
@@ -319,20 +319,17 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
elif cfg.policy.type in ["rewind", "sarm"] and getattr(cfg.policy, "use_temporal_sampler", False):
# Use temporal sequence sampler for loading sequences
from lerobot.datasets.temporal_sampler import TemporalSequenceSampler
elif cfg.policy.type == "sarm" and getattr(cfg.policy, "use_temporal_sampler", False):
# Use SARM temporal sampler for reward model training
from lerobot.datasets.temporal_sampler import SARMTemporalSampler
shuffle = False
sampling_mode = getattr(cfg.policy, "sampling_mode", cfg.policy.type)
sampler = TemporalSequenceSampler(
sampler = SARMTemporalSampler(
dataset_from_index=dataset.meta.episodes["dataset_from_index"],
dataset_to_index=dataset.meta.episodes["dataset_to_index"],
sequence_length=cfg.policy.max_length,
stride=getattr(cfg.policy, "sequence_stride", 1) if cfg.policy.type == "rewind" else getattr(cfg.policy, "frame_gap", 30),
frame_gap=getattr(cfg.policy, "frame_gap", 30),
shuffle=True,
seed=cfg.seed,
sampling_mode=sampling_mode,
)
else:
shuffle = True

View File

@@ -35,7 +35,7 @@ class RABCWeightComputer:
and applies soft weighting based on progress deltas.
Args:
reward_model: Pre-trained reward model (e.g., SARM, ReWiND)
reward_model: Pre-trained reward model (e.g., SARM)
kappa: Hard threshold for high-quality samples (default: 0.01)
epsilon: Small constant for numerical stability (default: 1e-6)
device: Device to run reward model on