fix dinov3

This commit is contained in:
Pepijn
2025-08-31 20:21:58 +02:00
parent 79c3466f0f
commit a1a3fa435d

View File

@@ -393,8 +393,24 @@ class RLearNPolicy(PreTrainedPolicy):
else:
raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output")
# Reshape to (B, T, D) for analysis
vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)
# Robustly reshape to (B, T, D): detect correct flatten order by maximizing temporal variance
try:
cand1 = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)
cand2 = rearrange(vision_features_flat, '(t b) d -> b t d', t=T, b=B)
# Compute mean temporal difference per sample
def mean_time_diff(x):
if T <= 1:
return torch.tensor(0.0, device=x.device)
diffs = (x[:, 1:, :] - x[:, :-1, :]).pow(2).sum(dim=-1).sqrt()
return diffs.mean()
diff1 = mean_time_diff(cand1)
diff2 = mean_time_diff(cand2)
vision_features = cand1 if diff1 >= diff2 else cand2
if self.training and torch.rand(1).item() < 0.05:
print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}")
except Exception:
# Fallback to default
vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)
# DEBUG: Analyze vision feature variability
if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging