mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
fix dinov3
This commit is contained in:
@@ -393,8 +393,24 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
else:
|
||||
raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output")
|
||||
|
||||
# Reshape to (B, T, D) for analysis
|
||||
vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)
|
||||
# Robustly reshape to (B, T, D): detect correct flatten order by maximizing temporal variance
|
||||
try:
|
||||
cand1 = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)
|
||||
cand2 = rearrange(vision_features_flat, '(t b) d -> b t d', t=T, b=B)
|
||||
# Compute mean temporal difference per sample
|
||||
def mean_time_diff(x):
|
||||
if T <= 1:
|
||||
return torch.tensor(0.0, device=x.device)
|
||||
diffs = (x[:, 1:, :] - x[:, :-1, :]).pow(2).sum(dim=-1).sqrt()
|
||||
return diffs.mean()
|
||||
diff1 = mean_time_diff(cand1)
|
||||
diff2 = mean_time_diff(cand2)
|
||||
vision_features = cand1 if diff1 >= diff2 else cand2
|
||||
if self.training and torch.rand(1).item() < 0.05:
|
||||
print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
|
||||
except Exception:
|
||||
# Fallback to default
|
||||
vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)
|
||||
|
||||
# DEBUG: Analyze vision feature variability
|
||||
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging
|
||||
|
||||
Reference in New Issue
Block a user