fix dinov3

This commit is contained in:
Pepijn
2025-08-31 19:41:16 +02:00
parent 16e82fd29f
commit e1d433cbfc

View File

@@ -149,6 +149,10 @@ class RLearNPolicy(PreTrainedPolicy):
p.requires_grad = False
for p in self.text_model.parameters():
p.requires_grad = False
# Ensure frozen encoders run in eval mode (no dropout, stable outputs)
self.vision_model.eval()
self.text_model.eval()
# x_transformers Decoder (matching ReWiND exactly)
self.decoder = Decoder(
@@ -373,17 +377,23 @@ class RLearNPolicy(PreTrainedPolicy):
inputs = {k: v.to(device) for k, v in inputs.items()}
# Process in batch through DINOv3 model
# DEBUGGING: Disable inference mode to check if it's causing caching issues
# context_manager = torch.inference_mode() if not self.training else nullcontext()
# with context_manager:
vision_outputs = self.vision_model(**inputs)
# Use inference mode for stable, fast frozen encoder forward
with torch.inference_mode():
vision_outputs = self.vision_model(**inputs)
# Use pooler_output from DINOv3 (better than CLS token)
if hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None:
# Prefer mean-pooled patch tokens over pooler/CLS to ensure input-dependent variation
if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None:
tokens = vision_outputs.last_hidden_state # (BT, N_tokens, D)
if tokens.dim() == 3 and tokens.shape[1] > 1:
# Exclude CLS/reg token at index 0, average over patch tokens
vision_features_flat = tokens[:, 1:, :].mean(dim=1)
else:
# Fallback to first token if only one token is present
vision_features_flat = tokens[:, 0]
elif hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None:
vision_features_flat = vision_outputs.pooler_output # (BT, D)
else:
# Fallback to last hidden state CLS token if pooler_output not available
vision_features_flat = vision_outputs.last_hidden_state[:, 0] # (BT, D)
raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output")
# Reshape to (B, T, D) for analysis
vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)