exactly as rewind code

This commit is contained in:
Pepijn
2025-08-28 21:18:41 +02:00
parent cc05067a76
commit 7dce022a05
3 changed files with 291 additions and 328 deletions

View File

@@ -66,8 +66,10 @@ class RLearNConfig(PreTrainedConfig):
# ReWiND-specific parameters
use_video_rewind: bool = True # Enable video rewinding augmentation
rewind_prob: float = 0.5 # Probability of applying rewind to each batch
rewind_prob: float = 0.8 # Probability of applying rewind to each sample (paper: ~80%)
rewind_last3_prob: float = 0.1 # Of the rewinds, 10% only rewind the last 3 frames
use_mismatch_loss: bool = True # Enable mismatched language-video loss
mismatch_prob: float = 0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%)
# Loss hyperparameters (simplified for ReWiND)
# The main loss is just MSE between predicted and target progress
@@ -80,6 +82,18 @@ class RLearNConfig(PreTrainedConfig):
}
)
# Architectural knobs to better mirror ReWiND
num_register_tokens: int = 4
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head
# HLGauss loss parameters
use_hl_gauss_loss: bool = True
reward_min_value: float = 0.0
reward_max_value: float = 1.0
reward_hl_gauss_loss_num_bins: int = 20
categorical_rewards: bool = False
reward_bins: int = 10 # only used if categorical_rewards=True
def validate_features(self) -> None:
# Require at least one image feature. Language is recommended but optional (can be blank).
if not self.image_features:

View File

@@ -76,10 +76,24 @@ Notes
from __future__ import annotations
import math
from itertools import chain
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
# ReWiND dependencies
try:
from x_transformers import Decoder
from hl_gauss_pytorch import HLGaussLayer
import einx
from einops import rearrange, repeat, pack, unpack
except ImportError as e:
raise ImportError(
"ReWiND dependencies not installed. Please install: "
"pip install x-transformers hl-gauss-pytorch einx einops"
) from e
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -87,12 +101,12 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
class RLearNPolicy(PreTrainedPolicy):
"""Video-language conditioned reward model.
"""Video-language conditioned reward model following ReWiND architecture exactly: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
- Visual encoder: frozen DINOv2 (base), returns per-frame embeddings.
- Text encoder: frozen sentence-transformers (all-MiniLM-L12-v2), returns a language embedding.
- Temporal module: causal transformer over time that cross-attends to language embedding.
- Output: per-timestep reward logits; trainable small head.
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
- Output: per-timestep rewards via HLGauss layer or categorical bins.
"""
config_class = RLearNConfig
@@ -102,6 +116,7 @@ class RLearNPolicy(PreTrainedPolicy):
super().__init__(config)
self.config = config
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
self.categorical_rewards = config.categorical_rewards
# Encoders - ReWiND paper setup: DINOv2 for vision, sentence-transformers for text
from transformers import AutoImageProcessor, AutoModel
@@ -124,43 +139,48 @@ class RLearNPolicy(PreTrainedPolicy):
for p in self.text_encoder.parameters():
p.requires_grad = False
# x_transformers Decoder (matching ReWiND exactly)
self.decoder = Decoder(
dim=config.dim_model,
depth=config.n_layers,
heads=config.n_heads,
attn_dim_head=64, # ReWiND default
ff_mult=config.dim_feedforward // config.dim_model, # Convert to multiplier
# Note: x_transformers uses attn_dropout and ff_dropout separately
attn_dropout=config.dropout,
ff_dropout=config.dropout,
)
# Linear projections to the shared temporal model dimension
self.visual_proj = nn.Linear(self.vision_hidden, config.dim_model)
self.text_proj = nn.Linear(self.text_hidden, config.dim_model)
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
# Positional encodings over time
self.register_buffer(
"positional_encoding",
create_sinusoidal_pos_encoding(config.max_seq_len, config.dim_model),
persistent=False,
# Only first frame gets a positional embed (no cheating on progress)
self.first_pos_emb = nn.Parameter(torch.randn(config.dim_model) * 1e-2)
# Register / memory / attention sink tokens
self.num_register_tokens = config.num_register_tokens
self.register_tokens = nn.Parameter(torch.randn(config.num_register_tokens, config.dim_model) * 1e-2)
# MLP predictor (matching ReWiND's Feedforwards)
from x_mlps_pytorch import Feedforwards
self.mlp_predictor = Feedforwards(
dim=config.dim_model,
dim_out=config.reward_bins if config.categorical_rewards else None,
depth=config.mlp_predictor_depth
)
# Optional first-frame learned bias to discourage position cheating
self.first_frame_bias = (
nn.Parameter(torch.zeros(1, 1, config.dim_model))
if config.use_first_frame_positional_bias
else None
# HLGauss layer or plain regression
self.hl_gauss_layer = HLGaussLayer(
dim=config.dim_model,
use_regression=not config.use_hl_gauss_loss,
hl_gauss_loss=dict(
min_value=config.reward_min_value,
max_value=config.reward_max_value,
num_bins=config.reward_hl_gauss_loss_num_bins,
) if config.use_hl_gauss_loss else None
)
# Temporal aggregator: causal transformer over time with language cross-attention
self.temporal = TemporalCausalTransformer(
dim_model=config.dim_model,
n_heads=config.n_heads,
n_layers=config.n_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
pre_norm=config.pre_norm,
)
# Reward head with proper initialization
head_linear = nn.Linear(config.dim_model, 1)
# Initialize with small weights and bias to output values around 0
nn.init.normal_(head_linear.weight, mean=0.0, std=0.02)
nn.init.constant_(head_linear.bias, 0.0) # Start with 0 bias, sigmoid(0) = 0.5
head_layers: list[nn.Module] = [head_linear]
if config.use_tanh_head:
head_layers.append(nn.Tanh())
self.head = nn.Sequential(*head_layers)
# Simple frame dropout probability
self.frame_dropout_p = config.frame_dropout_p
self.stride = max(1, config.stride)
@@ -182,7 +202,7 @@ class RLearNPolicy(PreTrainedPolicy):
@torch.no_grad()
def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict per-timestep rewards for evaluation.
"""Predict per-timestep rewards for evaluation using ReWiND architecture.
Args:
batch: Input batch with OBS_IMAGES and optionally OBS_LANGUAGE
@@ -190,83 +210,74 @@ class RLearNPolicy(PreTrainedPolicy):
Returns:
Predicted rewards tensor of shape (B, T)
"""
batch = self.normalize_inputs(batch)
# Extract frames and form (B, T, C, H, W), padding if needed
# Extract frames and form (B, T, C, H, W)
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
B, T, C, H, W = frames.shape
# Apply stride (no dropout during eval)
idx = torch.arange(0, T, self.stride, device=frames.device)
frames = frames[:, idx]
B, T_eff, C, H, W = frames.shape # NEW: effective length after stride
T_eff = frames.shape[1]
# Encode language using sentence-transformers
lang_emb = encode_language(
batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
)
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
lang_emb = self.text_proj(lang_emb) # (B, D)
# Get language commands
commands = batch.get(OBS_LANGUAGE, None)
if commands is None:
commands = [""] * B
elif not isinstance(commands, list):
commands = [str(commands)] * B
# Process frames with DINOv2
# Flatten (B, T_eff, C, H, W) -> (BT, C, H, W)
BT = B * T_eff
flat = frames.reshape(BT, C, H, W)
# Convert to list of PIL images or numpy arrays for the processor
# DINOv2 processor expects images in HWC format
images_list = []
for i in range(BT):
img = flat[i] # (C, H, W)
# Convert to HWC format
img = img.permute(1, 2, 0) # (H, W, C)
# Convert to numpy if needed
if img.dtype == torch.uint8:
img = img.cpu().numpy()
else:
# Convert to uint8 range
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
images_list.append(img)
# Forward through ReWiND model (inference mode)
device = next(self.parameters()).device
frames = frames.to(device)
# Process with DINOv2 processor
processed = self.vision_processor(images=images_list, return_tensors="pt")
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
# Encode frames through DINOv2
vision_outputs = self.vision_encoder(pixel_values)
# Extract CLS tokens for temporal modeling
# DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
# The CLS token is the first token
if hasattr(vision_outputs, "last_hidden_state"):
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
else:
raise RuntimeError("Vision encoder must output last_hidden_state")
# Project CLS tokens for temporal sequence
visual_seq = self.visual_proj(cls_tokens).reshape(B, T_eff, self.config.dim_model) # (B, T', D)
# Add temporal positional encodings and optional first-frame bias
pe = (
self.positional_encoding[: visual_seq.shape[1]]
.unsqueeze(0)
.to(visual_seq.dtype)
.to(visual_seq.device)
# Process video frames
video_embeds = self._encode_video_frames(frames) # (B, T, D_vision)
# Language embeddings
lang_embeds = self.text_encoder.encode(
commands,
output_value='token_embeddings',
convert_to_tensor=True,
device=device
)
visual_seq = visual_seq + pe
if self.first_frame_bias is not None:
visual_seq = visual_seq.clone()
visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias
# Temporal model with cross-attention to language
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
values = self.head(temporal_features).squeeze(-1) # (B, T')
return values
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
mask = self._mask_from_lens(lens)
# Register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
# Project embeddings
lang_tokens = self.to_lang_tokens(lang_embeds)
video_tokens = self.to_video_tokens(video_embeds)
# Add first frame positional embedding
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
# Pack all tokens for attention
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
# Extend mask for register and video tokens
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
# Forward through decoder
attended = self.decoder(tokens, mask=mask)
# Unpack and get video token features
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
# MLP predictor
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
# Get rewards via HLGauss layer
if self.categorical_rewards:
return video_frame_embeds # Return logits directly
else:
return self.hl_gauss_layer(video_frame_embeds).squeeze(-1) # (B, T)
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# Initial version: no-op; rely on upstream processors if any
@@ -276,6 +287,41 @@ class RLearNPolicy(PreTrainedPolicy):
# Initial version: no-op
return batch
def _encode_video_frames(self, frames: Tensor) -> Tensor:
"""Encode video frames through DINOv2 to get per-frame embeddings.
Args:
frames: (B, T, C, H, W)
Returns:
(B, T, D_vision)
"""
B, T, C, H, W = frames.shape
flat = rearrange(frames, 'b t c h w -> (b t) c h w')
# Process with DINOv2
images_list = []
for i in range(B * T):
img = flat[i].permute(1, 2, 0) # CHW -> HWC
if img.dtype == torch.uint8:
img = img.cpu().numpy()
else:
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
images_list.append(img)
processed = self.vision_processor(images=images_list, return_tensors="pt")
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
vision_outputs = self.vision_encoder(pixel_values)
# Extract CLS tokens
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
return rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
def _mask_from_lens(self, lens: Tensor) -> Tensor:
"""Create mask from sequence lengths."""
seq = torch.arange(lens.amax().item(), device=lens.device)
return einx.less('n, b -> b n', seq, lens)
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Compute ReWiND training loss with on-the-fly progress label generation.
@@ -289,18 +335,22 @@ class RLearNPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract frames and form (B, T, C, H, W), padding if needed
# Extract frames and form (B, T, C, H, W)
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
B, T, C, H, W = frames.shape
device = next(self.parameters()).device
frames = frames.to(device)
# Apply video rewinding augmentation during training
augmented_target = None
if self.training and self.config.use_video_rewind:
frames, augmented_target = apply_video_rewind(frames, rewind_prob=self.config.rewind_prob)
# Use augmented progress labels if rewinding was applied
if REWARD in batch:
target = augmented_target
frames, augmented_target = apply_video_rewind(
frames,
rewind_prob=self.config.rewind_prob,
last3_prob=getattr(self.config, "rewind_last3_prob", None),
)
# Apply stride and frame dropout during training
# Apply stride and frame dropout
idx = torch.arange(0, T, self.stride, device=frames.device)
if self.training and self.frame_dropout_p > 0.0 and T > 1:
mask = torch.rand_like(idx.float()) > self.frame_dropout_p
@@ -308,69 +358,55 @@ class RLearNPolicy(PreTrainedPolicy):
if idx.numel() == 0:
idx = torch.tensor([0], device=frames.device)
frames = frames[:, idx]
T_eff = frames.shape[1]
# Encode language using sentence-transformers
lang_emb = encode_language(
batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
)
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
lang_emb = self.text_proj(lang_emb) # (B, D)
# Get language commands
commands = batch.get(OBS_LANGUAGE, None)
if commands is None:
commands = [""] * B
elif not isinstance(commands, list):
commands = [str(commands)] * B
# Encode frames through DINOv2 visual encoder
# Flatten time for batched encode
BT = B * frames.shape[1]
flat = frames.reshape(BT, C, H, W)
# Convert to list of PIL images or numpy arrays for the processor
# DINOv2 processor expects images in HWC format
images_list = []
for i in range(BT):
img = flat[i] # (C, H, W)
# Convert to HWC format
img = img.permute(1, 2, 0) # (H, W, C)
# Convert to numpy if needed
if img.dtype == torch.uint8:
img = img.cpu().numpy()
else:
# Convert to uint8 range
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
images_list.append(img)
# Process video frames through DINOv2
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision)
# Process with DINOv2 processor
processed = self.vision_processor(images=images_list, return_tensors="pt")
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
# Encode through DINOv2 model
vision_outputs = self.vision_encoder(pixel_values)
# Extract CLS token for temporal modeling
# DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
if hasattr(vision_outputs, "last_hidden_state"):
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token
else:
raise RuntimeError("Vision encoder must output last_hidden_state")
# Project CLS tokens for temporal sequence
visual_seq = self.visual_proj(cls_tokens).reshape(B, -1, self.config.dim_model) # (B, T', D)
# Add temporal positional encodings and optional first-frame bias
pe = (
self.positional_encoding[: visual_seq.shape[1]]
.unsqueeze(0)
.to(visual_seq.dtype)
.to(visual_seq.device)
# Language embeddings
lang_embeds = self.text_encoder.encode(
commands,
output_value='token_embeddings',
convert_to_tensor=True,
device=device
)
visual_seq = visual_seq + pe
if self.first_frame_bias is not None:
visual_seq = visual_seq.clone()
visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias
# Temporal model with cross-attention to language
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
values = self.head(temporal_features).squeeze(-1) # (B, T')
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
mask = self._mask_from_lens(lens)
# Register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
# Project embeddings
lang_tokens = self.to_lang_tokens(lang_embeds)
video_tokens = self.to_video_tokens(video_embeds)
# Add first frame positional embedding
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
# Pack all tokens for attention [lang | register | video]
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
# Extend mask for register and video tokens
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
# Forward through x_transformers Decoder
attended = self.decoder(tokens, mask=mask)
# Unpack and get video token features
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
# MLP predictor
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
# Generate progress labels on-the-fly (ReWiND approach)
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
@@ -451,149 +487,78 @@ class RLearNPolicy(PreTrainedPolicy):
# During inference, we might not want to compute loss
if not self.training and target is None:
loss = values.mean() * 0.0
loss_dict["has_labels"] = 0.0
return loss, {**loss_dict, "values_mean": values.mean().item()}
# Return predictions without loss
if self.categorical_rewards:
return video_frame_embeds.mean() * 0.0, {"has_labels": 0.0}
else:
rewards = self.hl_gauss_layer(video_frame_embeds)
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
# ReWiND Loss (following the paper exactly)
# The core loss is progress regression with video rewinding augmentation
# Calculate loss using HLGauss or categorical
if self.categorical_rewards:
# Categorical cross-entropy loss
assert target.dtype in (torch.long, torch.int), "Categorical rewards require integer targets"
loss = F.cross_entropy(
rearrange(video_frame_embeds, 'b t l -> b l t'),
target.long(),
ignore_index=-1
)
else:
# HLGauss loss or MSE regression
assert target.dtype == torch.float, "Continuous rewards require float targets"
# Create video mask for variable length support
video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device)
loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask)
# 1) Main progress regression loss for matched sequences
# Target should be normalized progress from 0 to 1 (t/T)
L_progress = F.mse_loss(values, target)
# Optional: Mismatched video-language pairs loss
L_mismatch = torch.zeros((), device=device)
if self.training and self.config.use_mismatch_loss and B > 1:
if torch.rand(1, device=device).item() < getattr(self.config, "mismatch_prob", 0.2):
# Shuffle language within batch
shuffled_indices = torch.randperm(B, device=device)
shuffled_commands = [commands[i] for i in shuffled_indices]
# Re-encode with mismatched language
lang_embeds_mm = self.text_encoder.encode(
shuffled_commands,
output_value='token_embeddings',
convert_to_tensor=True,
device=device
)
lang_embeds_mm = pad_sequence(lang_embeds_mm, batch_first=True).to(device)
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
# Pack and forward
tokens_mm, _ = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
attended_mm = self.decoder(tokens_mm, mask=mask)
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape, 'b * d')
mismatch_embeds = self.mlp_predictor(attended_video_mm)
# Mismatched pairs should predict zero progress
zeros_target = torch.zeros_like(target[:, :T_eff])
if self.categorical_rewards:
L_mismatch = F.cross_entropy(
rearrange(mismatch_embeds, 'b t l -> b l t'),
zeros_target.long(),
ignore_index=-1
)
else:
L_mismatch = self.hl_gauss_layer(mismatch_embeds, zeros_target, mask=video_mask)
# 2) Mismatched video-language pairs should predict zero progress
L_mismatch = torch.zeros((), device=values.device)
if self.training and self.config.use_mismatch_loss and values.size(0) > 1:
# Randomly shuffle language instructions within the batch
shuffled_indices = torch.randperm(B, device=values.device)
lang_mismatch = lang_emb[shuffled_indices]
# Forward pass with mismatched language
mismatch_feat = self.temporal(visual_seq, lang_mismatch, return_features=True)
mismatch_values = self.head(mismatch_feat).squeeze(-1)
# Mismatched pairs should predict zero progress
L_mismatch = F.mse_loss(mismatch_values, torch.zeros_like(target))
# Total loss is just progress regression (rewinding is handled via data augmentation)
loss = L_progress + L_mismatch
# Total loss
total_loss = loss + L_mismatch
# Log individual loss components
loss_dict.update(
{
"loss_progress": L_progress.item(),
"loss_mismatch": L_mismatch.item(),
}
)
loss_dict.update({
"loss": total_loss.item(),
"loss_main": loss.item(),
"loss_mismatch": L_mismatch.item(),
})
loss_dict["loss"] = loss.item()
loss_dict["values_mean"] = values.mean().item()
return loss, loss_dict
return total_loss, loss_dict
class TemporalCausalTransformer(nn.Module):
def __init__(
self,
dim_model: int,
n_heads: int,
n_layers: int,
dim_feedforward: int,
dropout: float,
pre_norm: bool,
):
super().__init__()
self.layers = nn.ModuleList(
[
TemporalCausalTransformerLayer(dim_model, n_heads, dim_feedforward, dropout, pre_norm)
for _ in range(n_layers)
]
)
self.norm = nn.LayerNorm(dim_model)
self.head = nn.Linear(dim_model, 1)
def forward(self, x: Tensor, lang_emb: Tensor, return_features: bool = False) -> Tensor:
# x: (B, T, D), lang_emb: (B, D)
B, T, D = x.shape
# Prepare language as a single token for cross-attention context
lang_token = lang_emb.unsqueeze(1) # (B, 1, D)
x = x.transpose(0, 1) # (T, B, D)
lang_token = lang_token.transpose(0, 1) # (1, B, D)
causal_mask = generate_causal_mask(T, device=x.device)
for layer in self.layers:
x = layer(x, lang_token, causal_mask)
x = self.norm(x)
x = x.transpose(0, 1) # (B, T, D)
if return_features:
return x
return self.head(x) # (B, T, 1)
class TemporalCausalTransformerLayer(nn.Module):
def __init__(self, dim_model: int, n_heads: int, dim_feedforward: int, dropout: float, pre_norm: bool):
super().__init__()
self.self_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)
self.cross_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)
self.linear1 = nn.Linear(dim_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, dim_model)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(dim_model)
self.norm2 = nn.LayerNorm(dim_model)
self.norm3 = nn.LayerNorm(dim_model)
self.activation = F.gelu
self.pre_norm = pre_norm
def forward(self, x: Tensor, lang_token: Tensor, causal_mask: Tensor) -> Tensor:
# Self-attention with causal mask
residual = x
if self.pre_norm:
x = self.norm1(x)
x = self.self_attn(x, x, x, attn_mask=causal_mask)[0]
x = residual + self.dropout1(x)
if not self.pre_norm:
x = self.norm1(x)
# Cross-attention to language token (keys/values from language, queries are time tokens)
residual = x
if self.pre_norm:
x = self.norm2(x)
# Broadcast language token across time
T = x.shape[0]
lang_kv = lang_token.expand(1, x.shape[1], x.shape[2]) # (1, B, D)
x = self.cross_attn(x, lang_kv, lang_kv)[0]
x = residual + self.dropout2(x)
if not self.pre_norm:
x = self.norm2(x)
# Feed-forward
residual = x
if self.pre_norm:
x = self.norm3(x)
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
x = residual + self.dropout3(x)
if not self.pre_norm:
x = self.norm3(x)
return x
def create_sinusoidal_pos_encoding(max_len: int, dim: int) -> Tensor:
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # (L, 1)
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) # (D/2)
pe = torch.zeros(max_len, dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe # (L, D)
def generate_causal_mask(T: int, device=None) -> Tensor:
# (T, T) with True where masking should occur for MultiheadAttention expects float mask or bool?
mask = torch.full((T, T), float("-inf"), device=device)
mask = torch.triu(mask, diagonal=1)
return mask
# Helper functions for ReWiND architecture
def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None) -> Tensor:
@@ -669,28 +634,10 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
return frames
def encode_language(
language_input: Tensor | list | str | None, text_encoder, batch_size: int
) -> Tensor:
"""Encode language using sentence-transformers (ReWiND paper setup)."""
# language_input can be: list[str] length B, or None
if language_input is None:
texts = [""] * batch_size
elif isinstance(language_input, list):
texts = language_input
else:
# Single string for the batch
texts = [str(language_input)] * batch_size
# For sentence-transformers, we can directly encode
# Returns tensor of shape (batch_size, embedding_dim)
device = next(iter(text_encoder.parameters())).device if hasattr(text_encoder, 'parameters') else 'cpu'
embeddings = text_encoder.encode(texts, convert_to_tensor=True, device=device)
return embeddings
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]:
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]:
"""Apply video rewinding augmentation as described in ReWiND paper.
Each video in the batch has an independent chance of being rewound.
@@ -726,8 +673,11 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor
# Split point i: between frame 2 and T-1
i = torch.randint(2, T, (1,)).item()
# Rewind length k: between 1 and i-1 frames
k = torch.randint(1, min(i, T - i + 1), (1,)).item()
# Rewind length k: between 1 and i-1 frames, with option to force last-3 frames occasionally
if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3:
k = min(3, i - 1)
else:
k = torch.randint(1, min(i, T - i + 1), (1,)).item()
# Create rewound sequence: o1...oi, oi-1, ..., oi-k
forward_frames = frames[b, :i] # Frames up to split point
@@ -761,4 +711,4 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor
augmented_frames.append(rewound_seq)
augmented_progress.append(rewound_progress)
return torch.stack(augmented_frames), torch.stack(augmented_progress)
return torch.stack(augmented_frames), torch.stack(augmented_progress)

View File

@@ -75,10 +75,9 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
- Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x]
- Try different losses
- Only rewind loss [x]
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
- check code is same as rewind repo code (architecture and trainign details) []
- Test only rewind loss (evaluate) []
- Check rewind implementation by hand/cleanup []
- Only vlc loss then eval []
- Vlc + Rewind loss then eval []
- Cleanup code []