From 221e5862ea360edc7d3d8f940d4dd3b8383c5c3a Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 15:52:15 +0200 Subject: [PATCH] cleanup --- .../policies/rlearn/modeling_rlearn.py | 101 ++++++++++++------ 1 file changed, 66 insertions(+), 35 deletions(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 16d543ae7..9e00f0cf6 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -464,10 +464,10 @@ class RLearNPolicy(PreTrainedPolicy): # IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window loss_dict: dict[str, float] = {} - # Generate progress targets + # Generate progress targets that span full 0-1 range if self.training and augmented_target is not None: - # Video rewind already generated targets - target = augmented_target[:, idx] + # Always create targets that span 0-1 across T_eff frames for better distribution + target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1) else: # Use anchor-based window-relative progress if anchor_stats.get("fallback_used", False): @@ -501,36 +501,60 @@ class RLearNPolicy(PreTrainedPolicy): # For logging, compute sigmoid predictions predicted_rewards = torch.sigmoid(raw_logits) - # Mismatched video-language pairs loss (always logit regression) + # Mismatched video-language pairs loss (only when languages actually differ) L_mismatch = torch.zeros((), device=device) if self.training and B > 1 and torch.rand(1, device=device).item() < self.config.mismatch_prob: - # Shuffle language within batch + # Create actual mismatches - ensure shuffled language != original language shuffled_indices = torch.randperm(B, device=device) - shuffled_commands = [commands[i] for i in shuffled_indices] - # Re-encode with mismatched language - lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device) - lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm) + # Find which samples actually got different languages + mismatch_mask = [] + shuffled_commands = [] + for i in range(B): + shuffled_idx = shuffled_indices[i].item() + original_cmd = commands[i] + shuffled_cmd = commands[shuffled_idx] + + # Only count as mismatch if languages are actually different + is_mismatch = original_cmd != shuffled_cmd + mismatch_mask.append(is_mismatch) + shuffled_commands.append(shuffled_cmd) - # Pack and forward - tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d') - mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) - attended_mm = self.decoder(tokens_mm, mask=mask_mm) - _, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d') - - # Process mismatch frames with single MLP - mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D) - - mismatch_embeds = self.mlp_predictor(mismatch_tokens) - - # Mismatched pairs should predict near-zero progress (logit mode) - normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds) - mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1) - - # Target logit corresponding to sigmoid ≈ 0 - eps = self.config.logit_eps - zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps)) - L_mismatch = F.mse_loss(mismatch_raw_logits, zeros_target_logits, reduction='mean') + # Only apply mismatch loss if we have actual mismatches + if any(mismatch_mask): + # Re-encode with mismatched language + lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device) + lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm) + + # Pack and forward + tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d') + mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) + attended_mm = self.decoder(tokens_mm, mask=mask_mm) + _, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d') + + # Process mismatch frames with single MLP + mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D) + mismatch_embeds = self.mlp_predictor(mismatch_tokens) + + # Predict near-zero progress for mismatched pairs + normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds) + mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1) + + # Create mask tensor for loss calculation + mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool) + + if mismatch_tensor.any(): + # Target logit corresponding to sigmoid ≈ 0 + eps = self.config.logit_eps + zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps)) + + # Only compute loss for samples that are actually mismatched + mismatch_loss_per_sample = F.mse_loss( + mismatch_raw_logits, zeros_target_logits, reduction='none' + ).mean(dim=1) # (B,) + + # Apply mask and average only over true mismatches + L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean() # Total loss total_loss = loss + L_mismatch @@ -545,9 +569,11 @@ class RLearNPolicy(PreTrainedPolicy): print(f"\n=== LOGIT REGRESSION DEBUG ===") print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}") + has_high_targets = (target_expanded > 0.8).any().item() + print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}") print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}") print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}") - print(f"Sample {sample_idx}: targets={sample_targets[:8]} preds={sample_preds[:8]}") + print(f"Sample {sample_idx} (T_eff={T_eff}): target_range=[{sample_targets.min():.3f}, {sample_targets.max():.3f}] pred_range=[{sample_preds.min():.3f}, {sample_preds.max():.3f}]") print(f"Loss: {loss:.6f}") print("=" * 40) @@ -576,6 +602,8 @@ class RLearNPolicy(PreTrainedPolicy): "anchor_std": float(anchor_stats.get('anchor_std', 0.0)), "oob_fraction": float(anchor_stats.get('oob_fraction', 0.0)), "padded_fraction": float(anchor_stats.get('padded_fraction', 0.0)), + # Mismatch loss statistics + "mismatch_applied": float(L_mismatch.item() > 0), # Timing information "timing_vision_ms": float(vision_time * 1000), "timing_language_ms": float(lang_time * 1000), @@ -806,11 +834,14 @@ class RLearNPolicy(PreTrainedPolicy): return frames, anchor_stats def _calculate_anchor_based_progress(self, T_eff: int) -> Tensor: - """Generate window-relative progress (0 to 1 across window).""" + """Generate window-relative progress (0 to 1 across actual frames used).""" device = next(self.parameters()).device - # Simple window-relative progress: 0 to 1 across the temporal window - # This centers the mean around 0.5 and is stable regardless of episode length - progress = torch.linspace(0, 1, T_eff, device=device) + # Create progress that spans 0 to 1 across the T_eff frames we actually use + # This ensures we get samples at all progress levels including near 1.0 + if T_eff == 1: + progress = torch.tensor([0.5], device=device) # Single frame gets middle progress + else: + progress = torch.linspace(0, 1, T_eff, device=device) # Full 0-1 range return progress.unsqueeze(0) # (1, T_eff) - will broadcast to (B, T_eff) @@ -939,8 +970,8 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo B, T, C, H, W = frames.shape device = frames.device - # Create default progress labels using window-relative progress (0 to 1) - # This centers the mean around 0.5 and removes episode-length dependence + # Create default progress labels - will be properly scaled after stride/dropout + # Use frame indices that will give 0-1 range after subsampling default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1) # Apply rewind augmentation to each sample in batch independently