mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
fix
This commit is contained in:
@@ -314,13 +314,15 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
if flat.max() > 1.0:
|
||||
flat = flat / 255.0
|
||||
|
||||
# DINOv3 expects images in [0, 1] range, RGB format
|
||||
# Convert tensor to list of PIL-like arrays for processor
|
||||
flat_numpy = flat.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C)
|
||||
# Prepare inputs for processor: feed uint8 to avoid double-rescale; let processor rescale/normalize
|
||||
flat_clamped = flat.clamp(0.0, 1.0)
|
||||
flat_uint8 = (flat_clamped * 255.0).round().to(torch.uint8) # (BT, C, H, W)
|
||||
flat_numpy = flat_uint8.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) uint8
|
||||
images_list = [flat_numpy[i] for i in range(B * T)]
|
||||
|
||||
|
||||
# Process through DINOv3 processor and model
|
||||
inputs = self.vision_processor(images=images_list, return_tensors="pt")
|
||||
# Disable center-crop to preserve motion near borders; allow default rescale/normalize
|
||||
inputs = self.vision_processor(images=images_list, return_tensors="pt", do_center_crop=False)
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Process in batch through DINOv3 model
|
||||
@@ -411,6 +413,27 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
feature_mean = vision_features.mean().item()
|
||||
feature_std = vision_features.std().item()
|
||||
print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}")
|
||||
|
||||
# Extra DIAGNOSTIC: CLS vs patch mean/max deltas for one sample, two far-apart frames
|
||||
try:
|
||||
if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2:
|
||||
# Recover CLS tokens
|
||||
cls_flat = tokens[:, 0, :] # (BT, D)
|
||||
cls = rearrange(cls_flat, '(b t) d -> b t d', b=B, t=T)
|
||||
b0 = 0
|
||||
f0, f1 = 0, T - 1
|
||||
# L2 between CLS at two frames
|
||||
cls_l2 = (cls[b0, f1] - cls[b0, f0]).pow(2).sum().sqrt().item()
|
||||
# Patch mean L2
|
||||
pm_f0 = patch_features[b0, f0].mean(dim=0)
|
||||
pm_f1 = patch_features[b0, f1].mean(dim=0)
|
||||
pm_l2 = (pm_f1 - pm_f0).pow(2).sum().sqrt().item()
|
||||
# Max over patches L2
|
||||
per_patch_l2 = (patch_features[b0, f1] - patch_features[b0, f0]).pow(2).sum(dim=1).sqrt()
|
||||
max_p_l2 = per_patch_l2.max().item()
|
||||
print(f"CLS ΔL2: {cls_l2:.6f} | mean(patches) ΔL2: {pm_l2:.6f} | max(patch) ΔL2: {max_p_l2:.6f}")
|
||||
except Exception as _:
|
||||
pass
|
||||
|
||||
# Check temporal variance for each sample
|
||||
for b_idx in range(min(B, 2)): # Debug first 2 samples
|
||||
@@ -631,7 +654,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
L_mismatch = torch.zeros((), device=device)
|
||||
if self.training and B > 1 and torch.rand(1, device=device).item() < self.config.mismatch_prob:
|
||||
# Create actual mismatches - ensure shuffled language != original language
|
||||
shuffled_indices = torch.randperm(B, device=device)
|
||||
shuffled_indices = torch.randperm(B, device=device)
|
||||
|
||||
# Find which samples actually got different languages
|
||||
mismatch_mask = []
|
||||
@@ -731,7 +754,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
pred_std = sample_preds.std()
|
||||
target_std = sample_targets.std()
|
||||
print(f" Variation - Target std: {target_std:.4f} | Pred std: {pred_std:.4f}")
|
||||
else:
|
||||
else:
|
||||
# For longer sequences, show first 8 and last 8
|
||||
print(f" Targets: {sample_targets[:8]} ... {sample_targets[-8:]}")
|
||||
print(f" Preds: {sample_preds[:8]} ... {sample_preds[-8:]}")
|
||||
@@ -1238,7 +1261,7 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
default_progress = torch.stack(default_progress)
|
||||
else:
|
||||
# Fallback to window-relative progress
|
||||
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
|
||||
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
# Apply rewind augmentation to each sample in batch independently
|
||||
augmented_frames = []
|
||||
@@ -1260,14 +1283,14 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
# Split point i: between frame 2 and T-1
|
||||
i = torch.randint(2, T, (1,)).item()
|
||||
i = torch.randint(2, T, (1,)).item()
|
||||
|
||||
# Rewind length k: between 1 and i-1 frames
|
||||
if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3:
|
||||
k = min(3, i - 1)
|
||||
else:
|
||||
k = torch.randint(1, i, (1,)).item()
|
||||
k = min(k, i - 1)
|
||||
# Rewind length k: between 1 and i-1 frames
|
||||
if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3:
|
||||
k = min(3, i - 1)
|
||||
else:
|
||||
k = torch.randint(1, i, (1,)).item()
|
||||
k = min(k, i - 1)
|
||||
|
||||
# Create rewound sequence: frames[0:i] + reversed frames[i-k:i]
|
||||
forward_length = i
|
||||
@@ -1279,8 +1302,8 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
# Perfect fit!
|
||||
forward_frames = frames[b, :i]
|
||||
reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0])
|
||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||
|
||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||
|
||||
# Create corresponding progress labels based on episode-relative positions
|
||||
if window_frame_indices and episode_lengths:
|
||||
# Use episode-relative progress for rewind
|
||||
@@ -1298,9 +1321,9 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
else:
|
||||
# Fallback to window-relative progress
|
||||
denom = max(T - 1, 1)
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
|
||||
denom = max(T - 1, 1)
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
|
||||
success = True
|
||||
@@ -1335,15 +1358,15 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
denom = max(T - 1, 1)
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
|
||||
success = True
|
||||
break
|
||||
# If too long or can't fix, try again with different i,k
|
||||
|
||||
if success:
|
||||
augmented_frames.append(rewound_seq)
|
||||
augmented_progress.append(rewound_progress)
|
||||
augmented_frames.append(rewound_seq)
|
||||
augmented_progress.append(rewound_progress)
|
||||
else:
|
||||
# Fallback: use original sequence if we can't create a good rewind
|
||||
augmented_frames.append(frames[b])
|
||||
|
||||
Reference in New Issue
Block a user