From def71cc439b6cff2311aafae37df1d467272da0e Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 15:20:20 +0200 Subject: [PATCH] change sampling --- .../policies/rlearn/configuration_rlearn.py | 17 +- .../policies/rlearn/modeling_rlearn.py | 587 ++++++++++++------ src/lerobot/scripts/train.py | 9 + 3 files changed, 424 insertions(+), 189 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 914f68a02..d35262201 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -69,12 +69,23 @@ class RLearNConfig(PreTrainedConfig): # ReWiND-specific parameters use_video_rewind: bool = True # Enable video rewinding augmentation - rewind_prob: float = 0.8 # Probability of applying rewind to each sample (paper: ~80%) - rewind_last3_prob: float = 0.1 # Of the rewinds, 10% only rewind the last 3 frames + rewind_prob: float = 0.5 # Reduced from 0.8 to avoid too many artifacts + rewind_last3_prob: float = 0.3 # Increased to favor smaller rewinds use_mismatch_loss: bool = False # Enable mismatched language-video loss mismatch_prob: float = ( 0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%) ) + + # NEW: Loss and head improvements + use_logit_regression: bool = True # Use logit space regression instead of sigmoid+MSE + logit_eps: float = 1e-6 # Clipping epsilon for logit transform: logit(clamp(target, eps, 1-eps)) + head_lr_multiplier: float = 2.0 # Increase head learning rate relative to base + head_weight_init_std: float = 0.05 # Larger head weight initialization for faster positive logits + remove_head_bias_wd: bool = True # Remove weight decay from head bias + + # Window sampling improvements + use_random_anchor_sampling: bool = True # Use explicit random anchor sampling during training + use_window_relative_progress: bool = True # Use window-relative progress (0-1 across window) instead of episode-relative # Loss hyperparameters (simplified for ReWiND) # The main loss is just MSE between predicted and target progress @@ -91,7 +102,7 @@ class RLearNConfig(PreTrainedConfig): num_register_tokens: int = 4 # register / memory tokens, can't hurt mlp_predictor_depth: int = 3 # depth of the per-frame MLP head - # Simple MSE regression loss (no binning) + # Loss configuration - supports both sigmoid+MSE and logit regression # Evaluation visualization parameters enable_eval_visualizations: bool = False # Enable reward evaluation visualizations during training diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 60d141ebc..fb9c1ea6b 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -190,16 +190,21 @@ class RLearNPolicy(PreTrainedPolicy): # Layer normalization before reward head to stabilize MLP outputs self.pre_reward_norm = nn.LayerNorm(config.dim_model) - # MSE regression head with sigmoid activation to bound outputs to [0,1] + # Regression head - supports both logit and sigmoid modes self.reward_head = nn.Linear(config.dim_model, 1) - # Initialize with small weights to prevent sigmoid saturation - # Target: sigmoid(0) = 0.5, so we want raw logits around [-2, 2] range + # Initialize head with improved settings with torch.no_grad(): - self.reward_head.weight.normal_(0.0, 0.02) # Small but not tiny - self.reward_head.bias.fill_(0.0) # Start at sigmoid(0) = 0.5 + if config.use_logit_regression: + # Logit regression: can use larger weights since no saturation issues + self.reward_head.weight.normal_(0.0, config.head_weight_init_std) + self.reward_head.bias.fill_(0.0) # Neutral start in logit space + else: + # Sigmoid mode: moderate initialization + self.reward_head.weight.normal_(0.0, 0.02) + self.reward_head.bias.fill_(0.0) - self.sigmoid = nn.Sigmoid() + self.sigmoid = nn.Sigmoid() if not config.use_logit_regression else None # Simple frame dropout probability self.frame_dropout_p = config.frame_dropout_p @@ -224,9 +229,55 @@ class RLearNPolicy(PreTrainedPolicy): print(f"⚠️ torch.compile failed: {e}") # Continue without compilation - def get_optim_params(self) -> dict: - # Train only projections, temporal module and head by default if backbones are frozen - return [p for p in self.parameters() if p.requires_grad] + def get_optim_params(self) -> list: + """Return parameter groups with custom LR and weight decay settings.""" + # Collect trainable parameters + base_params = [] + head_weight_params = [] + head_bias_params = [] + + for name, param in self.named_parameters(): + if not param.requires_grad: + continue + + if "reward_head" in name: + if "bias" in name: + head_bias_params.append(param) + else: + head_weight_params.append(param) + else: + base_params.append(param) + + # Create parameter groups with different settings + param_groups = [] + + # Base parameters (everything except head) + if base_params: + param_groups.append({ + "params": base_params, + "name": "base" + }) + + # Head weight parameters (higher LR) + if head_weight_params: + param_groups.append({ + "params": head_weight_params, + "lr": self.config.learning_rate * self.config.head_lr_multiplier, + "name": "head_weights" + }) + + # Head bias parameters (higher LR, optionally no weight decay) + if head_bias_params: + head_bias_group = { + "params": head_bias_params, + "lr": self.config.learning_rate * self.config.head_lr_multiplier, + "name": "head_bias" + } + if self.config.remove_head_bias_wd: + head_bias_group["weight_decay"] = 0.0 + param_groups.append(head_bias_group) + + return param_groups def reset(self): pass @@ -311,9 +362,16 @@ class RLearNPolicy(PreTrainedPolicy): # MLP predictor video_frame_embeds = self.mlp_predictor(frame_specific_tokens) - # Get rewards via linear head with sigmoid activation + # Get rewards via linear head normalized_embeds = self.pre_reward_norm(video_frame_embeds) - return self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T) + raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T) + + if self.config.use_logit_regression: + # In logit mode, apply sigmoid at inference + return torch.sigmoid(raw_logits) + else: + # In sigmoid mode, apply sigmoid as usual + return self.sigmoid(raw_logits) def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # Initial version: no-op; rely on upstream processors if any @@ -386,16 +444,22 @@ class RLearNPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) - # Extract frames and form (B, T, C, H, W) - frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) + # NEW: Explicit random anchor window sampling for training + if self.training: + frames, anchor_stats = self._sample_random_anchor_windows(batch) + else: + # During inference, use the generic extractor + frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) + anchor_stats = None + B, T, C, H, W = frames.shape device = next(self.parameters()).device frames = frames.to(device) - # Apply video rewinding augmentation during training + # Apply video rewinding augmentation during training (FIXED: no constant padding) augmented_target = None if self.training and self.config.use_video_rewind: - frames, augmented_target = apply_video_rewind( + frames, augmented_target = apply_video_rewind_fixed( frames, rewind_prob=self.config.rewind_prob, last3_prob=getattr(self.config, "rewind_last3_prob", None), @@ -472,122 +536,17 @@ class RLearNPolicy(PreTrainedPolicy): loss_dict: dict[str, float] = {} # Check if video rewinding already set the target - if self.training and self.config.use_video_rewind and "augmented_target" in locals(): + if self.training and self.config.use_video_rewind and augmented_target is not None: # Use the augmented target from video rewinding and align with temporal subsampling target = augmented_target[:, idx] + elif self.training and anchor_stats is not None and not anchor_stats.get("fallback_used", False): + # NEW: Calculate progress using the known random anchors + target = self._calculate_anchor_based_progress(batch, anchor_stats, T_eff) else: - # Calculate true episode progress using episode_index and frame_index from batch + # Fallback: Calculate episode progress the old way episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch) if episode_indices is not None and frame_indices is not None and self.episode_data_index is not None: - - # Calculate progress for the current frame in each sample - progress_values = [] - - for b_idx in range(B): - ep_idx = episode_indices[b_idx].item() - frame_idx = frame_indices[b_idx].item() - - # Get episode boundaries - ep_start = self.episode_data_index["from"][ep_idx].item() - ep_end = self.episode_data_index["to"][ep_idx].item() - ep_length = ep_end - ep_start - - # Progress from 0 to 1 within the episode - # frame_index is relative to the episode (0-based within episode) - progress = frame_idx / max(1, ep_length - 1) - progress_values.append(progress) - - # Create progress tensor for the current frame (last in temporal sequence) - current_progress = torch.tensor(progress_values, device=video_frame_embeds.device, dtype=video_frame_embeds.dtype) - - # Now calculate progress for ALL frames in the temporal window - # The observation_delta_indices tell us which frames we're looking at - delta_indices = self.config.observation_delta_indices # e.g., [-15, -14, ..., 0] - - # Calculate progress for each frame in the temporal window - all_progress = [] - - # DEBUG: Log indexing details for first sample occasionally - debug_indexing = torch.rand(1).item() < 0.10 # 10% chance - increased for debugging - if debug_indexing: - print(f"\n=== INDEXING DEBUG ===") - print(f"Delta indices: {delta_indices}") - print(f"Batch size: {B}") - - # Check if batch samples have diverse frame indices (red flag if all identical) - unique_frames = torch.unique(frame_indices).tolist() - unique_episodes = torch.unique(episode_indices).tolist() - print(f"Unique frame indices in batch: {unique_frames[:10]}{'...' if len(unique_frames) > 10 else ''}") - print(f"Unique episode indices in batch: {unique_episodes[:10]}{'...' if len(unique_episodes) > 10 else ''}") - - if len(unique_frames) == 1: - print("🚨 RED FLAG: All samples have IDENTICAL frame index! This causes identical targets.") - - # First sample details - ep_idx_0 = episode_indices[0].item() - frame_idx_0 = frame_indices[0].item() - ep_start_0 = self.episode_data_index["from"][ep_idx_0].item() - ep_end_0 = self.episode_data_index["to"][ep_idx_0].item() - ep_length_0 = ep_end_0 - ep_start_0 - print(f"First sample - Episode: {ep_idx_0}, Frame: {frame_idx_0}/{ep_length_0}, Episode length: {ep_length_0}") - - # Check boundary proximity - frames_from_start = frame_idx_0 - frames_from_end = ep_length_0 - frame_idx_0 - 1 - print(f"First sample proximity - Start: {frames_from_start}, End: {frames_from_end}") - - if frames_from_start < 15: - print(f"⚠️ Close to episode START: many deltas will go negative") - if frames_from_end < 15: - print(f"⚠️ Close to episode END: many deltas will exceed episode") - - for i, delta in enumerate(delta_indices): - # For each sample, calculate the progress of the frame at delta offset - frame_progress = [] - for b_idx in range(B): - ep_idx = episode_indices[b_idx].item() - frame_idx = frame_indices[b_idx].item() - - # Calculate the actual frame index with delta - target_frame_idx = frame_idx + delta - - # Get episode boundaries - ep_start = self.episode_data_index["from"][ep_idx].item() - ep_end = self.episode_data_index["to"][ep_idx].item() - ep_length = ep_end - ep_start - - # Calculate progress with proper boundary handling - if target_frame_idx < 0: - # Before episode start: extrapolate negative progress - prog = target_frame_idx / max(1, ep_length - 1) - elif target_frame_idx >= ep_length: - # After episode end: extrapolate progress beyond 1.0 - prog = target_frame_idx / max(1, ep_length - 1) - else: - # Within episode: normal progress calculation - prog = target_frame_idx / max(1, ep_length - 1) - - # Clip to reasonable bounds to prevent extreme values - prog = max(-1.0, min(2.0, prog)) # Allow some extrapolation - frame_progress.append(prog) - - # DEBUG: Log first sample's calculation - if debug_indexing and b_idx == 0: - boundary_status = "BEFORE" if target_frame_idx < 0 else "AFTER" if target_frame_idx >= ep_length else "WITHIN" - print(f" Frame {i:2d} (δ={delta:3d}): target_idx={target_frame_idx:3d} [{boundary_status}] → progress={prog:.6f}") - - all_progress.append( - torch.tensor(frame_progress, device=video_frame_embeds.device, dtype=video_frame_embeds.dtype) - ) - - if debug_indexing: - print("=" * 22) - - # Stack to get (B, T) tensor where T is the temporal sequence length - target = torch.stack(all_progress, dim=1) # (B, max_seq_len) - - # Apply stride/dropout indexing to match the processed frames - target = target[:, idx] + target = self._calculate_episode_progress(batch, episode_indices, frame_indices, T_eff, idx) else: raise ValueError( "No episode information found to build full-episode progress. " @@ -603,16 +562,26 @@ class RLearNPolicy(PreTrainedPolicy): rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} - # Calculate loss using MSE + # Calculate loss using the configured mode (logit regression or sigmoid+MSE) loss_start = time.perf_counter() assert target.dtype == torch.float, "Continuous rewards require float targets" - # Get reward predictions with sigmoid activation + # Get model outputs normalized_embeds = self.pre_reward_norm(video_frame_embeds) - predicted_rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T_eff) + raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T_eff) - # MSE loss with masking for variable length sequences - loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean') + if self.config.use_logit_regression: + # Logit regression: transform targets to logit space and compute MSE on logits + eps = self.config.logit_eps + target_clamped = torch.clamp(target[:, :T_eff], eps, 1 - eps) + target_logits = torch.logit(target_clamped) + loss = F.mse_loss(raw_logits, target_logits, reduction='mean') + # For logging/debug, also compute sigmoid predictions + predicted_rewards = torch.sigmoid(raw_logits) + else: + # Sigmoid mode: apply sigmoid and compute MSE on probabilities + predicted_rewards = self.sigmoid(raw_logits) + loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean') # Optional: Mismatched video-language pairs loss L_mismatch = torch.zeros((), device=device) @@ -644,9 +613,18 @@ class RLearNPolicy(PreTrainedPolicy): # Mismatched pairs should predict zero progress normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds) - mismatch_predictions = self.sigmoid(self.reward_head(normalized_mismatch_embeds)).squeeze(-1) - zeros_target = torch.zeros_like(target[:, :T_eff]) - L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean') + mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1) + + if self.config.use_logit_regression: + # In logit mode, target logit of ~0 corresponds to sigmoid(x)≈0 + eps = self.config.logit_eps + zeros_target_logits = torch.logit(torch.full_like(target[:, :T_eff], eps)) + L_mismatch = F.mse_loss(mismatch_raw_logits, zeros_target_logits, reduction='mean') + else: + # In sigmoid mode, target sigmoid output of 0 + mismatch_predictions = self.sigmoid(mismatch_raw_logits) + zeros_target = torch.zeros_like(target[:, :T_eff]) + L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean') # Total loss total_loss = loss + L_mismatch @@ -713,8 +691,23 @@ class RLearNPolicy(PreTrainedPolicy): "target_min": float(target.min().item()), "target_max": float(target.max().item()), "target_mean": float(target.mean().item()), + "target_std": float(target.std().item()), # Prediction statistics "pred_mean": float(predicted_rewards.mean().item()), + "pred_std": float(predicted_rewards.std().item()), + # Raw logits statistics (useful for monitoring head behavior) + "raw_logits_mean": float(raw_logits.mean().item()), + "raw_logits_std": float(raw_logits.std().item()), + # NEW: Anchor sampling statistics if available + **({ + "anchor_mean": float(anchor_stats['anchor_mean']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, + "anchor_std": float(anchor_stats['anchor_std']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, + "oob_fraction": float(anchor_stats['oob_fraction']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, + "padded_fraction": float(anchor_stats['padded_fraction']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, + "use_random_anchors": not (anchor_stats and anchor_stats.get('fallback_used', False)) if anchor_stats else False, + }), + # Loss mode indicator + "logit_regression": bool(self.config.use_logit_regression), # Timing information "timing_vision_ms": float(vision_time * 1000), "timing_language_ms": float(lang_time * 1000), @@ -874,6 +867,211 @@ class RLearNPolicy(PreTrainedPolicy): return ep, fr + def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Sample random anchor windows for training to avoid sampling bias. + + Returns: + frames: (B, T, C, H, W) tensor with T = max_seq_len + anchor_stats: dict with sampling statistics for logging + """ + # Extract episode and frame indices + episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch) + + if episode_indices is None or frame_indices is None or self.episode_data_index is None: + # Fallback to generic extractor if we don't have episode info + frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) + return frames, {"fallback_used": True} + + device = next(self.parameters()).device + B = len(episode_indices) + T = self.config.max_seq_len + delta_indices = self.config.observation_delta_indices # [-15, -14, ..., 0] + + # Get raw image data - assume it's already a temporal sequence from dataset + raw_frames = extract_visual_sequence(batch, target_seq_len=None) # Don't force padding + available_T = raw_frames.shape[1] + + # For each sample, choose a random anchor and build the window + sampled_frames = [] + anchor_positions = [] + oob_count = 0 + padded_count = 0 + resampled_count = 0 + + for b_idx in range(B): + ep_idx = episode_indices[b_idx].item() + frame_idx = frame_indices[b_idx].item() # Current frame position in episode + + # Get episode boundaries + ep_start = self.episode_data_index["from"][ep_idx].item() + ep_end = self.episode_data_index["to"][ep_idx].item() + ep_length = ep_end - ep_start + + # Choose random anchor within episode bounds such that we can get a full window + # The anchor is the "current" frame (delta=0), so we need at least T-1 frames before it + min_anchor = T - 1 # Need 15 frames before for [-15..0] window + max_anchor = ep_length - 1 # Episode frame indices are 0-based + + if min_anchor > max_anchor: + # Episode too short for full window - use available frames with padding + anchor = max_anchor + padded_count += 1 + else: + # Sample uniformly from valid range + anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item() + + anchor_positions.append(anchor) + + # Build window indices relative to episode start + window_indices = [anchor + delta for delta in delta_indices] + + # Handle out-of-bounds with reflection or clamping + valid_indices = [] + had_oob = False + for idx in window_indices: + if idx < 0: + # Reflect at episode boundary + valid_indices.append(-idx) + had_oob = True + elif idx >= ep_length: + # Reflect at episode end + valid_indices.append(2 * (ep_length - 1) - idx) + had_oob = True + else: + valid_indices.append(idx) + + if had_oob: + oob_count += 1 + + # Extract frames at these indices from the raw temporal sequence + # Map episode-relative indices to sequence indices + frame_tensors = [] + for ep_rel_idx in valid_indices: + if ep_rel_idx < available_T: + frame_tensors.append(raw_frames[b_idx, ep_rel_idx]) + else: + # Fallback: repeat last available frame + frame_tensors.append(raw_frames[b_idx, -1]) + padded_count += 1 + + sampled_frames.append(torch.stack(frame_tensors)) # (T, C, H, W) + + frames = torch.stack(sampled_frames, dim=0) # (B, T, C, H, W) + + anchor_stats = { + "anchor_mean": float(torch.tensor(anchor_positions).float().mean()), + "anchor_std": float(torch.tensor(anchor_positions).float().std()), + "oob_fraction": float(oob_count) / B, + "padded_fraction": float(padded_count) / B, + "resampled_count": resampled_count, + "fallback_used": False + } + + return frames, anchor_stats + + def _calculate_anchor_based_progress(self, batch: dict[str, Tensor], anchor_stats: dict, T_eff: int) -> Tensor: + """Calculate progress labels based on known random anchors (more efficient).""" + episode_indices, _ = self._extract_episode_and_frame_indices(batch) + if episode_indices is None: + raise ValueError("Need episode_indices for anchor-based progress calculation") + + device = next(self.parameters()).device + B = len(episode_indices) + delta_indices = self.config.observation_delta_indices + + # Build progress for each anchor position in the batch + all_progress = [] + + for i, delta in enumerate(delta_indices[:T_eff]): # Only compute for frames we'll actually use + frame_progress = [] + for b_idx in range(B): + ep_idx = episode_indices[b_idx].item() + + # Get episode length + ep_start = self.episode_data_index["from"][ep_idx].item() + ep_end = self.episode_data_index["to"][ep_idx].item() + ep_length = ep_end - ep_start + + # The anchor was chosen during window sampling + # For anchor-based progress, we use window-relative progress to center around 0.5 + # This is more stable and matches ReWiND's simple approach + window_position = i # Position in window [0, T_eff-1] + progress = window_position / max(1, T_eff - 1) # 0 to 1 across window + + frame_progress.append(progress) + + all_progress.append( + torch.tensor(frame_progress, device=device, dtype=torch.float32) + ) + + return torch.stack(all_progress, dim=1) # (B, T_eff) + + def _calculate_episode_progress(self, batch: dict[str, Tensor], episode_indices: Tensor, + frame_indices: Tensor, T_eff: int, idx: Tensor) -> Tensor: + """Calculate progress labels using episode-relative positions (legacy fallback).""" + device = next(self.parameters()).device + B = len(episode_indices) + delta_indices = self.config.observation_delta_indices + + # Calculate progress for each frame in the temporal window + all_progress = [] + + # DEBUG: Log indexing details for first sample occasionally + debug_indexing = torch.rand(1).item() < 0.05 # 5% chance + if debug_indexing: + print(f"\n=== EPISODE PROGRESS DEBUG ===") + print(f"Delta indices: {delta_indices}") + print(f"Batch size: {B}, T_eff: {T_eff}") + + # Check if batch samples have diverse frame indices + unique_frames = torch.unique(frame_indices).tolist() + unique_episodes = torch.unique(episode_indices).tolist() + print(f"Unique frame indices in batch: {len(unique_frames)} values") + print(f"Unique episode indices in batch: {len(unique_episodes)} values") + + if len(unique_frames) == 1: + print("🚨 RED FLAG: All samples have IDENTICAL frame index!") + + for i, delta in enumerate(delta_indices): + # For each sample, calculate the progress of the frame at delta offset + frame_progress = [] + for b_idx in range(B): + ep_idx = episode_indices[b_idx].item() + frame_idx = frame_indices[b_idx].item() + + # Calculate the actual frame index with delta + target_frame_idx = frame_idx + delta + + # Get episode boundaries + ep_start = self.episode_data_index["from"][ep_idx].item() + ep_end = self.episode_data_index["to"][ep_idx].item() + ep_length = ep_end - ep_start + + # Calculate progress with proper boundary handling + if target_frame_idx < 0: + prog = target_frame_idx / max(1, ep_length - 1) + elif target_frame_idx >= ep_length: + prog = target_frame_idx / max(1, ep_length - 1) + else: + prog = target_frame_idx / max(1, ep_length - 1) + + # Clip to reasonable bounds and clamp to [0,1] as recommended + prog = max(0.0, min(1.0, prog)) + frame_progress.append(prog) + + all_progress.append( + torch.tensor(frame_progress, device=device, dtype=torch.float32) + ) + + if debug_indexing: + print("=" * 30) + + # Stack to get (B, T) tensor where T is the temporal sequence length + target = torch.stack(all_progress, dim=1) # (B, max_seq_len) + + # Apply stride/dropout indexing to match the processed frames + return target[:, idx] + def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]: import json lengths: list[int] = [] @@ -984,17 +1182,16 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None return frames - - - -def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]: - """Apply video rewinding augmentation as described in ReWiND paper. - - Each video in the batch has an independent chance of being rewound. +def apply_video_rewind_fixed(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]: + """Apply video rewinding augmentation WITHOUT constant-value padding (FIXED version). + + This version ensures the rewound sequence is exactly T frames without flat plateaus + that drag down the target mean. Args: frames: Tensor of shape (B, T, C, H, W) rewind_prob: Probability of applying rewind augmentation to each video + last3_prob: Probability of limiting rewind to last 3 frames Returns: Augmented frames and corresponding progress labels @@ -1002,8 +1199,8 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo B, T, C, H, W = frames.shape device = frames.device - # Create default progress labels (linearly increasing from 0 to 1 with denominator T-1) - # torch.linspace(0, 1, T) already yields j/(T-1) at step j + # Create default progress labels using window-relative progress (0 to 1) + # This centers the mean around 0.5 and removes episode-length dependence default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1) # Apply rewind augmentation to each sample in batch independently @@ -1020,50 +1217,68 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo augmented_progress.append(default_progress[b]) continue - # Apply rewinding to this video - # Split point i: between frame 2 and T-1 (upper bound exclusive in torch.randint) - i = torch.randint(2, T, (1,)).item() + # Apply rewinding - but ensure we get exactly T frames + max_attempts = 10 # Limit resampling attempts + success = False + + for attempt in range(max_attempts): + # Split point i: between frame 2 and T-1 + i = torch.randint(2, T, (1,)).item() - # Rewind length k: between 1 and i-1 frames - if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3: - k = min(3, i - 1) + # Rewind length k: between 1 and i-1 frames + if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3: + k = min(3, i - 1) + else: + k = torch.randint(1, i, (1,)).item() + k = min(k, i - 1) + + # Create rewound sequence: frames[0:i] + reversed frames[i-k:i] + forward_length = i + reverse_length = k + total_length = forward_length + reverse_length + + # Check if we can make exactly T frames + if total_length == T: + # Perfect fit! + forward_frames = frames[b, :i] + reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0]) + rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0) + + # Create corresponding progress labels without constant padding + denom = max(T - 1, 1) + forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) + reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device) + rewound_progress = torch.cat([forward_progress, reverse_progress]) + + success = True + break + elif total_length < T: + # Too short - try to extend by adjusting k + needed = T - total_length + if i + needed <= T: # Can we extend k? + k_extended = k + needed + if i - k_extended >= 0: + forward_frames = frames[b, :i] + reverse_frames = frames[b, max(0, i - k_extended):i].flip(dims=[0]) + rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0) + + if rewound_seq.shape[0] == T: + # Create progress labels + denom = max(T - 1, 1) + forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) + reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device) + rewound_progress = torch.cat([forward_progress, reverse_progress]) + + success = True + break + # If too long or can't fix, try again with different i,k + + if success: + augmented_frames.append(rewound_seq) + augmented_progress.append(rewound_progress) else: - k = torch.randint(1, i, (1,)).item() - k = min(k, i - 1) - - # Create rewound sequence: o1...oi, oi-1, ..., oi-k - forward_frames = frames[b, :i] # Frames up to split point - reverse_frames = frames[b, max(0, i - k) : i].flip(dims=[0]) # Reversed frames - - # Concatenate forward and reverse parts - rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0) - - # Pad by repeating the last real frame if needed to maintain fixed length T - if rewound_seq.shape[0] < T: - last_frame = rewound_seq[-1:] - pad_frames = last_frame.expand(T - rewound_seq.shape[0], C, H, W) - rewound_seq = torch.cat([rewound_seq, pad_frames], dim=0) - elif rewound_seq.shape[0] > T: - rewound_seq = rewound_seq[:T] - - # Create corresponding progress labels - denom = max(T - 1, 1) - # Forward part: increasing progress using denominator T-1 - forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) - # Reverse part: decreasing progress starting from (i-1)/(T-1) - reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device) - - rewound_progress = torch.cat([forward_progress, reverse_progress]) - - # Pad progress by repeating the last real progress if needed - if rewound_progress.shape[0] < T: - last_val = rewound_progress[-1] - pad_vals = last_val.expand(T - rewound_progress.shape[0]) - rewound_progress = torch.cat([rewound_progress, pad_vals]) - elif rewound_progress.shape[0] > T: - rewound_progress = rewound_progress[:T] - - augmented_frames.append(rewound_seq) - augmented_progress.append(rewound_progress) + # Fallback: use original sequence if we can't create a good rewind + augmented_frames.append(frames[b]) + augmented_progress.append(default_progress[b]) return torch.stack(augmented_frames), torch.stack(augmented_progress) \ No newline at end of file diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 3067580af..e5b6c33de 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -248,6 +248,15 @@ def train(cfg: TrainPipelineConfig): drop_n_last_frames=cfg.policy.drop_n_last_frames, shuffle=True, ) + elif cfg.policy.type == "rlearn": + # For RLearN, drop first 15 frames to avoid padding issues with temporal windows + shuffle = False + sampler = EpisodeAwareSampler( + dataset.episode_data_index, + drop_n_first_frames=15, # Skip frames that would need padding + drop_n_last_frames=0, + shuffle=True, + ) else: shuffle = True sampler = None