siglip again

This commit is contained in:
Pepijn
2025-09-01 10:27:58 +02:00
parent 5eb5bf7164
commit 9dcb407ba7
2 changed files with 22 additions and 25 deletions

View File

@@ -38,9 +38,9 @@ class RLearNConfig(PreTrainedConfig):
lightweight temporal aggregator + head.
"""
# Encoders - Using DINOv3 for vision and SigLIP2 for text
vision_model_name: str = "facebook/dinov3-vitb16-pretrain-lvd1689m"
text_model_name: str = "google/siglip2-base-patch16-224"
# Encoders - Use SigLIP2 for both vision and text (shared checkpoint)
vision_model_name: str = "google/siglip2-base-patch16-512"
text_model_name: str = "google/siglip2-base-patch16-512"
freeze_backbones: bool = True
# Sequence length, amount of past frames including current one to use in the temporal model

View File

@@ -41,8 +41,8 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
class RLearNPolicy(PreTrainedPolicy):
"""Video-language conditioned reward model following ReWiND architecture: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
- Visual encoder: frozen DinoV3 encoder, returns per-frame embeddings.
- Text encoder: frozen SigLIP2, returns a language embedding.
- Visual encoder: frozen SigLIP2 vision tower, returns per-frame patch embeddings.
- Text encoder: frozen SigLIP2 text tower, returns language token embeddings.
"""
@@ -54,16 +54,15 @@ class RLearNPolicy(PreTrainedPolicy):
self.config = config
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
# Encoders - DINOv3 for vision, SigLIP2 for text
from transformers import AutoProcessor, AutoModel, AutoImageProcessor
# Encoders - SigLIP2 shared checkpoint for vision and text
from transformers import AutoProcessor, AutoModel
# Load DINOv3 processor and model for vision
self.vision_processor = AutoImageProcessor.from_pretrained(config.vision_model_name)
self.vision_model = AutoModel.from_pretrained(config.vision_model_name)
# Load SigLIP2 processor and model for text
self.text_processor = AutoProcessor.from_pretrained(config.text_model_name, use_fast=True)
self.text_model = AutoModel.from_pretrained(config.text_model_name)
# Shared processor handles both images and text
self.processor = AutoProcessor.from_pretrained(config.vision_model_name, use_fast=True)
# Shared model exposes .vision_model and .text_model
self.siglip_model = AutoModel.from_pretrained(config.vision_model_name)
self.vision_model = self.siglip_model.vision_model
self.text_model = self.siglip_model
# Move encoders to GPU if available
if torch.cuda.is_available():
@@ -71,11 +70,10 @@ class RLearNPolicy(PreTrainedPolicy):
self.text_model = self.text_model.to('cuda')
# Get hidden sizes from models
# DINOv3-ViTL16 has hidden_size directly in config
self.vision_hidden = getattr(self.vision_model.config, 'hidden_size', 1024) # DINOv3-large default
# SigLIP2 text model
th = getattr(getattr(self.text_model, 'config', None), 'text_config', None)
# SigLIP2 hidden sizes
self.vision_hidden = getattr(getattr(self.siglip_model, 'config', None), 'vision_config', None)
self.vision_hidden = getattr(self.vision_hidden, 'hidden_size', getattr(self.vision_model.config, 'hidden_size', 768))
th = getattr(getattr(self.siglip_model, 'config', None), 'text_config', None)
self.text_hidden = getattr(th, 'hidden_size', 512)
# Freeze encoders if requested
@@ -201,12 +199,11 @@ class RLearNPolicy(PreTrainedPolicy):
flat_numpy = flat_uint8.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) uint8
images_list = [flat_numpy[i] for i in range(B * T)]
# Process through DINOv3 processor and model
# Disable center-crop to preserve motion near borders; allow default rescale/normalize
inputs = self.vision_processor(images=images_list, return_tensors="pt", do_center_crop=False)
# Process images with SigLIP2 processor
inputs = self.processor(images=images_list, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Process in batch through DINOv3 model
# Process in batch through SigLIP2 vision tower
vision_outputs = self.vision_model(**inputs)
# Prefer patch tokens from last_hidden_state (exclude CLS at index 0)
@@ -659,8 +656,8 @@ class RLearNPolicy(PreTrainedPolicy):
embeddings: (B, L, D); mask: (B, L) True for valid tokens.
"""
# Optimized: Process all commands in batch (much faster than individual processing)
proc = self.text_processor(
text=commands,
proc = self.processor(
text=commands,
return_tensors='pt',
padding='max_length',
max_length=64,