diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index c6399226a..90c7d6d5d 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -393,8 +393,24 @@ class RLearNPolicy(PreTrainedPolicy): else: raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output") - # Reshape to (B, T, D) for analysis - vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T) + # Robustly reshape to (B, T, D): detect correct flatten order by maximizing temporal variance + try: + cand1 = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T) + cand2 = rearrange(vision_features_flat, '(t b) d -> b t d', t=T, b=B) + # Compute mean temporal difference per sample + def mean_time_diff(x): + if T <= 1: + return torch.tensor(0.0, device=x.device) + diffs = (x[:, 1:, :] - x[:, :-1, :]).pow(2).sum(dim=-1).sqrt() + return diffs.mean() + diff1 = mean_time_diff(cand1) + diff2 = mean_time_diff(cand2) + vision_features = cand1 if diff1 >= diff2 else cand2 + if self.training and torch.rand(1).item() < 0.05: + print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}") + except Exception: + # Fallback to default + vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T) # DEBUG: Analyze vision feature variability if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging