add vision feature debug

This commit is contained in:
Pepijn
2025-08-31 18:38:50 +02:00
parent 5c1d930a34
commit c51d40ad56

View File

@@ -395,7 +395,61 @@ class RLearNPolicy(PreTrainedPolicy):
vision_outputs = self.vision_model.vision_model(pixel_values=pixel_values)
cls_tokens = vision_outputs.last_hidden_state[:, 0]
return rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
# Reshape to (B, T, D) for analysis
vision_features = rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
# DEBUG: Analyze vision feature variability
if self.training and torch.rand(1).item() < 0.05: # 5% of training steps
with torch.no_grad():
print(f"\n🔍 VISION FEATURE DEBUG (B={B}, T={T}):")
# Check feature statistics
feature_mean = vision_features.mean().item()
feature_std = vision_features.std().item()
print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}")
# Check temporal variance for each sample
for b_idx in range(min(B, 2)): # Debug first 2 samples
sample_features = vision_features[b_idx] # (T, D)
# Variance across time dimension
temporal_variance = sample_features.var(dim=0).mean().item()
temporal_std = sample_features.std(dim=0).mean().item()
print(f"Sample {b_idx} temporal variance: {temporal_variance:.6f} (std: {temporal_std:.6f})")
# Frame-to-frame differences
if T > 1:
frame_diffs = torch.norm(sample_features[1:] - sample_features[:-1], dim=-1)
avg_frame_diff = frame_diffs.mean().item()
max_frame_diff = frame_diffs.max().item()
min_frame_diff = frame_diffs.min().item()
print(f"Sample {b_idx} frame differences: avg={avg_frame_diff:.6f}, "
f"max={max_frame_diff:.6f}, min={min_frame_diff:.6f}")
# Check if features are nearly identical
if avg_frame_diff < 0.001:
print(f" ⚠️ FEATURES BARELY CHANGING! Avg diff: {avg_frame_diff:.8f}")
elif avg_frame_diff < 0.01:
print(f" ⚠️ Features changing slowly. Avg diff: {avg_frame_diff:.6f}")
else:
print(f" ✓ Features changing normally. Avg diff: {avg_frame_diff:.6f}")
# Overall batch statistics
if B > 1 and T > 1:
all_diffs = torch.norm(
vision_features[:, 1:, :] - vision_features[:, :-1, :],
dim=-1
).flatten()
print(f"Batch-wide frame differences: mean={all_diffs.mean():.6f}, "
f"std={all_diffs.std():.6f}")
# Check percentage of very small differences
small_diffs = (all_diffs < 0.001).float().mean().item() * 100
print(f"Percentage of tiny differences (<0.001): {small_diffs:.1f}%")
print("=" * 50)
return vision_features
def _mask_from_lens(self, lens: Tensor) -> Tensor:
"""Create mask from sequence lengths."""