add initial modeling

This commit is contained in:
Pepijn
2025-10-15 12:52:33 +02:00
parent a6ff3cfebb
commit d9f0c8c3ae
5 changed files with 1004 additions and 0 deletions

View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Video sampling utilities for temporal data augmentation and frame selection.
This module provides utilities for sampling and augmenting video sequences, particularly
for reward model training. It includes functions for:
- Padding/sampling videos to fixed lengths
- Video rewind augmentation for learning to decrease rewards
"""
import random
from typing import Tuple
import numpy as np
import torch
def sample_video_feature(
video_feature: torch.Tensor,
max_length: int = 32,
random_sample: bool = True
) -> torch.Tensor:
"""
Sample or pad video features to a fixed length.
Args:
video_feature: Video features tensor (num_frames, feature_dim)
max_length: Target sequence length
random_sample: If True, randomly sample frames. If False, uniformly sample.
Returns:
Sampled/padded video features (max_length, feature_dim)
"""
video_length = len(video_feature)
if video_length < max_length:
# Pad with last frame
padding_length = max_length - video_length
last_frame = video_feature[-1].unsqueeze(0)
padding_frames = last_frame.repeat(padding_length, 1)
video_feature = torch.cat([video_feature, padding_frames], dim=0)
elif video_length > max_length:
if random_sample:
# Random sampling
frame_idx = sorted(random.sample(range(video_length), max_length))
else:
# Uniform sampling
frame_idx = np.linspace(0, video_length - 1, max_length, dtype=int)
video_feature = video_feature[frame_idx]
return video_feature
def sample_reverse_video_feature(
video_feature: torch.Tensor,
max_length: int = 32,
random_sample: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample video with reverse augmentation (video rewind).
This function implements the video rewind augmentation described in the ReWiND paper.
It splits the video at a random point and reverses k frames from that point, creating
a trajectory that looks like it's making progress then regressing. This trains the
reward model to properly decrease rewards when the policy fails.
Args:
video_feature: Video features tensor (num_frames, feature_dim)
max_length: Target sequence length
random_sample: If True, use random sampling for frame selection
Returns:
Tuple of:
- Rewound video features (max_length, feature_dim)
- Progress targets for each frame (max_length,)
"""
video_length = len(video_feature)
# Sample split point (where to start reversing)
split_idx = random.randint(1, min(video_length - 1, max_length - 1))
# Sample how many frames to reverse (k in the paper)
max_reverse = min(split_idx, max_length - split_idx)
if max_reverse > 0:
reverse_length = random.randint(1, max_reverse)
else:
reverse_length = 0
# Create rewound video
if reverse_length > 0:
# Forward part: frames 0 to split_idx
forward_frames = video_feature[:split_idx]
# Reverse part: frames from split_idx-1 going backwards
reverse_frames = video_feature[split_idx - reverse_length:split_idx].flip(0)
# Combine forward and reverse parts
rewound_video = torch.cat([forward_frames, reverse_frames], dim=0)
# Create progress targets
# Forward part has increasing progress
forward_progress = torch.linspace(0, split_idx / video_length, split_idx)
# Reverse part has decreasing progress
reverse_progress = torch.linspace(
(split_idx - 1) / video_length,
(split_idx - reverse_length) / video_length,
reverse_length
)
progress_targets = torch.cat([forward_progress, reverse_progress])
else:
# No reversal, just use original video
rewound_video = video_feature[:max_length]
progress_targets = torch.linspace(0, min(max_length, video_length) / video_length, len(rewound_video))
# Pad or sample to target length
if len(rewound_video) < max_length:
# Pad with last frame
padding_length = max_length - len(rewound_video)
last_frame = rewound_video[-1].unsqueeze(0)
padding_frames = last_frame.repeat(padding_length, 1)
rewound_video = torch.cat([rewound_video, padding_frames], dim=0)
# Extend progress targets (stay at last progress value)
last_progress = progress_targets[-1]
padding_progress = torch.full((padding_length,), last_progress)
progress_targets = torch.cat([progress_targets, padding_progress])
elif len(rewound_video) > max_length:
# Sample frames
if random_sample:
frame_idx = sorted(random.sample(range(len(rewound_video)), max_length))
else:
frame_idx = np.linspace(0, len(rewound_video) - 1, max_length, dtype=int)
rewound_video = rewound_video[frame_idx]
progress_targets = progress_targets[frame_idx]
return rewound_video, progress_targets

View File

@@ -111,6 +111,41 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
return LambdaLR(optimizer, lr_lambda, -1)
@LRSchedulerConfig.register_subclass("cosine_with_min_lr")
@dataclass
class CosineWithMinLRSchedulerConfig(LRSchedulerConfig):
"""Cosine learning rate scheduler with minimum learning rate floor.
Used by ReWiND for reward model training. Includes linear warmup phase
followed by cosine annealing with a minimum learning rate.
"""
num_warmup_steps: int
min_lr: float = 0.0
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
def lr_lambda(current_step):
# Get base learning rate from optimizer
base_lr = optimizer.param_groups[0]['lr']
if current_step <= self.num_warmup_steps:
# Linear warmup
if self.num_warmup_steps == 0:
return 1.0
return float(current_step) / float(max(1, self.num_warmup_steps))
else:
# Cosine annealing with minimum learning rate
progress = (current_step - self.num_warmup_steps) / float(
max(1, num_training_steps - self.num_warmup_steps)
)
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
# Scale between min_lr and base_lr
min_factor = self.min_lr / base_lr if base_lr > 0 else 0.0
return min_factor + (1.0 - min_factor) * cosine_factor
return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict()
write_json(state_dict, save_dir / SCHEDULER_STATE)

View File

@@ -0,0 +1,36 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
from lerobot.policies.rewind.modeling_rewind import (
ReWiNDRewardModel,
ReWiNDTransformer,
train_step_fn,
create_training_batch,
compute_progress_loss,
compute_misaligned_loss,
)
__all__ = [
"ReWiNDConfig",
"ReWiNDRewardModel",
"ReWiNDTransformer",
"train_step_fn",
"create_training_batch",
"compute_progress_loss",
"compute_misaligned_loss",
]

View File

@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from lerobot.configs.policies import PreTrainedConfig
@PreTrainedConfig.register_subclass("rewind")
@dataclass
class ReWiNDConfig(PreTrainedConfig):
"""Configuration class for ReWiND Reward Model.
ReWiND (Reward from Video and Natural language Descriptions) is a reward model
that computes task completion/progress rewards from video observations and
language task descriptions.
"""
# Model architecture parameters
video_dim: int = 768 # DINO embedding dimension
text_dim: int = 384 # MiniLM embedding dimension
hidden_dim: int = 512
num_heads: int = 8
num_layers: int = 4
# Temporal parameters
max_length: int = 32 # Maximum video sequence length
subsample_video: bool = True # Whether to pad/subsample videos to max_length
# Training parameters
batch_size: int = 64
dino_batch_size: int = 64
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
# Model loading
pretrained_model_path: str | None = None
# Device settings
device: str | None = None
# Dropout
dropout: float = 0.1 # Dropout rate for transformer
def __post_init__(self):
super().__post_init__()
# Validate configuration
if self.hidden_dim % self.num_heads != 0:
raise ValueError(
f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})"
)
if self.max_length <= 0:
raise ValueError(f"max_length must be positive, got {self.max_length}")
if self.dropout < 0 or self.dropout >= 1:
raise ValueError(f"dropout must be in [0, 1), got {self.dropout}")

View File

@@ -0,0 +1,708 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Union, Dict, Optional
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import torchvision.transforms as T
from lerobot.policies.rewind.configuration_rewind import ReWiNDConfig
from lerobot.datasets.video_sampler import sample_video_feature, sample_reverse_video_feature
# Helper functions for encoding
def dino_load_image(img: np.ndarray) -> torch.Tensor:
"""
Load an image and return a tensor that can be used as an input to DINOv2.
Args:
img: Input image as numpy array (H, W, C) in uint8 format.
Returns:
Transformed image tensor ready for DINO encoder (1, 3, 224, 224).
"""
# Define transform: center crop to 224x224, normalize to [-1, 1]
dino_transform = T.Compose([
T.ToTensor(),
T.CenterCrop(224),
T.Normalize([0.5], [0.5])
])
img_pil = Image.fromarray(img)
transformed_img = dino_transform(img_pil)[:3].unsqueeze(0)
return transformed_img
def mean_pooling(model_output, attention_mask):
"""
Mean pooling - take attention mask into account for correct averaging.
Args:
model_output: Model output containing token embeddings.
attention_mask: Attention mask for the tokens.
Returns:
Mean-pooled embeddings.
"""
token_embeddings = model_output[0] # First element contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
class ReWiNDTransformer(nn.Module):
"""
ReWiND Transformer model for predicting task progress from video and text.
This model takes video frame embeddings and text embeddings as input,
and predicts a progress score (0-1) for each frame indicating how much
of the task has been completed.
"""
def __init__(
self,
video_dim: int = 768,
text_dim: int = 384,
hidden_dim: int = 512,
num_heads: int = 8,
num_layers: int = 4,
max_length: int = 32,
dropout: float = 0.1
):
super().__init__()
self.hidden_dim = hidden_dim
self.max_length = max_length
# Project video and text to common dimension
self.video_proj = nn.Linear(video_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
# Position embeddings for video sequence
# We only add positional embedding to the first frame as in the original
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4,
dropout=dropout,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# Progress prediction head (applied to each frame)
self.progress_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
# Attention mask for causal self-attention
# Will be created on-demand based on sequence length
self.register_buffer("attention_mask", None, persistent=False)
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
"""Generate or retrieve cached causal attention mask."""
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
# Create causal mask (upper triangular with -inf)
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
self.attention_mask = mask
return self.attention_mask
def forward(self, video_frames: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the ReWiND transformer.
Args:
video_frames: Video frame embeddings (batch_size, seq_len, video_dim)
text_embed: Text embeddings (batch_size, text_dim)
Returns:
Progress predictions for each frame (batch_size, seq_len, 1)
"""
batch_size = video_frames.shape[0]
# Project inputs to common dimension
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
# Add positional embedding to first video frame
video_embed[:, 0] += self.first_pos_embed
# Combine sequence: [text, video_frames]
sequence = torch.cat([text_embed, video_embed], dim=1)
# Get causal attention mask
seq_length = sequence.shape[1]
attention_mask = self._get_attention_mask(seq_length, sequence.device)
# Pass through transformer with causal masking
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
# Get progress predictions for each frame (exclude text token)
progress_preds = self.progress_head(transformed[:, 1:])
return progress_preds
class ReWiNDRewardModel(nn.Module):
"""
ReWiND Reward Model for computing task completion rewards from video and text.
This model combines:
- DINO (DINOv2) for encoding video frames
- MiniLM for encoding text descriptions
- ReWiNDTransformer for predicting task progress
"""
def __init__(self, config: ReWiNDConfig):
super().__init__()
self.config = config
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
# Initialize DINO encoder for images
logging.info("Loading DINO encoder...")
self.dino_encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
self.dino_encoder.to(self.device)
self.dino_encoder.eval()
# Initialize MiniLM encoder for text
logging.info("Loading MiniLM encoder...")
self.minilm_tokenizer = AutoTokenizer.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model = AutoModel.from_pretrained(
"sentence-transformers/all-MiniLM-L12-v2"
)
self.minilm_model.to(self.device)
self.minilm_model.eval()
# Initialize ReWiND transformer with explicit architecture parameters
self.rewind_transformer = ReWiNDTransformer(
video_dim=config.video_dim,
text_dim=config.text_dim,
hidden_dim=config.hidden_dim,
num_heads=config.num_heads,
num_layers=config.num_layers,
max_length=config.max_length,
dropout=config.dropout
)
self.rewind_transformer.to(self.device)
logging.info(f"ReWiND Reward Model initialized on {self.device}")
def to(self, device):
"""Override to method to ensure all components move together."""
super().to(device)
self.device = device if isinstance(device, torch.device) else torch.device(device)
self.dino_encoder.to(device)
self.minilm_model.to(device)
self.rewind_transformer.to(device)
return self
@torch.no_grad()
def encode_images(self, images: np.ndarray) -> np.ndarray:
"""
Encode video frames using DINO.
Args:
images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8.
Can also be (num_frames, H, W, C) for a single video.
Returns:
Encoded image features (num_videos, num_frames, 768) or (num_frames, 768).
"""
# Handle single video case
single_video = False
if len(images.shape) == 4:
images = images[np.newaxis, ...]
single_video = True
assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}"
# Ensure channels are in correct position
if images.shape[-1] == 3 and images.shape[2] != 3:
images = np.transpose(images, (0, 1, 4, 2, 3))
all_embeddings = []
for video in images:
# Process each video
video_embeddings = []
# Convert frames to list of numpy arrays
frames = [frame.transpose(1, 2, 0).astype(np.uint8) if frame.shape[0] == 3 else frame for frame in video]
# Batch process frames with DINO
episode_images_dino = [dino_load_image(frame) for frame in frames]
# Process in batches
for i in range(0, len(episode_images_dino), self.config.dino_batch_size):
batch = torch.cat(episode_images_dino[i:i + self.config.dino_batch_size])
batch = batch.to(self.device)
embeddings = self.dino_encoder(batch).squeeze().detach().cpu()
# Handle single frame case
if embeddings.dim() == 1:
embeddings = embeddings.unsqueeze(0)
video_embeddings.append(embeddings)
video_embeddings = torch.cat(video_embeddings)
all_embeddings.append(video_embeddings)
result = torch.stack(all_embeddings).numpy()
if single_video:
result = result[0]
return result
@torch.no_grad()
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""
Encode text using MiniLM.
Args:
text: Text string or list of text strings.
Returns:
Encoded text features (batch_size, 384) or (384,) for single text.
"""
if isinstance(text, str):
text = [text]
single_text = True
else:
single_text = False
# Process in batches
all_embeddings = []
for i in range(0, len(text), self.config.batch_size):
batch_text = text[i:i + self.config.batch_size]
encoded_input = self.minilm_tokenizer(
batch_text, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
model_output = self.minilm_model(**encoded_input)
text_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
all_embeddings.append(text_embeddings.cpu())
result = torch.cat(all_embeddings).numpy()
if single_text:
result = result[0]
return result
def padding_video(self, video_frames: torch.Tensor, max_length: int) -> torch.Tensor:
"""
Pad or subsample video frames to a fixed length.
Args:
video_frames: Video frames tensor (num_frames, embedding_dim)
max_length: Target sequence length
Returns:
Padded/subsampled video frames (max_length, embedding_dim)
"""
video_length = len(video_frames)
if isinstance(video_frames, np.ndarray):
video_frames = torch.tensor(video_frames)
if video_length < max_length:
# Pad with last frame
padding_length = max_length - video_length
last_frame = video_frames[-1].unsqueeze(0)
padding_frames = last_frame.repeat(padding_length, 1)
video_frames = torch.cat([video_frames, padding_frames], dim=0)
elif video_length > max_length:
# Subsample uniformly
frame_idx = np.linspace(0, video_length - 1, max_length).astype(int)
video_frames = video_frames[frame_idx]
return video_frames
@torch.no_grad()
def calculate_rewards(
self,
text_embeddings: Union[np.ndarray, torch.Tensor],
video_embeddings: Union[np.ndarray, torch.Tensor],
return_all_frames: bool = False
) -> np.ndarray:
"""
Calculate rewards for given text and video representations.
Args:
text_embeddings: Encoded text representations (batch_size, 384)
video_embeddings: Encoded video representations (batch_size, num_frames, 768)
return_all_frames: If True, return rewards for all frames. If False, return only last frame.
Returns:
Reward values (batch_size,) or (batch_size, num_frames) if return_all_frames=True
"""
# Convert to tensors if needed
if isinstance(text_embeddings, np.ndarray):
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
if isinstance(video_embeddings, np.ndarray):
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
# Handle single sample case
if text_embeddings.dim() == 1:
text_embeddings = text_embeddings.unsqueeze(0)
video_embeddings = video_embeddings.unsqueeze(0)
single_sample = True
else:
single_sample = False
# Process in batches
all_rewards = []
for i in range(0, len(video_embeddings), self.config.batch_size):
batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device)
batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device)
# Pad/subsample videos if needed
if self.config.subsample_video:
padded_videos = []
for video in batch_videos:
padded_video = self.padding_video(video, self.config.max_length)
padded_videos.append(padded_video)
batch_videos = torch.stack(padded_videos).to(self.device)
# Get progress predictions
rewards = self.rewind_transformer(batch_videos.float(), batch_texts.float())
if return_all_frames:
all_rewards.append(rewards.squeeze(-1).cpu())
else:
# Return only last frame reward
all_rewards.append(rewards[:, -1, 0].cpu())
result = torch.cat(all_rewards).numpy()
if single_sample:
result = result[0] if not return_all_frames else result[0]
return result
def load_pretrained_checkpoint(self, checkpoint_path: str, strict: bool = False):
"""
Load pretrained model weights from a checkpoint file.
Args:
checkpoint_path: Path to the .pth checkpoint file
strict: Whether to strictly enforce that the keys match
"""
logging.info(f"Loading pretrained checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
# Handle different checkpoint formats
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
# Check for architecture parameters in checkpoint
if "args" in checkpoint:
args = checkpoint["args"]
logging.info(f"Checkpoint was trained with: max_length={args.max_length}")
# Warn if max_length differs
if hasattr(args, 'max_length') and args.max_length != self.config.max_length:
logging.warning(
f"Checkpoint max_length ({args.max_length}) differs from config ({self.config.max_length}). "
"This may cause issues if sequence lengths don't match."
)
else:
state_dict = checkpoint
# Load only the ReWiNDTransformer weights
missing_keys, unexpected_keys = self.rewind_transformer.load_state_dict(state_dict, strict=strict)
if missing_keys:
logging.warning(f"Missing keys when loading checkpoint: {missing_keys}")
if unexpected_keys:
logging.warning(f"Unexpected keys when loading checkpoint: {unexpected_keys}")
logging.info("Checkpoint loaded successfully")
def train(self, mode: bool = True):
"""Set training mode. Note: DINO and MiniLM encoders always stay in eval mode."""
super().train(mode)
# Keep encoders in eval mode
self.dino_encoder.eval()
self.minilm_model.eval()
# Only transformer can be trained
self.rewind_transformer.train(mode)
return self
def eval(self):
"""Set evaluation mode."""
return self.train(False)
# Training utilities
def compute_progress_loss(
model: ReWiNDTransformer,
video_features: torch.Tensor,
text_features: torch.Tensor,
target_progress: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute progress prediction loss.
Args:
model: ReWiNDTransformer model
video_features: Batch of video features (batch_size, max_length, feature_dim)
text_features: Batch of text features (batch_size, text_dim)
target_progress: Optional target progress values (batch_size, max_length).
If None, uses linear progress from 0 to 1.
Returns:
Mean squared error loss
"""
# Get predictions
progress_preds = model(video_features, text_features)
# Create target progress if not provided
if target_progress is None:
batch_size, max_length = video_features.shape[:2]
target_progress = torch.linspace(0, 1, max_length, device=video_features.device)
target_progress = target_progress.unsqueeze(0).repeat(batch_size, 1)
# Ensure target has correct shape
if target_progress.dim() == 2:
target_progress = target_progress.unsqueeze(-1)
# Compute MSE loss
loss = F.mse_loss(progress_preds, target_progress)
return loss
def compute_misaligned_loss(
model: ReWiNDTransformer,
video_features: torch.Tensor,
misaligned_text_features: torch.Tensor
) -> torch.Tensor:
"""
Compute loss for misaligned video-text pairs (should predict 0 progress).
Args:
model: ReWiNDTransformer model
video_features: Batch of video features (batch_size, max_length, feature_dim)
misaligned_text_features: Batch of misaligned text features (batch_size, text_dim)
Returns:
Mean squared error loss (predictions should be close to 0)
"""
# Get predictions
progress_preds = model(video_features, misaligned_text_features)
# Target is all zeros
target_zeros = torch.zeros_like(progress_preds)
# Compute MSE loss
loss = F.mse_loss(progress_preds, target_zeros)
return loss
def train_step_fn(
model: ReWiNDRewardModel,
batch: Dict[str, torch.Tensor],
optimizer: torch.optim.Optimizer,
use_rewind: bool = True,
rewind_prob: float = 0.5,
misaligned_prob: float = 0.5,
gradient_clip: float = 1.0
) -> Dict[str, float]:
"""
Perform a single training step for the ReWiND model.
This function implements the training logic from the ReWiND paper, including:
- Progress prediction on aligned video-text pairs
- Video rewind augmentation for learning to decrease rewards
- Misaligned video-text pairs for learning to output zero rewards
Args:
model: ReWiNDRewardModel instance
batch: Dictionary containing:
- 'video_features': Pre-computed video embeddings (batch_size, num_frames, 768)
- 'text_features': Pre-computed text embeddings (batch_size, 384)
- 'misaligned_video_features': Optional misaligned videos
- 'misaligned_text_features': Optional misaligned texts
optimizer: Optimizer for updating model parameters
use_rewind: Whether to use video rewind augmentation
rewind_prob: Probability of applying rewind to each sample
misaligned_prob: Probability of including misaligned loss
gradient_clip: Gradient clipping value
Returns:
Dictionary of loss values for logging
"""
model.train()
optimizer.zero_grad()
# Get features from batch
video_features = batch['video_features'].to(model.device)
text_features = batch['text_features'].to(model.device)
batch_size = video_features.shape[0]
max_length = model.config.max_length
# Process videos (with potential rewind augmentation)
processed_videos = []
progress_targets = []
for i in range(batch_size):
if use_rewind and random.random() < rewind_prob:
# Apply video rewind augmentation
rewound_video, progress = sample_reverse_video_feature(
video_features[i],
max_length=max_length,
random_sample=True
)
processed_videos.append(rewound_video)
progress_targets.append(progress)
else:
# Normal video sampling
sampled_video = sample_video_feature(
video_features[i],
max_length=max_length,
random_sample=True
)
processed_videos.append(sampled_video)
# Linear progress from 0 to 1
progress = torch.linspace(0, 1, max_length, device=model.device)
progress_targets.append(progress)
processed_videos = torch.stack(processed_videos)
progress_targets = torch.stack(progress_targets)
# Compute progress loss
progress_loss = compute_progress_loss(
model.rewind_transformer,
processed_videos,
text_features,
progress_targets
)
total_loss = progress_loss
losses = {'progress_loss': progress_loss.item()}
# Compute misaligned loss if requested
if random.random() < misaligned_prob:
if 'misaligned_video_features' in batch and 'misaligned_text_features' in batch:
misaligned_videos = batch['misaligned_video_features'].to(model.device)
misaligned_texts = batch['misaligned_text_features'].to(model.device)
else:
# Create misaligned pairs by shuffling
shuffle_idx = torch.randperm(batch_size)
misaligned_videos = processed_videos[shuffle_idx]
misaligned_texts = text_features
# Sample misaligned videos
misaligned_videos_sampled = []
for i in range(batch_size):
sampled = sample_video_feature(
misaligned_videos[i],
max_length=max_length,
random_sample=True
)
misaligned_videos_sampled.append(sampled)
misaligned_videos_sampled = torch.stack(misaligned_videos_sampled)
misaligned_loss = compute_misaligned_loss(
model.rewind_transformer,
misaligned_videos_sampled,
misaligned_texts
)
total_loss = total_loss + misaligned_loss
losses['misaligned_loss'] = misaligned_loss.item()
# Backward pass
total_loss.backward()
# Gradient clipping
if gradient_clip > 0:
torch.nn.utils.clip_grad_norm_(model.rewind_transformer.parameters(), gradient_clip)
# Optimizer step
optimizer.step()
losses['total_loss'] = total_loss.item()
return losses
def create_training_batch(
model: ReWiNDRewardModel,
videos: np.ndarray,
texts: List[str],
batch_size: int = 32,
encode_on_the_fly: bool = True
) -> Dict[str, torch.Tensor]:
"""
Create a training batch from raw videos and texts.
Args:
model: ReWiNDRewardModel instance (for encoding if needed)
videos: Raw video frames (batch_size, num_frames, H, W, C)
texts: List of text descriptions
batch_size: Batch size for encoding
encode_on_the_fly: If True, encode videos and texts. If False, assume pre-encoded.
Returns:
Dictionary containing video and text features
"""
if encode_on_the_fly:
# Encode videos using DINO
video_features = model.encode_images(videos)
video_features = torch.tensor(video_features, dtype=torch.float32)
# Encode texts using MiniLM
text_features = model.encode_text(texts)
text_features = torch.tensor(text_features, dtype=torch.float32)
else:
# Assume videos and texts are already encoded
video_features = torch.tensor(videos, dtype=torch.float32)
text_features = torch.tensor(texts, dtype=torch.float32)
return {
'video_features': video_features,
'text_features': text_features
}