diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index c8b77ee2e..b167ae50f 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -66,8 +66,10 @@ class RLearNConfig(PreTrainedConfig): # ReWiND-specific parameters use_video_rewind: bool = True # Enable video rewinding augmentation - rewind_prob: float = 0.5 # Probability of applying rewind to each batch + rewind_prob: float = 0.8 # Probability of applying rewind to each sample (paper: ~80%) + rewind_last3_prob: float = 0.1 # Of the rewinds, 10% only rewind the last 3 frames use_mismatch_loss: bool = True # Enable mismatched language-video loss + mismatch_prob: float = 0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%) # Loss hyperparameters (simplified for ReWiND) # The main loss is just MSE between predicted and target progress @@ -80,6 +82,18 @@ class RLearNConfig(PreTrainedConfig): } ) + # Architectural knobs to better mirror ReWiND + num_register_tokens: int = 4 + mlp_predictor_depth: int = 3 # depth of the per-frame MLP head + + # HLGauss loss parameters + use_hl_gauss_loss: bool = True + reward_min_value: float = 0.0 + reward_max_value: float = 1.0 + reward_hl_gauss_loss_num_bins: int = 20 + categorical_rewards: bool = False + reward_bins: int = 10 # only used if categorical_rewards=True + def validate_features(self) -> None: # Require at least one image feature. Language is recommended but optional (can be blank). if not self.image_features: diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 46d2b4819..c108b5896 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -76,10 +76,24 @@ Notes from __future__ import annotations import math +from itertools import chain import torch import torch.nn.functional as F from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence + +# ReWiND dependencies +try: + from x_transformers import Decoder + from hl_gauss_pytorch import HLGaussLayer + import einx + from einops import rearrange, repeat, pack, unpack +except ImportError as e: + raise ImportError( + "ReWiND dependencies not installed. Please install: " + "pip install x-transformers hl-gauss-pytorch einx einops" + ) from e from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD from lerobot.policies.pretrained import PreTrainedPolicy @@ -87,12 +101,12 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig class RLearNPolicy(PreTrainedPolicy): - """Video-language conditioned reward model. + """Video-language conditioned reward model following ReWiND architecture exactly: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11. - Visual encoder: frozen DINOv2 (base), returns per-frame embeddings. - Text encoder: frozen sentence-transformers (all-MiniLM-L12-v2), returns a language embedding. - - Temporal module: causal transformer over time that cross-attends to language embedding. - - Output: per-timestep reward logits; trainable small head. + - Temporal module: x_transformers Decoder with packed tokens [lang | register | video]. + - Output: per-timestep rewards via HLGauss layer or categorical bins. """ config_class = RLearNConfig @@ -102,6 +116,7 @@ class RLearNPolicy(PreTrainedPolicy): super().__init__(config) self.config = config self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation + self.categorical_rewards = config.categorical_rewards # Encoders - ReWiND paper setup: DINOv2 for vision, sentence-transformers for text from transformers import AutoImageProcessor, AutoModel @@ -124,43 +139,48 @@ class RLearNPolicy(PreTrainedPolicy): for p in self.text_encoder.parameters(): p.requires_grad = False + # x_transformers Decoder (matching ReWiND exactly) + self.decoder = Decoder( + dim=config.dim_model, + depth=config.n_layers, + heads=config.n_heads, + attn_dim_head=64, # ReWiND default + ff_mult=config.dim_feedforward // config.dim_model, # Convert to multiplier + # Note: x_transformers uses attn_dropout and ff_dropout separately + attn_dropout=config.dropout, + ff_dropout=config.dropout, + ) + # Linear projections to the shared temporal model dimension - self.visual_proj = nn.Linear(self.vision_hidden, config.dim_model) - self.text_proj = nn.Linear(self.text_hidden, config.dim_model) + self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model) + self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model) - # Positional encodings over time - self.register_buffer( - "positional_encoding", - create_sinusoidal_pos_encoding(config.max_seq_len, config.dim_model), - persistent=False, + # Only first frame gets a positional embed (no cheating on progress) + self.first_pos_emb = nn.Parameter(torch.randn(config.dim_model) * 1e-2) + + # Register / memory / attention sink tokens + self.num_register_tokens = config.num_register_tokens + self.register_tokens = nn.Parameter(torch.randn(config.num_register_tokens, config.dim_model) * 1e-2) + + # MLP predictor (matching ReWiND's Feedforwards) + from x_mlps_pytorch import Feedforwards + self.mlp_predictor = Feedforwards( + dim=config.dim_model, + dim_out=config.reward_bins if config.categorical_rewards else None, + depth=config.mlp_predictor_depth ) - # Optional first-frame learned bias to discourage position cheating - self.first_frame_bias = ( - nn.Parameter(torch.zeros(1, 1, config.dim_model)) - if config.use_first_frame_positional_bias - else None + + # HLGauss layer or plain regression + self.hl_gauss_layer = HLGaussLayer( + dim=config.dim_model, + use_regression=not config.use_hl_gauss_loss, + hl_gauss_loss=dict( + min_value=config.reward_min_value, + max_value=config.reward_max_value, + num_bins=config.reward_hl_gauss_loss_num_bins, + ) if config.use_hl_gauss_loss else None ) - - # Temporal aggregator: causal transformer over time with language cross-attention - self.temporal = TemporalCausalTransformer( - dim_model=config.dim_model, - n_heads=config.n_heads, - n_layers=config.n_layers, - dim_feedforward=config.dim_feedforward, - dropout=config.dropout, - pre_norm=config.pre_norm, - ) - - # Reward head with proper initialization - head_linear = nn.Linear(config.dim_model, 1) - # Initialize with small weights and bias to output values around 0 - nn.init.normal_(head_linear.weight, mean=0.0, std=0.02) - nn.init.constant_(head_linear.bias, 0.0) # Start with 0 bias, sigmoid(0) = 0.5 - - head_layers: list[nn.Module] = [head_linear] - if config.use_tanh_head: - head_layers.append(nn.Tanh()) - self.head = nn.Sequential(*head_layers) + # Simple frame dropout probability self.frame_dropout_p = config.frame_dropout_p self.stride = max(1, config.stride) @@ -182,7 +202,7 @@ class RLearNPolicy(PreTrainedPolicy): @torch.no_grad() def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor: - """Predict per-timestep rewards for evaluation. + """Predict per-timestep rewards for evaluation using ReWiND architecture. Args: batch: Input batch with OBS_IMAGES and optionally OBS_LANGUAGE @@ -190,83 +210,74 @@ class RLearNPolicy(PreTrainedPolicy): Returns: Predicted rewards tensor of shape (B, T) """ - batch = self.normalize_inputs(batch) - # Extract frames and form (B, T, C, H, W), padding if needed + # Extract frames and form (B, T, C, H, W) frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) B, T, C, H, W = frames.shape # Apply stride (no dropout during eval) idx = torch.arange(0, T, self.stride, device=frames.device) frames = frames[:, idx] - B, T_eff, C, H, W = frames.shape # NEW: effective length after stride + T_eff = frames.shape[1] - # Encode language using sentence-transformers - lang_emb = encode_language( - batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B - ) - # Ensure embeddings are normal tensors on the correct device (not inference tensors) - lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device) - lang_emb = self.text_proj(lang_emb) # (B, D) + # Get language commands + commands = batch.get(OBS_LANGUAGE, None) + if commands is None: + commands = [""] * B + elif not isinstance(commands, list): + commands = [str(commands)] * B - # Process frames with DINOv2 - # Flatten (B, T_eff, C, H, W) -> (BT, C, H, W) - BT = B * T_eff - flat = frames.reshape(BT, C, H, W) - - # Convert to list of PIL images or numpy arrays for the processor - # DINOv2 processor expects images in HWC format - images_list = [] - for i in range(BT): - img = flat[i] # (C, H, W) - # Convert to HWC format - img = img.permute(1, 2, 0) # (H, W, C) - - # Convert to numpy if needed - if img.dtype == torch.uint8: - img = img.cpu().numpy() - else: - # Convert to uint8 range - img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() - - images_list.append(img) + # Forward through ReWiND model (inference mode) + device = next(self.parameters()).device + frames = frames.to(device) - # Process with DINOv2 processor - processed = self.vision_processor(images=images_list, return_tensors="pt") - pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device) - - # Encode frames through DINOv2 - vision_outputs = self.vision_encoder(pixel_values) - - # Extract CLS tokens for temporal modeling - # DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size) - # The CLS token is the first token - if hasattr(vision_outputs, "last_hidden_state"): - cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision) - else: - raise RuntimeError("Vision encoder must output last_hidden_state") - - # Project CLS tokens for temporal sequence - visual_seq = self.visual_proj(cls_tokens).reshape(B, T_eff, self.config.dim_model) # (B, T', D) - - # Add temporal positional encodings and optional first-frame bias - pe = ( - self.positional_encoding[: visual_seq.shape[1]] - .unsqueeze(0) - .to(visual_seq.dtype) - .to(visual_seq.device) + # Process video frames + video_embeds = self._encode_video_frames(frames) # (B, T, D_vision) + + # Language embeddings + lang_embeds = self.text_encoder.encode( + commands, + output_value='token_embeddings', + convert_to_tensor=True, + device=device ) - visual_seq = visual_seq + pe - if self.first_frame_bias is not None: - visual_seq = visual_seq.clone() - visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias - - # Temporal model with cross-attention to language - temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D) - values = self.head(temporal_features).squeeze(-1) # (B, T') - - return values + lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device) + lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device) + mask = self._mask_from_lens(lens) + + # Register tokens + register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B) + + # Project embeddings + lang_tokens = self.to_lang_tokens(lang_embeds) + video_tokens = self.to_video_tokens(video_embeds) + + # Add first frame positional embedding + first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:] + first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B) + video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1) + + # Pack all tokens for attention + tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d') + + # Extend mask for register and video tokens + mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) + + # Forward through decoder + attended = self.decoder(tokens, mask=mask) + + # Unpack and get video token features + _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') + + # MLP predictor + video_frame_embeds = self.mlp_predictor(attended_video_tokens) + + # Get rewards via HLGauss layer + if self.categorical_rewards: + return video_frame_embeds # Return logits directly + else: + return self.hl_gauss_layer(video_frame_embeds).squeeze(-1) # (B, T) def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # Initial version: no-op; rely on upstream processors if any @@ -276,6 +287,41 @@ class RLearNPolicy(PreTrainedPolicy): # Initial version: no-op return batch + def _encode_video_frames(self, frames: Tensor) -> Tensor: + """Encode video frames through DINOv2 to get per-frame embeddings. + + Args: + frames: (B, T, C, H, W) + + Returns: + (B, T, D_vision) + """ + B, T, C, H, W = frames.shape + flat = rearrange(frames, 'b t c h w -> (b t) c h w') + + # Process with DINOv2 + images_list = [] + for i in range(B * T): + img = flat[i].permute(1, 2, 0) # CHW -> HWC + if img.dtype == torch.uint8: + img = img.cpu().numpy() + else: + img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() + images_list.append(img) + + processed = self.vision_processor(images=images_list, return_tensors="pt") + pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device) + vision_outputs = self.vision_encoder(pixel_values) + + # Extract CLS tokens + cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision) + return rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T) + + def _mask_from_lens(self, lens: Tensor) -> Tensor: + """Create mask from sequence lengths.""" + seq = torch.arange(lens.amax().item(), device=lens.device) + return einx.less('n, b -> b n', seq, lens) + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Compute ReWiND training loss with on-the-fly progress label generation. @@ -289,18 +335,22 @@ class RLearNPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) - # Extract frames and form (B, T, C, H, W), padding if needed + # Extract frames and form (B, T, C, H, W) frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) B, T, C, H, W = frames.shape + device = next(self.parameters()).device + frames = frames.to(device) # Apply video rewinding augmentation during training + augmented_target = None if self.training and self.config.use_video_rewind: - frames, augmented_target = apply_video_rewind(frames, rewind_prob=self.config.rewind_prob) - # Use augmented progress labels if rewinding was applied - if REWARD in batch: - target = augmented_target + frames, augmented_target = apply_video_rewind( + frames, + rewind_prob=self.config.rewind_prob, + last3_prob=getattr(self.config, "rewind_last3_prob", None), + ) - # Apply stride and frame dropout during training + # Apply stride and frame dropout idx = torch.arange(0, T, self.stride, device=frames.device) if self.training and self.frame_dropout_p > 0.0 and T > 1: mask = torch.rand_like(idx.float()) > self.frame_dropout_p @@ -308,69 +358,55 @@ class RLearNPolicy(PreTrainedPolicy): if idx.numel() == 0: idx = torch.tensor([0], device=frames.device) frames = frames[:, idx] + T_eff = frames.shape[1] - # Encode language using sentence-transformers - lang_emb = encode_language( - batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B - ) - # Ensure embeddings are normal tensors on the correct device (not inference tensors) - lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device) - lang_emb = self.text_proj(lang_emb) # (B, D) + # Get language commands + commands = batch.get(OBS_LANGUAGE, None) + if commands is None: + commands = [""] * B + elif not isinstance(commands, list): + commands = [str(commands)] * B - # Encode frames through DINOv2 visual encoder - # Flatten time for batched encode - BT = B * frames.shape[1] - flat = frames.reshape(BT, C, H, W) - - # Convert to list of PIL images or numpy arrays for the processor - # DINOv2 processor expects images in HWC format - images_list = [] - for i in range(BT): - img = flat[i] # (C, H, W) - # Convert to HWC format - img = img.permute(1, 2, 0) # (H, W, C) - - # Convert to numpy if needed - if img.dtype == torch.uint8: - img = img.cpu().numpy() - else: - # Convert to uint8 range - img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() - - images_list.append(img) + # Process video frames through DINOv2 + video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision) - # Process with DINOv2 processor - processed = self.vision_processor(images=images_list, return_tensors="pt") - pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device) - - # Encode through DINOv2 model - vision_outputs = self.vision_encoder(pixel_values) - - # Extract CLS token for temporal modeling - # DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size) - if hasattr(vision_outputs, "last_hidden_state"): - cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token - else: - raise RuntimeError("Vision encoder must output last_hidden_state") - - # Project CLS tokens for temporal sequence - visual_seq = self.visual_proj(cls_tokens).reshape(B, -1, self.config.dim_model) # (B, T', D) - - # Add temporal positional encodings and optional first-frame bias - pe = ( - self.positional_encoding[: visual_seq.shape[1]] - .unsqueeze(0) - .to(visual_seq.dtype) - .to(visual_seq.device) + # Language embeddings + lang_embeds = self.text_encoder.encode( + commands, + output_value='token_embeddings', + convert_to_tensor=True, + device=device ) - visual_seq = visual_seq + pe - if self.first_frame_bias is not None: - visual_seq = visual_seq.clone() - visual_seq[:, :1] = visual_seq[:, :1] + self.first_frame_bias - - # Temporal model with cross-attention to language - temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D) - values = self.head(temporal_features).squeeze(-1) # (B, T') + lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device) + lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device) + mask = self._mask_from_lens(lens) + + # Register tokens + register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B) + + # Project embeddings + lang_tokens = self.to_lang_tokens(lang_embeds) + video_tokens = self.to_video_tokens(video_embeds) + + # Add first frame positional embedding + first_video_token, rest_video_tokens = video_tokens[:, :1], video_tokens[:, 1:] + first_video_token = first_video_token + repeat(self.first_pos_emb, 'd -> b 1 d', b=B) + video_tokens = torch.cat((first_video_token, rest_video_tokens), dim=1) + + # Pack all tokens for attention [lang | register | video] + tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d') + + # Extend mask for register and video tokens + mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) + + # Forward through x_transformers Decoder + attended = self.decoder(tokens, mask=mask) + + # Unpack and get video token features + _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') + + # MLP predictor + video_frame_embeds = self.mlp_predictor(attended_video_tokens) # Generate progress labels on-the-fly (ReWiND approach) # IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window @@ -451,149 +487,78 @@ class RLearNPolicy(PreTrainedPolicy): # During inference, we might not want to compute loss if not self.training and target is None: - loss = values.mean() * 0.0 - loss_dict["has_labels"] = 0.0 - return loss, {**loss_dict, "values_mean": values.mean().item()} + # Return predictions without loss + if self.categorical_rewards: + return video_frame_embeds.mean() * 0.0, {"has_labels": 0.0} + else: + rewards = self.hl_gauss_layer(video_frame_embeds) + return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} - # ReWiND Loss (following the paper exactly) - # The core loss is progress regression with video rewinding augmentation + # Calculate loss using HLGauss or categorical + if self.categorical_rewards: + # Categorical cross-entropy loss + assert target.dtype in (torch.long, torch.int), "Categorical rewards require integer targets" + loss = F.cross_entropy( + rearrange(video_frame_embeds, 'b t l -> b l t'), + target.long(), + ignore_index=-1 + ) + else: + # HLGauss loss or MSE regression + assert target.dtype == torch.float, "Continuous rewards require float targets" + # Create video mask for variable length support + video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device) + loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask) - # 1) Main progress regression loss for matched sequences - # Target should be normalized progress from 0 to 1 (t/T) - L_progress = F.mse_loss(values, target) + # Optional: Mismatched video-language pairs loss + L_mismatch = torch.zeros((), device=device) + if self.training and self.config.use_mismatch_loss and B > 1: + if torch.rand(1, device=device).item() < getattr(self.config, "mismatch_prob", 0.2): + # Shuffle language within batch + shuffled_indices = torch.randperm(B, device=device) + shuffled_commands = [commands[i] for i in shuffled_indices] + + # Re-encode with mismatched language + lang_embeds_mm = self.text_encoder.encode( + shuffled_commands, + output_value='token_embeddings', + convert_to_tensor=True, + device=device + ) + lang_embeds_mm = pad_sequence(lang_embeds_mm, batch_first=True).to(device) + lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm) + + # Pack and forward + tokens_mm, _ = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d') + attended_mm = self.decoder(tokens_mm, mask=mask) + _, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape, 'b * d') + mismatch_embeds = self.mlp_predictor(attended_video_mm) + + # Mismatched pairs should predict zero progress + zeros_target = torch.zeros_like(target[:, :T_eff]) + if self.categorical_rewards: + L_mismatch = F.cross_entropy( + rearrange(mismatch_embeds, 'b t l -> b l t'), + zeros_target.long(), + ignore_index=-1 + ) + else: + L_mismatch = self.hl_gauss_layer(mismatch_embeds, zeros_target, mask=video_mask) - # 2) Mismatched video-language pairs should predict zero progress - L_mismatch = torch.zeros((), device=values.device) - if self.training and self.config.use_mismatch_loss and values.size(0) > 1: - # Randomly shuffle language instructions within the batch - shuffled_indices = torch.randperm(B, device=values.device) - lang_mismatch = lang_emb[shuffled_indices] - - # Forward pass with mismatched language - mismatch_feat = self.temporal(visual_seq, lang_mismatch, return_features=True) - mismatch_values = self.head(mismatch_feat).squeeze(-1) - - # Mismatched pairs should predict zero progress - L_mismatch = F.mse_loss(mismatch_values, torch.zeros_like(target)) - - # Total loss is just progress regression (rewinding is handled via data augmentation) - loss = L_progress + L_mismatch + # Total loss + total_loss = loss + L_mismatch # Log individual loss components - loss_dict.update( - { - "loss_progress": L_progress.item(), - "loss_mismatch": L_mismatch.item(), - } - ) + loss_dict.update({ + "loss": total_loss.item(), + "loss_main": loss.item(), + "loss_mismatch": L_mismatch.item(), + }) - loss_dict["loss"] = loss.item() - loss_dict["values_mean"] = values.mean().item() - return loss, loss_dict + return total_loss, loss_dict -class TemporalCausalTransformer(nn.Module): - def __init__( - self, - dim_model: int, - n_heads: int, - n_layers: int, - dim_feedforward: int, - dropout: float, - pre_norm: bool, - ): - super().__init__() - self.layers = nn.ModuleList( - [ - TemporalCausalTransformerLayer(dim_model, n_heads, dim_feedforward, dropout, pre_norm) - for _ in range(n_layers) - ] - ) - self.norm = nn.LayerNorm(dim_model) - self.head = nn.Linear(dim_model, 1) - - def forward(self, x: Tensor, lang_emb: Tensor, return_features: bool = False) -> Tensor: - # x: (B, T, D), lang_emb: (B, D) - B, T, D = x.shape - # Prepare language as a single token for cross-attention context - lang_token = lang_emb.unsqueeze(1) # (B, 1, D) - - x = x.transpose(0, 1) # (T, B, D) - lang_token = lang_token.transpose(0, 1) # (1, B, D) - causal_mask = generate_causal_mask(T, device=x.device) - for layer in self.layers: - x = layer(x, lang_token, causal_mask) - x = self.norm(x) - x = x.transpose(0, 1) # (B, T, D) - if return_features: - return x - return self.head(x) # (B, T, 1) - - -class TemporalCausalTransformerLayer(nn.Module): - def __init__(self, dim_model: int, n_heads: int, dim_feedforward: int, dropout: float, pre_norm: bool): - super().__init__() - self.self_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False) - self.cross_attn = nn.MultiheadAttention(dim_model, n_heads, dropout=dropout, batch_first=False) - self.linear1 = nn.Linear(dim_model, dim_feedforward) - self.linear2 = nn.Linear(dim_feedforward, dim_model) - self.dropout = nn.Dropout(dropout) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - self.norm1 = nn.LayerNorm(dim_model) - self.norm2 = nn.LayerNorm(dim_model) - self.norm3 = nn.LayerNorm(dim_model) - self.activation = F.gelu - self.pre_norm = pre_norm - - def forward(self, x: Tensor, lang_token: Tensor, causal_mask: Tensor) -> Tensor: - # Self-attention with causal mask - residual = x - if self.pre_norm: - x = self.norm1(x) - x = self.self_attn(x, x, x, attn_mask=causal_mask)[0] - x = residual + self.dropout1(x) - if not self.pre_norm: - x = self.norm1(x) - - # Cross-attention to language token (keys/values from language, queries are time tokens) - residual = x - if self.pre_norm: - x = self.norm2(x) - # Broadcast language token across time - T = x.shape[0] - lang_kv = lang_token.expand(1, x.shape[1], x.shape[2]) # (1, B, D) - x = self.cross_attn(x, lang_kv, lang_kv)[0] - x = residual + self.dropout2(x) - if not self.pre_norm: - x = self.norm2(x) - - # Feed-forward - residual = x - if self.pre_norm: - x = self.norm3(x) - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - x = residual + self.dropout3(x) - if not self.pre_norm: - x = self.norm3(x) - return x - - -def create_sinusoidal_pos_encoding(max_len: int, dim: int) -> Tensor: - position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # (L, 1) - div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) # (D/2) - pe = torch.zeros(max_len, dim) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - return pe # (L, D) - - -def generate_causal_mask(T: int, device=None) -> Tensor: - # (T, T) with True where masking should occur for MultiheadAttention expects float mask or bool? - mask = torch.full((T, T), float("-inf"), device=device) - mask = torch.triu(mask, diagonal=1) - return mask +# Helper functions for ReWiND architecture def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None) -> Tensor: @@ -669,28 +634,10 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None return frames -def encode_language( - language_input: Tensor | list | str | None, text_encoder, batch_size: int -) -> Tensor: - """Encode language using sentence-transformers (ReWiND paper setup).""" - # language_input can be: list[str] length B, or None - if language_input is None: - texts = [""] * batch_size - elif isinstance(language_input, list): - texts = language_input - else: - # Single string for the batch - texts = [str(language_input)] * batch_size - - # For sentence-transformers, we can directly encode - # Returns tensor of shape (batch_size, embedding_dim) - device = next(iter(text_encoder.parameters())).device if hasattr(text_encoder, 'parameters') else 'cpu' - embeddings = text_encoder.encode(texts, convert_to_tensor=True, device=device) - - return embeddings -def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]: + +def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]: """Apply video rewinding augmentation as described in ReWiND paper. Each video in the batch has an independent chance of being rewound. @@ -726,8 +673,11 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor # Split point i: between frame 2 and T-1 i = torch.randint(2, T, (1,)).item() - # Rewind length k: between 1 and i-1 frames - k = torch.randint(1, min(i, T - i + 1), (1,)).item() + # Rewind length k: between 1 and i-1 frames, with option to force last-3 frames occasionally + if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3: + k = min(3, i - 1) + else: + k = torch.randint(1, min(i, T - i + 1), (1,)).item() # Create rewound sequence: o1...oi, oi-1, ..., oi-k forward_frames = frames[b, :i] # Frames up to split point @@ -761,4 +711,4 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor augmented_frames.append(rewound_seq) augmented_progress.append(rewound_progress) - return torch.stack(augmented_frames), torch.stack(augmented_progress) + return torch.stack(augmented_frames), torch.stack(augmented_progress) \ No newline at end of file diff --git a/src/lerobot/policies/rlearn/rlearn_plan.md b/src/lerobot/policies/rlearn/rlearn_plan.md index 58295dafc..c15df520a 100644 --- a/src/lerobot/policies/rlearn/rlearn_plan.md +++ b/src/lerobot/policies/rlearn/rlearn_plan.md @@ -75,10 +75,9 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/ - Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x] - Try different losses - Only rewind loss [x] + - Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x] - Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x] - - check code is same as rewind repo code (architecture and trainign details) [] - Test only rewind loss (evaluate) [] - - Check rewind implementation by hand/cleanup [] - Only vlc loss then eval [] - Vlc + Rewind loss then eval [] - Cleanup code []