From c51d40ad56532a65d4e6a8b31769384d45eaddc4 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 18:38:50 +0200 Subject: [PATCH] add vision feature debug --- .../policies/rlearn/modeling_rlearn.py | 56 ++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 6f5a224eb..ae4b42e3c 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -395,7 +395,61 @@ class RLearNPolicy(PreTrainedPolicy): vision_outputs = self.vision_model.vision_model(pixel_values=pixel_values) cls_tokens = vision_outputs.last_hidden_state[:, 0] - return rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T) + # Reshape to (B, T, D) for analysis + vision_features = rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T) + + # DEBUG: Analyze vision feature variability + if self.training and torch.rand(1).item() < 0.05: # 5% of training steps + with torch.no_grad(): + print(f"\n🔍 VISION FEATURE DEBUG (B={B}, T={T}):") + + # Check feature statistics + feature_mean = vision_features.mean().item() + feature_std = vision_features.std().item() + print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}") + + # Check temporal variance for each sample + for b_idx in range(min(B, 2)): # Debug first 2 samples + sample_features = vision_features[b_idx] # (T, D) + + # Variance across time dimension + temporal_variance = sample_features.var(dim=0).mean().item() + temporal_std = sample_features.std(dim=0).mean().item() + print(f"Sample {b_idx} temporal variance: {temporal_variance:.6f} (std: {temporal_std:.6f})") + + # Frame-to-frame differences + if T > 1: + frame_diffs = torch.norm(sample_features[1:] - sample_features[:-1], dim=-1) + avg_frame_diff = frame_diffs.mean().item() + max_frame_diff = frame_diffs.max().item() + min_frame_diff = frame_diffs.min().item() + print(f"Sample {b_idx} frame differences: avg={avg_frame_diff:.6f}, " + f"max={max_frame_diff:.6f}, min={min_frame_diff:.6f}") + + # Check if features are nearly identical + if avg_frame_diff < 0.001: + print(f" ⚠️ FEATURES BARELY CHANGING! Avg diff: {avg_frame_diff:.8f}") + elif avg_frame_diff < 0.01: + print(f" ⚠️ Features changing slowly. Avg diff: {avg_frame_diff:.6f}") + else: + print(f" ✓ Features changing normally. Avg diff: {avg_frame_diff:.6f}") + + # Overall batch statistics + if B > 1 and T > 1: + all_diffs = torch.norm( + vision_features[:, 1:, :] - vision_features[:, :-1, :], + dim=-1 + ).flatten() + print(f"Batch-wide frame differences: mean={all_diffs.mean():.6f}, " + f"std={all_diffs.std():.6f}") + + # Check percentage of very small differences + small_diffs = (all_diffs < 0.001).float().mean().item() * 100 + print(f"Percentage of tiny differences (<0.001): {small_diffs:.1f}%") + + print("=" * 50) + + return vision_features def _mask_from_lens(self, lens: Tensor) -> Tensor: """Create mask from sequence lengths."""