mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
siglip again
This commit is contained in:
@@ -38,9 +38,9 @@ class RLearNConfig(PreTrainedConfig):
|
||||
lightweight temporal aggregator + head.
|
||||
"""
|
||||
|
||||
# Encoders - Using DINOv3 for vision and SigLIP2 for text
|
||||
vision_model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m"
|
||||
text_model_name: str = "google/siglip2-base-patch16-224"
|
||||
# Encoders - Use SigLIP2 for both vision and text (shared checkpoint)
|
||||
vision_model_name: str = "google/siglip2-base-patch16-512"
|
||||
text_model_name: str = "google/siglip2-base-patch16-512"
|
||||
freeze_backbones: bool = True
|
||||
|
||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||
|
||||
@@ -41,8 +41,8 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
class RLearNPolicy(PreTrainedPolicy):
|
||||
"""Video-language conditioned reward model following ReWiND architecture: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
|
||||
|
||||
- Visual encoder: frozen DinoV3 encoder, returns per-frame embeddings.
|
||||
- Text encoder: frozen SigLIP2, returns a language embedding.
|
||||
- Visual encoder: frozen SigLIP2 vision tower, returns per-frame patch embeddings.
|
||||
- Text encoder: frozen SigLIP2 text tower, returns language token embeddings.
|
||||
|
||||
"""
|
||||
|
||||
@@ -54,16 +54,15 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
|
||||
|
||||
# Encoders - DINOv3 for vision, SigLIP2 for text
|
||||
from transformers import AutoProcessor, AutoModel, AutoImageProcessor
|
||||
# Encoders - SigLIP2 shared checkpoint for vision and text
|
||||
from transformers import AutoProcessor, AutoModel
|
||||
|
||||
# Load DINOv3 processor and model for vision
|
||||
self.vision_processor = AutoImageProcessor.from_pretrained(config.vision_model_name)
|
||||
self.vision_model = AutoModel.from_pretrained(config.vision_model_name)
|
||||
|
||||
# Load SigLIP2 processor and model for text
|
||||
self.text_processor = AutoProcessor.from_pretrained(config.text_model_name, use_fast=True)
|
||||
self.text_model = AutoModel.from_pretrained(config.text_model_name)
|
||||
# Shared processor handles both images and text
|
||||
self.processor = AutoProcessor.from_pretrained(config.vision_model_name, use_fast=True)
|
||||
# Shared model exposes .vision_model and .text_model
|
||||
self.siglip_model = AutoModel.from_pretrained(config.vision_model_name)
|
||||
self.vision_model = self.siglip_model.vision_model
|
||||
self.text_model = self.siglip_model
|
||||
|
||||
# Move encoders to GPU if available
|
||||
if torch.cuda.is_available():
|
||||
@@ -71,11 +70,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
self.text_model = self.text_model.to('cuda')
|
||||
|
||||
# Get hidden sizes from models
|
||||
# DINOv3-ViTL16 has hidden_size directly in config
|
||||
self.vision_hidden = getattr(self.vision_model.config, 'hidden_size', 1024) # DINOv3-large default
|
||||
|
||||
# SigLIP2 text model
|
||||
th = getattr(getattr(self.text_model, 'config', None), 'text_config', None)
|
||||
# SigLIP2 hidden sizes
|
||||
self.vision_hidden = getattr(getattr(self.siglip_model, 'config', None), 'vision_config', None)
|
||||
self.vision_hidden = getattr(self.vision_hidden, 'hidden_size', getattr(self.vision_model.config, 'hidden_size', 768))
|
||||
th = getattr(getattr(self.siglip_model, 'config', None), 'text_config', None)
|
||||
self.text_hidden = getattr(th, 'hidden_size', 512)
|
||||
|
||||
# Freeze encoders if requested
|
||||
@@ -201,12 +199,11 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
flat_numpy = flat_uint8.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) uint8
|
||||
images_list = [flat_numpy[i] for i in range(B * T)]
|
||||
|
||||
# Process through DINOv3 processor and model
|
||||
# Disable center-crop to preserve motion near borders; allow default rescale/normalize
|
||||
inputs = self.vision_processor(images=images_list, return_tensors="pt", do_center_crop=False)
|
||||
# Process images with SigLIP2 processor
|
||||
inputs = self.processor(images=images_list, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Process in batch through DINOv3 model
|
||||
# Process in batch through SigLIP2 vision tower
|
||||
vision_outputs = self.vision_model(**inputs)
|
||||
|
||||
# Prefer patch tokens from last_hidden_state (exclude CLS at index 0)
|
||||
@@ -659,8 +656,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
embeddings: (B, L, D); mask: (B, L) True for valid tokens.
|
||||
"""
|
||||
# Optimized: Process all commands in batch (much faster than individual processing)
|
||||
proc = self.text_processor(
|
||||
text=commands,
|
||||
proc = self.processor(
|
||||
text=commands,
|
||||
return_tensors='pt',
|
||||
padding='max_length',
|
||||
max_length=64,
|
||||
|
||||
Reference in New Issue
Block a user