diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 25f832147..e8f1db5b1 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -149,6 +149,10 @@ class RLearNPolicy(PreTrainedPolicy): p.requires_grad = False for p in self.text_model.parameters(): p.requires_grad = False + + # Ensure frozen encoders run in eval mode (no dropout, stable outputs) + self.vision_model.eval() + self.text_model.eval() # x_transformers Decoder (matching ReWiND exactly) self.decoder = Decoder( @@ -373,17 +377,23 @@ class RLearNPolicy(PreTrainedPolicy): inputs = {k: v.to(device) for k, v in inputs.items()} # Process in batch through DINOv3 model - # DEBUGGING: Disable inference mode to check if it's causing caching issues - # context_manager = torch.inference_mode() if not self.training else nullcontext() - # with context_manager: - vision_outputs = self.vision_model(**inputs) + # Use inference mode for stable, fast frozen encoder forward + with torch.inference_mode(): + vision_outputs = self.vision_model(**inputs) - # Use pooler_output from DINOv3 (better than CLS token) - if hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None: + # Prefer mean-pooled patch tokens over pooler/CLS to ensure input-dependent variation + if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None: + tokens = vision_outputs.last_hidden_state # (BT, N_tokens, D) + if tokens.dim() == 3 and tokens.shape[1] > 1: + # Exclude CLS/reg token at index 0, average over patch tokens + vision_features_flat = tokens[:, 1:, :].mean(dim=1) + else: + # Fallback to first token if only one token is present + vision_features_flat = tokens[:, 0] + elif hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None: vision_features_flat = vision_outputs.pooler_output # (BT, D) else: - # Fallback to last hidden state CLS token if pooler_output not available - vision_features_flat = vision_outputs.last_hidden_state[:, 0] # (BT, D) + raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output") # Reshape to (B, T, D) for analysis vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)