From 9204a8bccdf4f628378fbe91c9d9148c8af12b80 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 19:06:30 +0200 Subject: [PATCH] debug same frame --- .../policies/rlearn/modeling_rlearn.py | 74 +++++++++++++++++-- 1 file changed, 69 insertions(+), 5 deletions(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 2e5ad50ce..1c1bdfb66 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -389,10 +389,10 @@ class RLearNPolicy(PreTrainedPolicy): inputs = {k: v.to(device) for k, v in inputs.items()} # Process in batch through DINOv3 model - # Use inference mode for better performance when possible - context_manager = torch.inference_mode() if not self.training else nullcontext() - with context_manager: - vision_outputs = self.vision_model(**inputs) + # DEBUGGING: Disable inference mode to check if it's causing caching issues + # context_manager = torch.inference_mode() if not self.training else nullcontext() + # with context_manager: + vision_outputs = self.vision_model(**inputs) # Use pooler_output from DINOv3 (better than CLS token) if hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None: @@ -405,10 +405,53 @@ class RLearNPolicy(PreTrainedPolicy): vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T) # DEBUG: Analyze vision feature variability - if self.training and torch.rand(1).item() < 0.05: # 5% of training steps + if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging with torch.no_grad(): print(f"\n🔍 DINOv3 VISION FEATURE DEBUG (B={B}, T={T}):") + # CRITICAL: Check if input frames are actually different + print(f"Raw frame tensor stats: mean={frames.mean():.6f}, std={frames.std():.6f}") + + # Check frame-to-frame differences in raw input + if T > 1: + raw_frame_diffs = torch.norm( + frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :], + dim=(2, 3, 4) # Across C, H, W + ).mean() + print(f"Raw input frame differences: {raw_frame_diffs:.6f}") + + if raw_frame_diffs < 0.001: + print(f" ⚠️ INPUT FRAMES ARE NEARLY IDENTICAL! Diff: {raw_frame_diffs:.8f}") + else: + print(f" ✓ Input frames are different. Diff: {raw_frame_diffs:.6f}") + + # Check processed pixel values + first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels + if T > 1: + pixel_diffs = torch.norm( + first_sample_pixels[1:] - first_sample_pixels[:-1], + dim=(1, 2, 3) # Across C, H, W + ).mean() + print(f"Processed pixel_values differences: {pixel_diffs:.6f}") + + if pixel_diffs < 0.001: + print(f" ⚠️ PROCESSED PIXELS ARE NEARLY IDENTICAL! Diff: {pixel_diffs:.8f}") + else: + print(f" ✓ Processed pixels are different. Diff: {pixel_diffs:.6f}") + + # Check if all samples in batch have same first frame + if B > 1: + batch_first_frame_diff = torch.norm( + inputs['pixel_values'][::T] - inputs['pixel_values'][0].unsqueeze(0), + dim=(1, 2, 3) + ).mean() + print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}") + + if batch_first_frame_diff < 0.001: + print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}") + else: + print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}") + # Check feature statistics feature_mean = vision_features.mean().item() feature_std = vision_features.std().item() @@ -988,6 +1031,27 @@ class RLearNPolicy(PreTrainedPolicy): frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices] sampled_frames.append(torch.stack(frame_tensors)) window_frame_indices.append(frame_indices_for_progress) + + # DEBUG: Check if stride sampling is producing different frames + if torch.rand(1).item() < 0.1 and b_idx == 0: # Debug first sample occasionally + print(f"\n🔍 STRIDE SAMPLING DEBUG (Sample {b_idx}):") + print(f"Episode length: {ep_length}, Anchor: {anchor}") + print(f"Window indices: {window_indices[:5]}...{window_indices[-5:]}") # First and last 5 + print(f"Frame indices for progress: {frame_indices_for_progress[:5]}...{frame_indices_for_progress[-5:]}") + + # Check if window indices are all the same + unique_indices = len(set(window_indices)) + print(f"Unique window indices: {unique_indices} out of {len(window_indices)}") + if unique_indices == 1: + print(f" ⚠️ ALL WINDOW INDICES ARE THE SAME! Index: {window_indices[0]}") + + # Check frame tensor differences + if len(frame_tensors) > 1: + frame_diff = torch.norm(frame_tensors[1] - frame_tensors[0]).item() + print(f"First frame difference: {frame_diff:.6f}") + if frame_diff < 0.001: + print(f" ⚠️ CONSECUTIVE SAMPLED FRAMES ARE NEARLY IDENTICAL!") + print("-" * 50) frames = torch.stack(sampled_frames, dim=0)