This commit is contained in:
Pepijn
2025-08-31 15:52:15 +02:00
parent 1e1b010257
commit 221e5862ea

View File

@@ -464,10 +464,10 @@ class RLearNPolicy(PreTrainedPolicy):
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
loss_dict: dict[str, float] = {}
# Generate progress targets
# Generate progress targets that span full 0-1 range
if self.training and augmented_target is not None:
# Video rewind already generated targets
target = augmented_target[:, idx]
# Always create targets that span 0-1 across T_eff frames for better distribution
target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1)
else:
# Use anchor-based window-relative progress
if anchor_stats.get("fallback_used", False):
@@ -501,36 +501,60 @@ class RLearNPolicy(PreTrainedPolicy):
# For logging, compute sigmoid predictions
predicted_rewards = torch.sigmoid(raw_logits)
# Mismatched video-language pairs loss (always logit regression)
# Mismatched video-language pairs loss (only when languages actually differ)
L_mismatch = torch.zeros((), device=device)
if self.training and B > 1 and torch.rand(1, device=device).item() < self.config.mismatch_prob:
# Shuffle language within batch
# Create actual mismatches - ensure shuffled language != original language
shuffled_indices = torch.randperm(B, device=device)
shuffled_commands = [commands[i] for i in shuffled_indices]
# Re-encode with mismatched language
lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device)
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
# Find which samples actually got different languages
mismatch_mask = []
shuffled_commands = []
for i in range(B):
shuffled_idx = shuffled_indices[i].item()
original_cmd = commands[i]
shuffled_cmd = commands[shuffled_idx]
# Only count as mismatch if languages are actually different
is_mismatch = original_cmd != shuffled_cmd
mismatch_mask.append(is_mismatch)
shuffled_commands.append(shuffled_cmd)
# Pack and forward
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
# Process mismatch frames with single MLP
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
mismatch_embeds = self.mlp_predictor(mismatch_tokens)
# Mismatched pairs should predict near-zero progress (logit mode)
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1)
# Target logit corresponding to sigmoid ≈ 0
eps = self.config.logit_eps
zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps))
L_mismatch = F.mse_loss(mismatch_raw_logits, zeros_target_logits, reduction='mean')
# Only apply mismatch loss if we have actual mismatches
if any(mismatch_mask):
# Re-encode with mismatched language
lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device)
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
# Pack and forward
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
# Process mismatch frames with single MLP
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
mismatch_embeds = self.mlp_predictor(mismatch_tokens)
# Predict near-zero progress for mismatched pairs
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1)
# Create mask tensor for loss calculation
mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool)
if mismatch_tensor.any():
# Target logit corresponding to sigmoid ≈ 0
eps = self.config.logit_eps
zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps))
# Only compute loss for samples that are actually mismatched
mismatch_loss_per_sample = F.mse_loss(
mismatch_raw_logits, zeros_target_logits, reduction='none'
).mean(dim=1) # (B,)
# Apply mask and average only over true mismatches
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
# Total loss
total_loss = loss + L_mismatch
@@ -545,9 +569,11 @@ class RLearNPolicy(PreTrainedPolicy):
print(f"\n=== LOGIT REGRESSION DEBUG ===")
print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}")
has_high_targets = (target_expanded > 0.8).any().item()
print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}")
print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}")
print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}")
print(f"Sample {sample_idx}: targets={sample_targets[:8]} preds={sample_preds[:8]}")
print(f"Sample {sample_idx} (T_eff={T_eff}): target_range=[{sample_targets.min():.3f}, {sample_targets.max():.3f}] pred_range=[{sample_preds.min():.3f}, {sample_preds.max():.3f}]")
print(f"Loss: {loss:.6f}")
print("=" * 40)
@@ -576,6 +602,8 @@ class RLearNPolicy(PreTrainedPolicy):
"anchor_std": float(anchor_stats.get('anchor_std', 0.0)),
"oob_fraction": float(anchor_stats.get('oob_fraction', 0.0)),
"padded_fraction": float(anchor_stats.get('padded_fraction', 0.0)),
# Mismatch loss statistics
"mismatch_applied": float(L_mismatch.item() > 0),
# Timing information
"timing_vision_ms": float(vision_time * 1000),
"timing_language_ms": float(lang_time * 1000),
@@ -806,11 +834,14 @@ class RLearNPolicy(PreTrainedPolicy):
return frames, anchor_stats
def _calculate_anchor_based_progress(self, T_eff: int) -> Tensor:
"""Generate window-relative progress (0 to 1 across window)."""
"""Generate window-relative progress (0 to 1 across actual frames used)."""
device = next(self.parameters()).device
# Simple window-relative progress: 0 to 1 across the temporal window
# This centers the mean around 0.5 and is stable regardless of episode length
progress = torch.linspace(0, 1, T_eff, device=device)
# Create progress that spans 0 to 1 across the T_eff frames we actually use
# This ensures we get samples at all progress levels including near 1.0
if T_eff == 1:
progress = torch.tensor([0.5], device=device) # Single frame gets middle progress
else:
progress = torch.linspace(0, 1, T_eff, device=device) # Full 0-1 range
return progress.unsqueeze(0) # (1, T_eff) - will broadcast to (B, T_eff)
@@ -939,8 +970,8 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
B, T, C, H, W = frames.shape
device = frames.device
# Create default progress labels using window-relative progress (0 to 1)
# This centers the mean around 0.5 and removes episode-length dependence
# Create default progress labels - will be properly scaled after stride/dropout
# Use frame indices that will give 0-1 range after subsampling
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
# Apply rewind augmentation to each sample in batch independently