debug same frame

This commit is contained in:
Pepijn
2025-08-31 19:06:30 +02:00
parent 43eedf62e4
commit 9204a8bccd

View File

@@ -389,10 +389,10 @@ class RLearNPolicy(PreTrainedPolicy):
inputs = {k: v.to(device) for k, v in inputs.items()}
# Process in batch through DINOv3 model
# Use inference mode for better performance when possible
context_manager = torch.inference_mode() if not self.training else nullcontext()
with context_manager:
vision_outputs = self.vision_model(**inputs)
# DEBUGGING: Disable inference mode to check if it's causing caching issues
# context_manager = torch.inference_mode() if not self.training else nullcontext()
# with context_manager:
vision_outputs = self.vision_model(**inputs)
# Use pooler_output from DINOv3 (better than CLS token)
if hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None:
@@ -405,10 +405,53 @@ class RLearNPolicy(PreTrainedPolicy):
vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)
# DEBUG: Analyze vision feature variability
if self.training and torch.rand(1).item() < 0.05: # 5% of training steps
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
with torch.no_grad():
print(f"\n🔍 DINOv3 VISION FEATURE DEBUG (B={B}, T={T}):")
# CRITICAL: Check if input frames are actually different
print(f"Raw frame tensor stats: mean={frames.mean():.6f}, std={frames.std():.6f}")
# Check frame-to-frame differences in raw input
if T > 1:
raw_frame_diffs = torch.norm(
frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :],
dim=(2, 3, 4) # Across C, H, W
).mean()
print(f"Raw input frame differences: {raw_frame_diffs:.6f}")
if raw_frame_diffs < 0.001:
print(f" ⚠️ INPUT FRAMES ARE NEARLY IDENTICAL! Diff: {raw_frame_diffs:.8f}")
else:
print(f" ✓ Input frames are different. Diff: {raw_frame_diffs:.6f}")
# Check processed pixel values
first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels
if T > 1:
pixel_diffs = torch.norm(
first_sample_pixels[1:] - first_sample_pixels[:-1],
dim=(1, 2, 3) # Across C, H, W
).mean()
print(f"Processed pixel_values differences: {pixel_diffs:.6f}")
if pixel_diffs < 0.001:
print(f" ⚠️ PROCESSED PIXELS ARE NEARLY IDENTICAL! Diff: {pixel_diffs:.8f}")
else:
print(f" ✓ Processed pixels are different. Diff: {pixel_diffs:.6f}")
# Check if all samples in batch have same first frame
if B > 1:
batch_first_frame_diff = torch.norm(
inputs['pixel_values'][::T] - inputs['pixel_values'][0].unsqueeze(0),
dim=(1, 2, 3)
).mean()
print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}")
if batch_first_frame_diff < 0.001:
print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}")
else:
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}")
# Check feature statistics
feature_mean = vision_features.mean().item()
feature_std = vision_features.std().item()
@@ -988,6 +1031,27 @@ class RLearNPolicy(PreTrainedPolicy):
frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices]
sampled_frames.append(torch.stack(frame_tensors))
window_frame_indices.append(frame_indices_for_progress)
# DEBUG: Check if stride sampling is producing different frames
if torch.rand(1).item() < 0.1 and b_idx == 0: # Debug first sample occasionally
print(f"\n🔍 STRIDE SAMPLING DEBUG (Sample {b_idx}):")
print(f"Episode length: {ep_length}, Anchor: {anchor}")
print(f"Window indices: {window_indices[:5]}...{window_indices[-5:]}") # First and last 5
print(f"Frame indices for progress: {frame_indices_for_progress[:5]}...{frame_indices_for_progress[-5:]}")
# Check if window indices are all the same
unique_indices = len(set(window_indices))
print(f"Unique window indices: {unique_indices} out of {len(window_indices)}")
if unique_indices == 1:
print(f" ⚠️ ALL WINDOW INDICES ARE THE SAME! Index: {window_indices[0]}")
# Check frame tensor differences
if len(frame_tensors) > 1:
frame_diff = torch.norm(frame_tensors[1] - frame_tensors[0]).item()
print(f"First frame difference: {frame_diff:.6f}")
if frame_diff < 0.001:
print(f" ⚠️ CONSECUTIVE SAMPLED FRAMES ARE NEARLY IDENTICAL!")
print("-" * 50)
frames = torch.stack(sampled_frames, dim=0)