mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
exactly as rewind code
This commit is contained in:
@@ -66,8 +66,10 @@ class RLearNConfig(PreTrainedConfig):
|
||||
|
||||
# ReWiND-specific parameters
|
||||
use_video_rewind: bool = True # Enable video rewinding augmentation
|
||||
rewind_prob: float = 0.5 # Probability of applying rewind to each batch
|
||||
rewind_prob: float = 0.8 # Probability of applying rewind to each sample (paper: ~80%)
|
||||
rewind_last3_prob: float = 0.1 # Of the rewinds, 10% only rewind the last 3 frames
|
||||
use_mismatch_loss: bool = True # Enable mismatched language-video loss
|
||||
mismatch_prob: float = 0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%)
|
||||
|
||||
# Loss hyperparameters (simplified for ReWiND)
|
||||
# The main loss is just MSE between predicted and target progress
|
||||
@@ -80,6 +82,18 @@ class RLearNConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# Architectural knobs to better mirror ReWiND
|
||||
num_register_tokens: int = 4
|
||||
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head
|
||||
|
||||
# HLGauss loss parameters
|
||||
use_hl_gauss_loss: bool = True
|
||||
reward_min_value: float = 0.0
|
||||
reward_max_value: float = 1.0
|
||||
reward_hl_gauss_loss_num_bins: int = 20
|
||||
categorical_rewards: bool = False
|
||||
reward_bins: int = 10 # only used if categorical_rewards=True
|
||||
|
||||
def validate_features(self) -> None:
|
||||
# Require at least one image feature. Language is recommended but optional (can be blank).
|
||||
if not self.image_features:
|
||||
|
||||
@@ -76,10 +76,24 @@ Notes
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# ReWiND dependencies
|
||||
try:
|
||||
from x_transformers import Decoder
|
||||
from hl_gauss_pytorch import HLGaussLayer
|
||||
import einx
|
||||
from einops import rearrange, repeat, pack, unpack
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ReWiND dependencies not installed. Please install: "
|
||||
"pip install x-transformers hl-gauss-pytorch einx einops"
|
||||
) from e
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -87,12 +101,12 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
|
||||
|
||||
class RLearNPolicy(PreTrainedPolicy):
|
||||
"""Video-language conditioned reward model.
|
||||
"""Video-language conditioned reward model following ReWiND architecture exactly: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
|
||||
|
||||
- Visual encoder: frozen DINOv2 (base), returns per-frame embeddings.
|
||||
- Text encoder: frozen sentence-transformers (all-MiniLM-L12-v2), returns a language embedding.
|
||||
- Temporal module: causal transformer over time that cross-attends to language embedding.
|
||||
- Output: per-timestep reward logits; trainable small head.
|
||||
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
|
||||
- Output: per-timestep rewards via HLGauss layer or categorical bins.
|
||||
"""
|
||||
|
||||
config_class = RLearNConfig
|
||||
@@ -102,6 +116,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
|
||||
self.categorical_rewards = config.categorical_rewards
|
||||
|
||||
# Encoders - ReWiND paper setup: DINOv2 for vision, sentence-transformers for text
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
@@ -124,43 +139,48 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
for p in self.text_encoder.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# x_transformers Decoder (matching ReWiND exactly)
|
||||
self.decoder = Decoder(
|
||||
dim=config.dim_model,
|
||||
depth=config.n_layers,
|
||||
heads=config.n_heads,
|
||||
attn_dim_head=64, # ReWiND default
|
||||
ff_mult=config.dim_feedforward // config.dim_model, # Convert to multiplier
|
||||
# Note: x_transformers uses attn_dropout and ff_dropout separately
|
||||
attn_dropout=config.dropout,
|
||||
ff_dropout=config.dropout,
|
||||
)
|
||||
|
||||
# Linear projections to the shared temporal model dimension
|
||||
self.visual_proj = nn.Linear(self.vision_hidden, config.dim_model)
|
||||
self.text_proj = nn.Linear(self.text_hidden, config.dim_model)
|
||||
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
|
||||
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
|
||||
|
||||
# Positional encodings over time
|
||||
self.register_buffer(
|
||||
"positional_encoding",
|
||||
create_sinusoidal_pos_encoding(config.max_seq_len, config.dim_model),
|
||||
persistent=False,
|
||||
# Only first frame gets a positional embed (no cheating on progress)
|
||||
self.first_pos_emb = nn.Parameter(torch.randn(config.dim_model) * 1e-2)
|
||||
|
||||
# Register / memory / attention sink tokens
|
||||
self.num_register_tokens = config.num_register_tokens
|
||||
self.register_tokens = nn.Parameter(torch.randn(config.num_register_tokens, config.dim_model) * 1e-2)
|
||||
|
||||
# MLP predictor (matching ReWiND's Feedforwards)
|
||||
from x_mlps_pytorch import Feedforwards
|
||||
self.mlp_predictor = Feedforwards(
|
||||
dim=config.dim_model,
|
||||
dim_out=config.reward_bins if config.categorical_rewards else None,
|
||||
depth=config.mlp_predictor_depth
|
||||
)
|
||||
# Optional first-frame learned bias to discourage position cheating
|
||||
self.first_frame_bias = (
|
||||
nn.Parameter(torch.zeros(1, 1, config.dim_model))
|
||||
if config.use_first_frame_positional_bias
|
||||
else None
|
||||
|
||||
# HLGauss layer or plain regression
|
||||
self.hl_gauss_layer = HLGaussLayer(
|
||||
dim=config.dim_model,
|
||||
use_regression=not config.use_hl_gauss_loss,
|
||||
hl_gauss_loss=dict(
|
||||
min_value=config.reward_min_value,
|
||||
max_value=config.reward_max_value,
|
||||
num_bins=config.reward_hl_gauss_loss_num_bins,
|
||||
) if config.use_hl_gauss_loss else None
|
||||
)
|
||||
|
||||
# Temporal aggregator: causal transformer over time with language cross-attention
|
||||
self.temporal = TemporalCausalTransformer(
|
||||
dim_model=config.dim_model,
|
||||
n_heads=config.n_heads,
|
||||
n_layers=config.n_layers,
|
||||
dim_feedforward=config.dim_feedforward,
|
||||
dropout=config.dropout,
|
||||
pre_norm=config.pre_norm,
|
||||
)
|
||||
|
||||
# Reward head with proper initialization
|
||||
head_linear = nn.Linear(config.dim_model, 1)
|
||||
# Initialize with small weights and bias to output values around 0
|
||||
nn.init.normal_(head_linear.weight, mean=0.0, std=0.02)
|
||||
nn.init.constant_(head_linear.bias, 0.0) # Start with 0 bias, sigmoid(0) = 0.5
|
||||
|
||||
head_layers: list[nn.Module] = [head_linear]
|
||||
if config.use_tanh_head:
|
||||
head_layers.append(nn.Tanh())
|
||||
self.head = nn.Sequential(*head_layers)
|
||||
|
||||
# Simple frame dropout probability
|
||||
self.frame_dropout_p = config.frame_dropout_p
|
||||
self.stride = max(1, config.stride)
|
||||
@@ -182,7 +202,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict per-timestep rewards for evaluation.
|
||||
"""Predict per-timestep rewards for evaluation using ReWiND architecture.
|
||||
|
||||
Args:
|
||||
batch: Input batch with OBS_IMAGES and optionally OBS_LANGUAGE
|
||||
@@ -190,83 +210,74 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
Returns:
|
||||
Predicted rewards tensor of shape (B, T)
|
||||
"""
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Extract frames and form (B, T, C, H, W), padding if needed
|
||||
# Extract frames and form (B, T, C, H, W)
|
||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||
B, T, C, H, W = frames.shape
|
||||
|
||||
# Apply stride (no dropout during eval)
|
||||
idx = torch.arange(0, T, self.stride, device=frames.device)
|
||||
frames = frames[:, idx]
|
||||
B, T_eff, C, H, W = frames.shape # NEW: effective length after stride
|
||||
T_eff = frames.shape[1]
|
||||
|
||||
# Encode language using sentence-transformers
|
||||
lang_emb = encode_language(
|
||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
|
||||
)
|
||||
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
|
||||
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
|
||||
lang_emb = self.text_proj(lang_emb) # (B, D)
|
||||
# Get language commands
|
||||
commands = batch.get(OBS_LANGUAGE, None)
|
||||
if commands is None:
|
||||
commands = [""] * B
|
||||
elif not isinstance(commands, list):
|
||||
commands = [str(commands)] * B
|
||||
|
||||
# Process frames with DINOv2
|
||||
# Flatten (B, T_eff, C, H, W) -> (BT, C, H, W)
|
||||
BT = B * T_eff
|
||||
flat = frames.reshape(BT, C, H, W)
|
||||
|
||||
# Convert to list of PIL images or numpy arrays for the processor
|
||||
# DINOv2 processor expects images in HWC format
|
||||
images_list = []
|
||||
for i in range(BT):
|
||||
img = flat[i] # (C, H, W)
|
||||
# Convert to HWC format
|
||||
img = img.permute(1, 2, 0) # (H, W, C)
|
||||
|
||||
# Convert to numpy if needed
|
||||
if img.dtype == torch.uint8:
|
||||
img = img.cpu().numpy()
|
||||
else:
|
||||
# Convert to uint8 range
|
||||
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
||||
|
||||
images_list.append(img)
|
||||
# Forward through ReWiND model (inference mode)
|
||||
device = next(self.parameters()).device
|
||||
frames = frames.to(device)
|
||||
|
||||
# Process with DINOv2 processor
|
||||
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
||||
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
|
||||
# Encode frames through DINOv2
|
||||
vision_outputs = self.vision_encoder(pixel_values)
|
||||
|
||||
# Extract CLS tokens for temporal modeling
|
||||
# DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
|
||||
# The CLS token is the first token
|
||||
if hasattr(vision_outputs, "last_hidden_state"):
|
||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
|
||||
else:
|
||||
raise RuntimeError("Vision encoder must output last_hidden_state")
|
||||
|
||||
# Project CLS tokens for temporal sequence
|
||||
visual_seq = self.visual_proj(cls_tokens).reshape(B, T_eff, self.config.dim_model) # (B, T', D)
|
||||
|
||||
# Add temporal positional encodings and optional first-frame bias
|
||||
pe = (
|
||||
self.positional_encoding[: visual_seq.shape[1]]
|
||||
.unsqueeze(0)
|
||||
.to(visual_seq.dtype)
|
||||
.to(visual_seq.device)
|
||||
# Process video frames
|
||||
video_embeds = self._encode_video_frames(frames) # (B, T, D_vision)
|
||||
|
||||
# Language embeddings
|
||||
lang_embeds = self.text_encoder.encode(
|
||||
commands,
|
||||
output_value='token_embeddings',
|
||||
convert_to_tensor=True,
|
||||
device=device
|
||||
)
|
||||
visual_seq = visual_seq + pe
|
||||
if self.first_frame_bias is not None:
|
||||
visual_seq = visual_seq.clone()
|
||||
visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias
|
||||
|
||||
# Temporal model with cross-attention to language
|
||||
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
|
||||
values = self.head(temporal_features).squeeze(-1) # (B, T')
|
||||
|
||||
return values
|
||||
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
||||
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
|
||||
mask = self._mask_from_lens(lens)
|
||||
|
||||
# Register tokens
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
|
||||
|
||||
# Project embeddings
|
||||
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||
video_tokens = self.to_video_tokens(video_embeds)
|
||||
|
||||
# Add first frame positional embedding
|
||||
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
|
||||
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
|
||||
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
|
||||
|
||||
# Pack all tokens for attention
|
||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||
|
||||
# Extend mask for register and video tokens
|
||||
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||
|
||||
# Forward through decoder
|
||||
attended = self.decoder(tokens, mask=mask)
|
||||
|
||||
# Unpack and get video token features
|
||||
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
|
||||
|
||||
# MLP predictor
|
||||
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
|
||||
|
||||
# Get rewards via HLGauss layer
|
||||
if self.categorical_rewards:
|
||||
return video_frame_embeds # Return logits directly
|
||||
else:
|
||||
return self.hl_gauss_layer(video_frame_embeds).squeeze(-1) # (B, T)
|
||||
|
||||
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# Initial version: no-op; rely on upstream processors if any
|
||||
@@ -276,6 +287,41 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Initial version: no-op
|
||||
return batch
|
||||
|
||||
def _encode_video_frames(self, frames: Tensor) -> Tensor:
|
||||
"""Encode video frames through DINOv2 to get per-frame embeddings.
|
||||
|
||||
Args:
|
||||
frames: (B, T, C, H, W)
|
||||
|
||||
Returns:
|
||||
(B, T, D_vision)
|
||||
"""
|
||||
B, T, C, H, W = frames.shape
|
||||
flat = rearrange(frames, 'b t c h w -> (b t) c h w')
|
||||
|
||||
# Process with DINOv2
|
||||
images_list = []
|
||||
for i in range(B * T):
|
||||
img = flat[i].permute(1, 2, 0) # CHW -> HWC
|
||||
if img.dtype == torch.uint8:
|
||||
img = img.cpu().numpy()
|
||||
else:
|
||||
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
||||
images_list.append(img)
|
||||
|
||||
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
||||
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
vision_outputs = self.vision_encoder(pixel_values)
|
||||
|
||||
# Extract CLS tokens
|
||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
|
||||
return rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
|
||||
|
||||
def _mask_from_lens(self, lens: Tensor) -> Tensor:
|
||||
"""Create mask from sequence lengths."""
|
||||
seq = torch.arange(lens.amax().item(), device=lens.device)
|
||||
return einx.less('n, b -> b n', seq, lens)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Compute ReWiND training loss with on-the-fly progress label generation.
|
||||
|
||||
@@ -289,18 +335,22 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract frames and form (B, T, C, H, W), padding if needed
|
||||
# Extract frames and form (B, T, C, H, W)
|
||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||
B, T, C, H, W = frames.shape
|
||||
device = next(self.parameters()).device
|
||||
frames = frames.to(device)
|
||||
|
||||
# Apply video rewinding augmentation during training
|
||||
augmented_target = None
|
||||
if self.training and self.config.use_video_rewind:
|
||||
frames, augmented_target = apply_video_rewind(frames, rewind_prob=self.config.rewind_prob)
|
||||
# Use augmented progress labels if rewinding was applied
|
||||
if REWARD in batch:
|
||||
target = augmented_target
|
||||
frames, augmented_target = apply_video_rewind(
|
||||
frames,
|
||||
rewind_prob=self.config.rewind_prob,
|
||||
last3_prob=getattr(self.config, "rewind_last3_prob", None),
|
||||
)
|
||||
|
||||
# Apply stride and frame dropout during training
|
||||
# Apply stride and frame dropout
|
||||
idx = torch.arange(0, T, self.stride, device=frames.device)
|
||||
if self.training and self.frame_dropout_p > 0.0 and T > 1:
|
||||
mask = torch.rand_like(idx.float()) > self.frame_dropout_p
|
||||
@@ -308,69 +358,55 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
if idx.numel() == 0:
|
||||
idx = torch.tensor([0], device=frames.device)
|
||||
frames = frames[:, idx]
|
||||
T_eff = frames.shape[1]
|
||||
|
||||
# Encode language using sentence-transformers
|
||||
lang_emb = encode_language(
|
||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
|
||||
)
|
||||
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
|
||||
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
|
||||
lang_emb = self.text_proj(lang_emb) # (B, D)
|
||||
# Get language commands
|
||||
commands = batch.get(OBS_LANGUAGE, None)
|
||||
if commands is None:
|
||||
commands = [""] * B
|
||||
elif not isinstance(commands, list):
|
||||
commands = [str(commands)] * B
|
||||
|
||||
# Encode frames through DINOv2 visual encoder
|
||||
# Flatten time for batched encode
|
||||
BT = B * frames.shape[1]
|
||||
flat = frames.reshape(BT, C, H, W)
|
||||
|
||||
# Convert to list of PIL images or numpy arrays for the processor
|
||||
# DINOv2 processor expects images in HWC format
|
||||
images_list = []
|
||||
for i in range(BT):
|
||||
img = flat[i] # (C, H, W)
|
||||
# Convert to HWC format
|
||||
img = img.permute(1, 2, 0) # (H, W, C)
|
||||
|
||||
# Convert to numpy if needed
|
||||
if img.dtype == torch.uint8:
|
||||
img = img.cpu().numpy()
|
||||
else:
|
||||
# Convert to uint8 range
|
||||
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
||||
|
||||
images_list.append(img)
|
||||
# Process video frames through DINOv2
|
||||
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision)
|
||||
|
||||
# Process with DINOv2 processor
|
||||
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
||||
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
|
||||
# Encode through DINOv2 model
|
||||
vision_outputs = self.vision_encoder(pixel_values)
|
||||
|
||||
# Extract CLS token for temporal modeling
|
||||
# DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
|
||||
if hasattr(vision_outputs, "last_hidden_state"):
|
||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token
|
||||
else:
|
||||
raise RuntimeError("Vision encoder must output last_hidden_state")
|
||||
|
||||
# Project CLS tokens for temporal sequence
|
||||
visual_seq = self.visual_proj(cls_tokens).reshape(B, -1, self.config.dim_model) # (B, T', D)
|
||||
|
||||
# Add temporal positional encodings and optional first-frame bias
|
||||
pe = (
|
||||
self.positional_encoding[: visual_seq.shape[1]]
|
||||
.unsqueeze(0)
|
||||
.to(visual_seq.dtype)
|
||||
.to(visual_seq.device)
|
||||
# Language embeddings
|
||||
lang_embeds = self.text_encoder.encode(
|
||||
commands,
|
||||
output_value='token_embeddings',
|
||||
convert_to_tensor=True,
|
||||
device=device
|
||||
)
|
||||
visual_seq = visual_seq + pe
|
||||
if self.first_frame_bias is not None:
|
||||
visual_seq = visual_seq.clone()
|
||||
visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias
|
||||
|
||||
# Temporal model with cross-attention to language
|
||||
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
|
||||
values = self.head(temporal_features).squeeze(-1) # (B, T')
|
||||
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
||||
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
|
||||
mask = self._mask_from_lens(lens)
|
||||
|
||||
# Register tokens
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
|
||||
|
||||
# Project embeddings
|
||||
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||
video_tokens = self.to_video_tokens(video_embeds)
|
||||
|
||||
# Add first frame positional embedding
|
||||
first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:]
|
||||
first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B)
|
||||
video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1)
|
||||
|
||||
# Pack all tokens for attention [lang | register | video]
|
||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||
|
||||
# Extend mask for register and video tokens
|
||||
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||
|
||||
# Forward through x_transformers Decoder
|
||||
attended = self.decoder(tokens, mask=mask)
|
||||
|
||||
# Unpack and get video token features
|
||||
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
|
||||
|
||||
# MLP predictor
|
||||
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
|
||||
|
||||
# Generate progress labels on-the-fly (ReWiND approach)
|
||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||
@@ -451,149 +487,78 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# During inference, we might not want to compute loss
|
||||
if not self.training and target is None:
|
||||
loss = values.mean() * 0.0
|
||||
loss_dict["has_labels"] = 0.0
|
||||
return loss, {**loss_dict, "values_mean": values.mean().item()}
|
||||
# Return predictions without loss
|
||||
if self.categorical_rewards:
|
||||
return video_frame_embeds.mean() * 0.0, {"has_labels": 0.0}
|
||||
else:
|
||||
rewards = self.hl_gauss_layer(video_frame_embeds)
|
||||
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
||||
|
||||
# ReWiND Loss (following the paper exactly)
|
||||
# The core loss is progress regression with video rewinding augmentation
|
||||
# Calculate loss using HLGauss or categorical
|
||||
if self.categorical_rewards:
|
||||
# Categorical cross-entropy loss
|
||||
assert target.dtype in (torch.long, torch.int), "Categorical rewards require integer targets"
|
||||
loss = F.cross_entropy(
|
||||
rearrange(video_frame_embeds, 'b t l -> b l t'),
|
||||
target.long(),
|
||||
ignore_index=-1
|
||||
)
|
||||
else:
|
||||
# HLGauss loss or MSE regression
|
||||
assert target.dtype == torch.float, "Continuous rewards require float targets"
|
||||
# Create video mask for variable length support
|
||||
video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device)
|
||||
loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask)
|
||||
|
||||
# 1) Main progress regression loss for matched sequences
|
||||
# Target should be normalized progress from 0 to 1 (t/T)
|
||||
L_progress = F.mse_loss(values, target)
|
||||
# Optional: Mismatched video-language pairs loss
|
||||
L_mismatch = torch.zeros((), device=device)
|
||||
if self.training and self.config.use_mismatch_loss and B > 1:
|
||||
if torch.rand(1, device=device).item() < getattr(self.config, "mismatch_prob", 0.2):
|
||||
# Shuffle language within batch
|
||||
shuffled_indices = torch.randperm(B, device=device)
|
||||
shuffled_commands = [commands[i] for i in shuffled_indices]
|
||||
|
||||
# Re-encode with mismatched language
|
||||
lang_embeds_mm = self.text_encoder.encode(
|
||||
shuffled_commands,
|
||||
output_value='token_embeddings',
|
||||
convert_to_tensor=True,
|
||||
device=device
|
||||
)
|
||||
lang_embeds_mm = pad_sequence(lang_embeds_mm, batch_first=True).to(device)
|
||||
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
|
||||
|
||||
# Pack and forward
|
||||
tokens_mm, _ = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
|
||||
attended_mm = self.decoder(tokens_mm, mask=mask)
|
||||
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape, 'b * d')
|
||||
mismatch_embeds = self.mlp_predictor(attended_video_mm)
|
||||
|
||||
# Mismatched pairs should predict zero progress
|
||||
zeros_target = torch.zeros_like(target[:, :T_eff])
|
||||
if self.categorical_rewards:
|
||||
L_mismatch = F.cross_entropy(
|
||||
rearrange(mismatch_embeds, 'b t l -> b l t'),
|
||||
zeros_target.long(),
|
||||
ignore_index=-1
|
||||
)
|
||||
else:
|
||||
L_mismatch = self.hl_gauss_layer(mismatch_embeds, zeros_target, mask=video_mask)
|
||||
|
||||
# 2) Mismatched video-language pairs should predict zero progress
|
||||
L_mismatch = torch.zeros((), device=values.device)
|
||||
if self.training and self.config.use_mismatch_loss and values.size(0) > 1:
|
||||
# Randomly shuffle language instructions within the batch
|
||||
shuffled_indices = torch.randperm(B, device=values.device)
|
||||
lang_mismatch = lang_emb[shuffled_indices]
|
||||
|
||||
# Forward pass with mismatched language
|
||||
mismatch_feat = self.temporal(visual_seq, lang_mismatch, return_features=True)
|
||||
mismatch_values = self.head(mismatch_feat).squeeze(-1)
|
||||
|
||||
# Mismatched pairs should predict zero progress
|
||||
L_mismatch = F.mse_loss(mismatch_values, torch.zeros_like(target))
|
||||
|
||||
# Total loss is just progress regression (rewinding is handled via data augmentation)
|
||||
loss = L_progress + L_mismatch
|
||||
# Total loss
|
||||
total_loss = loss + L_mismatch
|
||||
|
||||
# Log individual loss components
|
||||
loss_dict.update(
|
||||
{
|
||||
"loss_progress": L_progress.item(),
|
||||
"loss_mismatch": L_mismatch.item(),
|
||||
}
|
||||
)
|
||||
loss_dict.update({
|
||||
"loss": total_loss.item(),
|
||||
"loss_main": loss.item(),
|
||||
"loss_mismatch": L_mismatch.item(),
|
||||
})
|
||||
|
||||
loss_dict["loss"] = loss.item()
|
||||
loss_dict["values_mean"] = values.mean().item()
|
||||
return loss, loss_dict
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
class TemporalCausalTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
n_heads: int,
|
||||
n_layers: int,
|
||||
dim_feedforward: int,
|
||||
dropout: float,
|
||||
pre_norm: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
TemporalCausalTransformerLayer(dim_model, n_heads, dim_feedforward, dropout, pre_norm)
|
||||
for _ in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm = nn.LayerNorm(dim_model)
|
||||
self.head = nn.Linear(dim_model, 1)
|
||||
|
||||
def forward(self, x: Tensor, lang_emb: Tensor, return_features: bool = False) -> Tensor:
|
||||
# x: (B, T, D), lang_emb: (B, D)
|
||||
B, T, D = x.shape
|
||||
# Prepare language as a single token for cross-attention context
|
||||
lang_token = lang_emb.unsqueeze(1) # (B, 1, D)
|
||||
|
||||
x = x.transpose(0, 1) # (T, B, D)
|
||||
lang_token = lang_token.transpose(0, 1) # (1, B, D)
|
||||
causal_mask = generate_causal_mask(T, device=x.device)
|
||||
for layer in self.layers:
|
||||
x = layer(x, lang_token, causal_mask)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(0, 1) # (B, T, D)
|
||||
if return_features:
|
||||
return x
|
||||
return self.head(x) # (B, T, 1)
|
||||
|
||||
|
||||
class TemporalCausalTransformerLayer(nn.Module):
|
||||
def __init__(self, dim_model: int, n_heads: int, dim_feedforward: int, dropout: float, pre_norm: bool):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)
|
||||
self.cross_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False)
|
||||
self.linear1 = nn.Linear(dim_model, dim_feedforward)
|
||||
self.linear2 = nn.Linear(dim_feedforward, dim_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
self.norm1 = nn.LayerNorm(dim_model)
|
||||
self.norm2 = nn.LayerNorm(dim_model)
|
||||
self.norm3 = nn.LayerNorm(dim_model)
|
||||
self.activation = F.gelu
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
def forward(self, x: Tensor, lang_token: Tensor, causal_mask: Tensor) -> Tensor:
|
||||
# Self-attention with causal mask
|
||||
residual = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
x = self.self_attn(x, x, x, attn_mask=causal_mask)[0]
|
||||
x = residual + self.dropout1(x)
|
||||
if not self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
|
||||
# Cross-attention to language token (keys/values from language, queries are time tokens)
|
||||
residual = x
|
||||
if self.pre_norm:
|
||||
x = self.norm2(x)
|
||||
# Broadcast language token across time
|
||||
T = x.shape[0]
|
||||
lang_kv = lang_token.expand(1, x.shape[1], x.shape[2]) # (1, B, D)
|
||||
x = self.cross_attn(x, lang_kv, lang_kv)[0]
|
||||
x = residual + self.dropout2(x)
|
||||
if not self.pre_norm:
|
||||
x = self.norm2(x)
|
||||
|
||||
# Feed-forward
|
||||
residual = x
|
||||
if self.pre_norm:
|
||||
x = self.norm3(x)
|
||||
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||
x = residual + self.dropout3(x)
|
||||
if not self.pre_norm:
|
||||
x = self.norm3(x)
|
||||
return x
|
||||
|
||||
|
||||
def create_sinusoidal_pos_encoding(max_len: int, dim: int) -> Tensor:
|
||||
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # (L, 1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) # (D/2)
|
||||
pe = torch.zeros(max_len, dim)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
return pe # (L, D)
|
||||
|
||||
|
||||
def generate_causal_mask(T: int, device=None) -> Tensor:
|
||||
# (T, T) with True where masking should occur for MultiheadAttention expects float mask or bool?
|
||||
mask = torch.full((T, T), float("-inf"), device=device)
|
||||
mask = torch.triu(mask, diagonal=1)
|
||||
return mask
|
||||
# Helper functions for ReWiND architecture
|
||||
|
||||
|
||||
def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None) -> Tensor:
|
||||
@@ -669,28 +634,10 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
|
||||
return frames
|
||||
|
||||
|
||||
def encode_language(
|
||||
language_input: Tensor | list | str | None, text_encoder, batch_size: int
|
||||
) -> Tensor:
|
||||
"""Encode language using sentence-transformers (ReWiND paper setup)."""
|
||||
# language_input can be: list[str] length B, or None
|
||||
if language_input is None:
|
||||
texts = [""] * batch_size
|
||||
elif isinstance(language_input, list):
|
||||
texts = language_input
|
||||
else:
|
||||
# Single string for the batch
|
||||
texts = [str(language_input)] * batch_size
|
||||
|
||||
# For sentence-transformers, we can directly encode
|
||||
# Returns tensor of shape (batch_size, embedding_dim)
|
||||
device = next(iter(text_encoder.parameters())).device if hasattr(text_encoder, 'parameters') else 'cpu'
|
||||
embeddings = text_encoder.encode(texts, convert_to_tensor=True, device=device)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]:
|
||||
|
||||
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]:
|
||||
"""Apply video rewinding augmentation as described in ReWiND paper.
|
||||
|
||||
Each video in the batch has an independent chance of being rewound.
|
||||
@@ -726,8 +673,11 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor
|
||||
# Split point i: between frame 2 and T-1
|
||||
i = torch.randint(2, T, (1,)).item()
|
||||
|
||||
# Rewind length k: between 1 and i-1 frames
|
||||
k = torch.randint(1, min(i, T - i + 1), (1,)).item()
|
||||
# Rewind length k: between 1 and i-1 frames, with option to force last-3 frames occasionally
|
||||
if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3:
|
||||
k = min(3, i - 1)
|
||||
else:
|
||||
k = torch.randint(1, min(i, T - i + 1), (1,)).item()
|
||||
|
||||
# Create rewound sequence: o1...oi, oi-1, ..., oi-k
|
||||
forward_frames = frames[b, :i] # Frames up to split point
|
||||
@@ -761,4 +711,4 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor
|
||||
augmented_frames.append(rewound_seq)
|
||||
augmented_progress.append(rewound_progress)
|
||||
|
||||
return torch.stack(augmented_frames), torch.stack(augmented_progress)
|
||||
return torch.stack(augmented_frames), torch.stack(augmented_progress)
|
||||
@@ -75,10 +75,9 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
||||
- Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x]
|
||||
- Try different losses
|
||||
- Only rewind loss [x]
|
||||
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
||||
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
||||
- check code is same as rewind repo code (architecture and trainign details) []
|
||||
- Test only rewind loss (evaluate) []
|
||||
- Check rewind implementation by hand/cleanup []
|
||||
- Only vlc loss then eval []
|
||||
- Vlc + Rewind loss then eval []
|
||||
- Cleanup code []
|
||||
|
||||
Reference in New Issue
Block a user