mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
debug same frame
This commit is contained in:
@@ -389,10 +389,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Process in batch through DINOv3 model
|
||||
# Use inference mode for better performance when possible
|
||||
context_manager = torch.inference_mode() if not self.training else nullcontext()
|
||||
with context_manager:
|
||||
vision_outputs = self.vision_model(**inputs)
|
||||
# DEBUGGING: Disable inference mode to check if it's causing caching issues
|
||||
# context_manager = torch.inference_mode() if not self.training else nullcontext()
|
||||
# with context_manager:
|
||||
vision_outputs = self.vision_model(**inputs)
|
||||
|
||||
# Use pooler_output from DINOv3 (better than CLS token)
|
||||
if hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None:
|
||||
@@ -405,10 +405,53 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)
|
||||
|
||||
# DEBUG: Analyze vision feature variability
|
||||
if self.training and torch.rand(1).item() < 0.05: # 5% of training steps
|
||||
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
|
||||
with torch.no_grad():
|
||||
print(f"\n🔍 DINOv3 VISION FEATURE DEBUG (B={B}, T={T}):")
|
||||
|
||||
# CRITICAL: Check if input frames are actually different
|
||||
print(f"Raw frame tensor stats: mean={frames.mean():.6f}, std={frames.std():.6f}")
|
||||
|
||||
# Check frame-to-frame differences in raw input
|
||||
if T > 1:
|
||||
raw_frame_diffs = torch.norm(
|
||||
frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :],
|
||||
dim=(2, 3, 4) # Across C, H, W
|
||||
).mean()
|
||||
print(f"Raw input frame differences: {raw_frame_diffs:.6f}")
|
||||
|
||||
if raw_frame_diffs < 0.001:
|
||||
print(f" ⚠️ INPUT FRAMES ARE NEARLY IDENTICAL! Diff: {raw_frame_diffs:.8f}")
|
||||
else:
|
||||
print(f" ✓ Input frames are different. Diff: {raw_frame_diffs:.6f}")
|
||||
|
||||
# Check processed pixel values
|
||||
first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels
|
||||
if T > 1:
|
||||
pixel_diffs = torch.norm(
|
||||
first_sample_pixels[1:] - first_sample_pixels[:-1],
|
||||
dim=(1, 2, 3) # Across C, H, W
|
||||
).mean()
|
||||
print(f"Processed pixel_values differences: {pixel_diffs:.6f}")
|
||||
|
||||
if pixel_diffs < 0.001:
|
||||
print(f" ⚠️ PROCESSED PIXELS ARE NEARLY IDENTICAL! Diff: {pixel_diffs:.8f}")
|
||||
else:
|
||||
print(f" ✓ Processed pixels are different. Diff: {pixel_diffs:.6f}")
|
||||
|
||||
# Check if all samples in batch have same first frame
|
||||
if B > 1:
|
||||
batch_first_frame_diff = torch.norm(
|
||||
inputs['pixel_values'][::T] - inputs['pixel_values'][0].unsqueeze(0),
|
||||
dim=(1, 2, 3)
|
||||
).mean()
|
||||
print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}")
|
||||
|
||||
if batch_first_frame_diff < 0.001:
|
||||
print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}")
|
||||
else:
|
||||
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}")
|
||||
|
||||
# Check feature statistics
|
||||
feature_mean = vision_features.mean().item()
|
||||
feature_std = vision_features.std().item()
|
||||
@@ -988,6 +1031,27 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices]
|
||||
sampled_frames.append(torch.stack(frame_tensors))
|
||||
window_frame_indices.append(frame_indices_for_progress)
|
||||
|
||||
# DEBUG: Check if stride sampling is producing different frames
|
||||
if torch.rand(1).item() < 0.1 and b_idx == 0: # Debug first sample occasionally
|
||||
print(f"\n🔍 STRIDE SAMPLING DEBUG (Sample {b_idx}):")
|
||||
print(f"Episode length: {ep_length}, Anchor: {anchor}")
|
||||
print(f"Window indices: {window_indices[:5]}...{window_indices[-5:]}") # First and last 5
|
||||
print(f"Frame indices for progress: {frame_indices_for_progress[:5]}...{frame_indices_for_progress[-5:]}")
|
||||
|
||||
# Check if window indices are all the same
|
||||
unique_indices = len(set(window_indices))
|
||||
print(f"Unique window indices: {unique_indices} out of {len(window_indices)}")
|
||||
if unique_indices == 1:
|
||||
print(f" ⚠️ ALL WINDOW INDICES ARE THE SAME! Index: {window_indices[0]}")
|
||||
|
||||
# Check frame tensor differences
|
||||
if len(frame_tensors) > 1:
|
||||
frame_diff = torch.norm(frame_tensors[1] - frame_tensors[0]).item()
|
||||
print(f"First frame difference: {frame_diff:.6f}")
|
||||
if frame_diff < 0.001:
|
||||
print(f" ⚠️ CONSECUTIVE SAMPLED FRAMES ARE NEARLY IDENTICAL!")
|
||||
print("-" * 50)
|
||||
|
||||
frames = torch.stack(sampled_frames, dim=0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user