diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 963267b04..d6748cec9 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -314,13 +314,15 @@ class RLearNPolicy(PreTrainedPolicy): if flat.max() > 1.0: flat = flat / 255.0 - # DINOv3 expects images in [0, 1] range, RGB format - # Convert tensor to list of PIL-like arrays for processor - flat_numpy = flat.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) + # Prepare inputs for processor: feed uint8 to avoid double-rescale; let processor rescale/normalize + flat_clamped = flat.clamp(0.0, 1.0) + flat_uint8 = (flat_clamped * 255.0).round().to(torch.uint8) # (BT, C, H, W) + flat_numpy = flat_uint8.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) uint8 images_list = [flat_numpy[i] for i in range(B * T)] - + # Process through DINOv3 processor and model - inputs = self.vision_processor(images=images_list, return_tensors="pt") + # Disable center-crop to preserve motion near borders; allow default rescale/normalize + inputs = self.vision_processor(images=images_list, return_tensors="pt", do_center_crop=False) inputs = {k: v.to(device) for k, v in inputs.items()} # Process in batch through DINOv3 model @@ -411,6 +413,27 @@ class RLearNPolicy(PreTrainedPolicy): feature_mean = vision_features.mean().item() feature_std = vision_features.std().item() print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}") + + # Extra DIAGNOSTIC: CLS vs patch mean/max deltas for one sample, two far-apart frames + try: + if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2: + # Recover CLS tokens + cls_flat = tokens[:, 0, :] # (BT, D) + cls = rearrange(cls_flat, '(b t) d -> b t d', b=B, t=T) + b0 = 0 + f0, f1 = 0, T - 1 + # L2 between CLS at two frames + cls_l2 = (cls[b0, f1] - cls[b0, f0]).pow(2).sum().sqrt().item() + # Patch mean L2 + pm_f0 = patch_features[b0, f0].mean(dim=0) + pm_f1 = patch_features[b0, f1].mean(dim=0) + pm_l2 = (pm_f1 - pm_f0).pow(2).sum().sqrt().item() + # Max over patches L2 + per_patch_l2 = (patch_features[b0, f1] - patch_features[b0, f0]).pow(2).sum(dim=1).sqrt() + max_p_l2 = per_patch_l2.max().item() + print(f"CLS ΔL2: {cls_l2:.6f} | mean(patches) ΔL2: {pm_l2:.6f} | max(patch) ΔL2: {max_p_l2:.6f}") + except Exception as _: + pass # Check temporal variance for each sample for b_idx in range(min(B, 2)): # Debug first 2 samples @@ -631,7 +654,7 @@ class RLearNPolicy(PreTrainedPolicy): L_mismatch = torch.zeros((), device=device) if self.training and B > 1 and torch.rand(1, device=device).item() < self.config.mismatch_prob: # Create actual mismatches - ensure shuffled language != original language - shuffled_indices = torch.randperm(B, device=device) + shuffled_indices = torch.randperm(B, device=device) # Find which samples actually got different languages mismatch_mask = [] @@ -731,7 +754,7 @@ class RLearNPolicy(PreTrainedPolicy): pred_std = sample_preds.std() target_std = sample_targets.std() print(f" Variation - Target std: {target_std:.4f} | Pred std: {pred_std:.4f}") - else: + else: # For longer sequences, show first 8 and last 8 print(f" Targets: {sample_targets[:8]} ... {sample_targets[-8:]}") print(f" Preds: {sample_preds[:8]} ... {sample_preds[-8:]}") @@ -1238,7 +1261,7 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo default_progress = torch.stack(default_progress) else: # Fallback to window-relative progress - default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1) + default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1) # Apply rewind augmentation to each sample in batch independently augmented_frames = [] @@ -1260,14 +1283,14 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo for attempt in range(max_attempts): # Split point i: between frame 2 and T-1 - i = torch.randint(2, T, (1,)).item() + i = torch.randint(2, T, (1,)).item() - # Rewind length k: between 1 and i-1 frames - if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3: - k = min(3, i - 1) - else: - k = torch.randint(1, i, (1,)).item() - k = min(k, i - 1) + # Rewind length k: between 1 and i-1 frames + if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3: + k = min(3, i - 1) + else: + k = torch.randint(1, i, (1,)).item() + k = min(k, i - 1) # Create rewound sequence: frames[0:i] + reversed frames[i-k:i] forward_length = i @@ -1279,8 +1302,8 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo # Perfect fit! forward_frames = frames[b, :i] reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0]) - rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0) - + rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0) + # Create corresponding progress labels based on episode-relative positions if window_frame_indices and episode_lengths: # Use episode-relative progress for rewind @@ -1298,9 +1321,9 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo rewound_progress = torch.cat([forward_progress, reverse_progress]) else: # Fallback to window-relative progress - denom = max(T - 1, 1) - forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) - reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device) + denom = max(T - 1, 1) + forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) + reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device) rewound_progress = torch.cat([forward_progress, reverse_progress]) success = True @@ -1335,15 +1358,15 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo denom = max(T - 1, 1) forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device) reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device) - rewound_progress = torch.cat([forward_progress, reverse_progress]) - + rewound_progress = torch.cat([forward_progress, reverse_progress]) + success = True break # If too long or can't fix, try again with different i,k if success: - augmented_frames.append(rewound_seq) - augmented_progress.append(rewound_progress) + augmented_frames.append(rewound_seq) + augmented_progress.append(rewound_progress) else: # Fallback: use original sequence if we can't create a good rewind augmented_frames.append(frames[b])