This commit is contained in:
Pepijn
2025-08-31 21:38:46 +02:00
parent d8c875e069
commit d51bbe9492

View File

@@ -314,13 +314,15 @@ class RLearNPolicy(PreTrainedPolicy):
if flat.max() > 1.0:
flat = flat / 255.0
# DINOv3 expects images in [0, 1] range, RGB format
# Convert tensor to list of PIL-like arrays for processor
flat_numpy = flat.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C)
# Prepare inputs for processor: feed uint8 to avoid double-rescale; let processor rescale/normalize
flat_clamped = flat.clamp(0.0, 1.0)
flat_uint8 = (flat_clamped * 255.0).round().to(torch.uint8) # (BT, C, H, W)
flat_numpy = flat_uint8.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) uint8
images_list = [flat_numpy[i] for i in range(B * T)]
# Process through DINOv3 processor and model
inputs = self.vision_processor(images=images_list, return_tensors="pt")
# Disable center-crop to preserve motion near borders; allow default rescale/normalize
inputs = self.vision_processor(images=images_list, return_tensors="pt", do_center_crop=False)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Process in batch through DINOv3 model
@@ -411,6 +413,27 @@ class RLearNPolicy(PreTrainedPolicy):
feature_mean = vision_features.mean().item()
feature_std = vision_features.std().item()
print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}")
# Extra DIAGNOSTIC: CLS vs patch mean/max deltas for one sample, two far-apart frames
try:
if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2:
# Recover CLS tokens
cls_flat = tokens[:, 0, :] # (BT, D)
cls = rearrange(cls_flat, '(b t) d -> b t d', b=B, t=T)
b0 = 0
f0, f1 = 0, T - 1
# L2 between CLS at two frames
cls_l2 = (cls[b0, f1] - cls[b0, f0]).pow(2).sum().sqrt().item()
# Patch mean L2
pm_f0 = patch_features[b0, f0].mean(dim=0)
pm_f1 = patch_features[b0, f1].mean(dim=0)
pm_l2 = (pm_f1 - pm_f0).pow(2).sum().sqrt().item()
# Max over patches L2
per_patch_l2 = (patch_features[b0, f1] - patch_features[b0, f0]).pow(2).sum(dim=1).sqrt()
max_p_l2 = per_patch_l2.max().item()
print(f"CLS ΔL2: {cls_l2:.6f} | mean(patches) ΔL2: {pm_l2:.6f} | max(patch) ΔL2: {max_p_l2:.6f}")
except Exception as _:
pass
# Check temporal variance for each sample
for b_idx in range(min(B, 2)): # Debug first 2 samples
@@ -631,7 +654,7 @@ class RLearNPolicy(PreTrainedPolicy):
L_mismatch = torch.zeros((), device=device)
if self.training and B > 1 and torch.rand(1, device=device).item() < self.config.mismatch_prob:
# Create actual mismatches - ensure shuffled language != original language
shuffled_indices = torch.randperm(B, device=device)
shuffled_indices = torch.randperm(B, device=device)
# Find which samples actually got different languages
mismatch_mask = []
@@ -731,7 +754,7 @@ class RLearNPolicy(PreTrainedPolicy):
pred_std = sample_preds.std()
target_std = sample_targets.std()
print(f" Variation - Target std: {target_std:.4f} | Pred std: {pred_std:.4f}")
else:
else:
# For longer sequences, show first 8 and last 8
print(f" Targets: {sample_targets[:8]} ... {sample_targets[-8:]}")
print(f" Preds: {sample_preds[:8]} ... {sample_preds[-8:]}")
@@ -1238,7 +1261,7 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
default_progress = torch.stack(default_progress)
else:
# Fallback to window-relative progress
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
# Apply rewind augmentation to each sample in batch independently
augmented_frames = []
@@ -1260,14 +1283,14 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
for attempt in range(max_attempts):
# Split point i: between frame 2 and T-1
i = torch.randint(2, T, (1,)).item()
i = torch.randint(2, T, (1,)).item()
# Rewind length k: between 1 and i-1 frames
if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3:
k = min(3, i - 1)
else:
k = torch.randint(1, i, (1,)).item()
k = min(k, i - 1)
# Rewind length k: between 1 and i-1 frames
if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3:
k = min(3, i - 1)
else:
k = torch.randint(1, i, (1,)).item()
k = min(k, i - 1)
# Create rewound sequence: frames[0:i] + reversed frames[i-k:i]
forward_length = i
@@ -1279,8 +1302,8 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
# Perfect fit!
forward_frames = frames[b, :i]
reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0])
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
# Create corresponding progress labels based on episode-relative positions
if window_frame_indices and episode_lengths:
# Use episode-relative progress for rewind
@@ -1298,9 +1321,9 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
rewound_progress = torch.cat([forward_progress, reverse_progress])
else:
# Fallback to window-relative progress
denom = max(T - 1, 1)
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
denom = max(T - 1, 1)
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
rewound_progress = torch.cat([forward_progress, reverse_progress])
success = True
@@ -1335,15 +1358,15 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
denom = max(T - 1, 1)
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device)
rewound_progress = torch.cat([forward_progress, reverse_progress])
rewound_progress = torch.cat([forward_progress, reverse_progress])
success = True
break
# If too long or can't fix, try again with different i,k
if success:
augmented_frames.append(rewound_seq)
augmented_progress.append(rewound_progress)
augmented_frames.append(rewound_seq)
augmented_progress.append(rewound_progress)
else:
# Fallback: use original sequence if we can't create a good rewind
augmented_frames.append(frames[b])