From 9dcb407ba7fe55c3ab8acc0ddd7a2e46f5ca666d Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 1 Sep 2025 10:27:58 +0200 Subject: [PATCH] siglip again --- .../policies/rlearn/configuration_rlearn.py | 6 +-- .../policies/rlearn/modeling_rlearn.py | 41 +++++++++---------- 2 files changed, 22 insertions(+), 25 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 0989c0c89..5bb9141f7 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -38,9 +38,9 @@ class RLearNConfig(PreTrainedConfig): lightweight temporal aggregator + head. """ - # Encoders - Using DINOv3 for vision and SigLIP2 for text - vision_model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m" - text_model_name: str = "google/siglip2-base-patch16-224" + # Encoders - Use SigLIP2 for both vision and text (shared checkpoint) + vision_model_name: str = "google/siglip2-base-patch16-512" + text_model_name: str = "google/siglip2-base-patch16-512" freeze_backbones: bool = True # Sequence length, amount of past frames including current one to use in the temporal model diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 5e5c583c5..c61658735 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -41,8 +41,8 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig class RLearNPolicy(PreTrainedPolicy): """Video-language conditioned reward model following ReWiND architecture: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11. - - Visual encoder: frozen DinoV3 encoder, returns per-frame embeddings. - - Text encoder: frozen SigLIP2, returns a language embedding. + - Visual encoder: frozen SigLIP2 vision tower, returns per-frame patch embeddings. + - Text encoder: frozen SigLIP2 text tower, returns language token embeddings. """ @@ -54,16 +54,15 @@ class RLearNPolicy(PreTrainedPolicy): self.config = config self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation - # Encoders - DINOv3 for vision, SigLIP2 for text - from transformers import AutoProcessor, AutoModel, AutoImageProcessor + # Encoders - SigLIP2 shared checkpoint for vision and text + from transformers import AutoProcessor, AutoModel - # Load DINOv3 processor and model for vision - self.vision_processor = AutoImageProcessor.from_pretrained(config.vision_model_name) - self.vision_model = AutoModel.from_pretrained(config.vision_model_name) - - # Load SigLIP2 processor and model for text - self.text_processor = AutoProcessor.from_pretrained(config.text_model_name, use_fast=True) - self.text_model = AutoModel.from_pretrained(config.text_model_name) + # Shared processor handles both images and text + self.processor = AutoProcessor.from_pretrained(config.vision_model_name, use_fast=True) + # Shared model exposes .vision_model and .text_model + self.siglip_model = AutoModel.from_pretrained(config.vision_model_name) + self.vision_model = self.siglip_model.vision_model + self.text_model = self.siglip_model # Move encoders to GPU if available if torch.cuda.is_available(): @@ -71,11 +70,10 @@ class RLearNPolicy(PreTrainedPolicy): self.text_model = self.text_model.to('cuda') # Get hidden sizes from models - # DINOv3-ViTL16 has hidden_size directly in config - self.vision_hidden = getattr(self.vision_model.config, 'hidden_size', 1024) # DINOv3-large default - - # SigLIP2 text model - th = getattr(getattr(self.text_model, 'config', None), 'text_config', None) + # SigLIP2 hidden sizes + self.vision_hidden = getattr(getattr(self.siglip_model, 'config', None), 'vision_config', None) + self.vision_hidden = getattr(self.vision_hidden, 'hidden_size', getattr(self.vision_model.config, 'hidden_size', 768)) + th = getattr(getattr(self.siglip_model, 'config', None), 'text_config', None) self.text_hidden = getattr(th, 'hidden_size', 512) # Freeze encoders if requested @@ -201,12 +199,11 @@ class RLearNPolicy(PreTrainedPolicy): flat_numpy = flat_uint8.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) uint8 images_list = [flat_numpy[i] for i in range(B * T)] - # Process through DINOv3 processor and model - # Disable center-crop to preserve motion near borders; allow default rescale/normalize - inputs = self.vision_processor(images=images_list, return_tensors="pt", do_center_crop=False) + # Process images with SigLIP2 processor + inputs = self.processor(images=images_list, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} - # Process in batch through DINOv3 model + # Process in batch through SigLIP2 vision tower vision_outputs = self.vision_model(**inputs) # Prefer patch tokens from last_hidden_state (exclude CLS at index 0) @@ -659,8 +656,8 @@ class RLearNPolicy(PreTrainedPolicy): embeddings: (B, L, D); mask: (B, L) True for valid tokens. """ # Optimized: Process all commands in batch (much faster than individual processing) - proc = self.text_processor( - text=commands, + proc = self.processor( + text=commands, return_tensors='pt', padding='max_length', max_length=64,