mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
fix dinov3
This commit is contained in:
@@ -149,6 +149,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
p.requires_grad = False
|
||||
for p in self.text_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# Ensure frozen encoders run in eval mode (no dropout, stable outputs)
|
||||
self.vision_model.eval()
|
||||
self.text_model.eval()
|
||||
|
||||
# x_transformers Decoder (matching ReWiND exactly)
|
||||
self.decoder = Decoder(
|
||||
@@ -373,17 +377,23 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Process in batch through DINOv3 model
|
||||
# DEBUGGING: Disable inference mode to check if it's causing caching issues
|
||||
# context_manager = torch.inference_mode() if not self.training else nullcontext()
|
||||
# with context_manager:
|
||||
vision_outputs = self.vision_model(**inputs)
|
||||
# Use inference mode for stable, fast frozen encoder forward
|
||||
with torch.inference_mode():
|
||||
vision_outputs = self.vision_model(**inputs)
|
||||
|
||||
# Use pooler_output from DINOv3 (better than CLS token)
|
||||
if hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None:
|
||||
# Prefer mean-pooled patch tokens over pooler/CLS to ensure input-dependent variation
|
||||
if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None:
|
||||
tokens = vision_outputs.last_hidden_state # (BT, N_tokens, D)
|
||||
if tokens.dim() == 3 and tokens.shape[1] > 1:
|
||||
# Exclude CLS/reg token at index 0, average over patch tokens
|
||||
vision_features_flat = tokens[:, 1:, :].mean(dim=1)
|
||||
else:
|
||||
# Fallback to first token if only one token is present
|
||||
vision_features_flat = tokens[:, 0]
|
||||
elif hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None:
|
||||
vision_features_flat = vision_outputs.pooler_output # (BT, D)
|
||||
else:
|
||||
# Fallback to last hidden state CLS token if pooler_output not available
|
||||
vision_features_flat = vision_outputs.last_hidden_state[:, 0] # (BT, D)
|
||||
raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output")
|
||||
|
||||
# Reshape to (B, T, D) for analysis
|
||||
vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T)
|
||||
|
||||
Reference in New Issue
Block a user