mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
cleanup
This commit is contained in:
@@ -464,10 +464,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||
loss_dict: dict[str, float] = {}
|
||||
|
||||
# Generate progress targets
|
||||
# Generate progress targets that span full 0-1 range
|
||||
if self.training and augmented_target is not None:
|
||||
# Video rewind already generated targets
|
||||
target = augmented_target[:, idx]
|
||||
# Always create targets that span 0-1 across T_eff frames for better distribution
|
||||
target = torch.linspace(0, 1, T_eff, device=device).unsqueeze(0).expand(B, -1)
|
||||
else:
|
||||
# Use anchor-based window-relative progress
|
||||
if anchor_stats.get("fallback_used", False):
|
||||
@@ -501,36 +501,60 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# For logging, compute sigmoid predictions
|
||||
predicted_rewards = torch.sigmoid(raw_logits)
|
||||
|
||||
# Mismatched video-language pairs loss (always logit regression)
|
||||
# Mismatched video-language pairs loss (only when languages actually differ)
|
||||
L_mismatch = torch.zeros((), device=device)
|
||||
if self.training and B > 1 and torch.rand(1, device=device).item() < self.config.mismatch_prob:
|
||||
# Shuffle language within batch
|
||||
# Create actual mismatches - ensure shuffled language != original language
|
||||
shuffled_indices = torch.randperm(B, device=device)
|
||||
shuffled_commands = [commands[i] for i in shuffled_indices]
|
||||
|
||||
# Re-encode with mismatched language
|
||||
lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device)
|
||||
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
|
||||
# Find which samples actually got different languages
|
||||
mismatch_mask = []
|
||||
shuffled_commands = []
|
||||
for i in range(B):
|
||||
shuffled_idx = shuffled_indices[i].item()
|
||||
original_cmd = commands[i]
|
||||
shuffled_cmd = commands[shuffled_idx]
|
||||
|
||||
# Only count as mismatch if languages are actually different
|
||||
is_mismatch = original_cmd != shuffled_cmd
|
||||
mismatch_mask.append(is_mismatch)
|
||||
shuffled_commands.append(shuffled_cmd)
|
||||
|
||||
# Pack and forward
|
||||
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
|
||||
mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
|
||||
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
|
||||
|
||||
# Process mismatch frames with single MLP
|
||||
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
|
||||
|
||||
mismatch_embeds = self.mlp_predictor(mismatch_tokens)
|
||||
|
||||
# Mismatched pairs should predict near-zero progress (logit mode)
|
||||
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
|
||||
mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1)
|
||||
|
||||
# Target logit corresponding to sigmoid ≈ 0
|
||||
eps = self.config.logit_eps
|
||||
zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps))
|
||||
L_mismatch = F.mse_loss(mismatch_raw_logits, zeros_target_logits, reduction='mean')
|
||||
# Only apply mismatch loss if we have actual mismatches
|
||||
if any(mismatch_mask):
|
||||
# Re-encode with mismatched language
|
||||
lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device)
|
||||
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
|
||||
|
||||
# Pack and forward
|
||||
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
|
||||
mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
|
||||
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
|
||||
|
||||
# Process mismatch frames with single MLP
|
||||
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
|
||||
mismatch_embeds = self.mlp_predictor(mismatch_tokens)
|
||||
|
||||
# Predict near-zero progress for mismatched pairs
|
||||
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
|
||||
mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1)
|
||||
|
||||
# Create mask tensor for loss calculation
|
||||
mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool)
|
||||
|
||||
if mismatch_tensor.any():
|
||||
# Target logit corresponding to sigmoid ≈ 0
|
||||
eps = self.config.logit_eps
|
||||
zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps))
|
||||
|
||||
# Only compute loss for samples that are actually mismatched
|
||||
mismatch_loss_per_sample = F.mse_loss(
|
||||
mismatch_raw_logits, zeros_target_logits, reduction='none'
|
||||
).mean(dim=1) # (B,)
|
||||
|
||||
# Apply mask and average only over true mismatches
|
||||
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
|
||||
|
||||
# Total loss
|
||||
total_loss = loss + L_mismatch
|
||||
@@ -545,9 +569,11 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
print(f"\n=== LOGIT REGRESSION DEBUG ===")
|
||||
print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}")
|
||||
has_high_targets = (target_expanded > 0.8).any().item()
|
||||
print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}")
|
||||
print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}")
|
||||
print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}")
|
||||
print(f"Sample {sample_idx}: targets={sample_targets[:8]} preds={sample_preds[:8]}")
|
||||
print(f"Sample {sample_idx} (T_eff={T_eff}): target_range=[{sample_targets.min():.3f}, {sample_targets.max():.3f}] pred_range=[{sample_preds.min():.3f}, {sample_preds.max():.3f}]")
|
||||
print(f"Loss: {loss:.6f}")
|
||||
print("=" * 40)
|
||||
|
||||
@@ -576,6 +602,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
"anchor_std": float(anchor_stats.get('anchor_std', 0.0)),
|
||||
"oob_fraction": float(anchor_stats.get('oob_fraction', 0.0)),
|
||||
"padded_fraction": float(anchor_stats.get('padded_fraction', 0.0)),
|
||||
# Mismatch loss statistics
|
||||
"mismatch_applied": float(L_mismatch.item() > 0),
|
||||
# Timing information
|
||||
"timing_vision_ms": float(vision_time * 1000),
|
||||
"timing_language_ms": float(lang_time * 1000),
|
||||
@@ -806,11 +834,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
return frames, anchor_stats
|
||||
|
||||
def _calculate_anchor_based_progress(self, T_eff: int) -> Tensor:
|
||||
"""Generate window-relative progress (0 to 1 across window)."""
|
||||
"""Generate window-relative progress (0 to 1 across actual frames used)."""
|
||||
device = next(self.parameters()).device
|
||||
# Simple window-relative progress: 0 to 1 across the temporal window
|
||||
# This centers the mean around 0.5 and is stable regardless of episode length
|
||||
progress = torch.linspace(0, 1, T_eff, device=device)
|
||||
# Create progress that spans 0 to 1 across the T_eff frames we actually use
|
||||
# This ensures we get samples at all progress levels including near 1.0
|
||||
if T_eff == 1:
|
||||
progress = torch.tensor([0.5], device=device) # Single frame gets middle progress
|
||||
else:
|
||||
progress = torch.linspace(0, 1, T_eff, device=device) # Full 0-1 range
|
||||
return progress.unsqueeze(0) # (1, T_eff) - will broadcast to (B, T_eff)
|
||||
|
||||
|
||||
@@ -939,8 +970,8 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
B, T, C, H, W = frames.shape
|
||||
device = frames.device
|
||||
|
||||
# Create default progress labels using window-relative progress (0 to 1)
|
||||
# This centers the mean around 0.5 and removes episode-length dependence
|
||||
# Create default progress labels - will be properly scaled after stride/dropout
|
||||
# Use frame indices that will give 0-1 range after subsampling
|
||||
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
# Apply rewind augmentation to each sample in batch independently
|
||||
|
||||
Reference in New Issue
Block a user