From 8b9fada80ff5733f7dc86df53658dba0acd21a9f Mon Sep 17 00:00:00 2001 From: Bryson Jones Date: Fri, 21 Nov 2025 14:31:35 -0800 Subject: [PATCH] expand the observation encoder to support differnt size encoders for vision and text --- .../configuration_multi_task_dit.py | 26 +++++++++++-------- .../modules/observation_encoder.py | 24 ++++++++--------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py index 5a6a71f37..e4b576445 100644 --- a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -232,20 +232,18 @@ class DinoV3EncoderConfig(VisionEncoderConfig): DinoV3 is a self-supervised Vision Transformer trained by Meta. CLS token usage and spatial feature extraction are handled automatically. - Available backbones: + Any timm model with "dinov3" in the name can be used. Examples: - vit_base_patch16_dinov3.lvd1689m (768 dims) + - vit_large_patch14_dinov3.lvd142m (1024 dims) """ backbone: str = "vit_base_patch16_dinov3.lvd1689m" def __post_init__(self): super().__post_init__() - # Validate backbone name - valid_backbones = [ - "vit_base_patch16_dinov3.lvd1689m", - ] - if self.backbone not in valid_backbones: - raise ValueError(f"backbone must be one of {valid_backbones}, got '{self.backbone}'") + # Validate that backbone name contains "dinov3" to ensure correct encoder type + if "dinov3" not in self.backbone.lower(): + raise ValueError(f"backbone must be a DinoV3 model (contain 'dinov3'), got '{self.backbone}'") @VisionEncoderConfig.register_subclass("clip") @@ -258,17 +256,18 @@ class CLIPVisionEncoderConfig(VisionEncoderConfig): CLIP's internal preprocessing (resize to 224x224) can be overridden by setting resize_shape and crop_shape. - Available backbones: + Any timm model with "clip" in the name can be used. Examples: - vit_base_patch16_clip_224.openai (default, 768 dims, 14x14 patches for 224x224) + - vit_large_patch14_clip_224.openai (1024 dims) """ backbone: str = "vit_base_patch16_clip_224.openai" def __post_init__(self): super().__post_init__() - # Validate backbone name + # Validate that backbone name contains "clip" to ensure correct encoder type if "clip" not in self.backbone.lower(): - raise ValueError(f"backbone must be a CLIP model, got '{self.backbone}'") + raise ValueError(f"backbone must be a CLIP model (contain 'clip'), got '{self.backbone}'") @dataclass @@ -294,14 +293,19 @@ class CLIPTextEncoderConfig(TextEncoderConfig): used to condition the policy. The text embeddings are processed by a learnable projection layer before being concatenated into the conditioning vector. + + Any HuggingFace CLIP model can be used. Examples: + - openai/clip-vit-base-patch16 (default) + - openai/clip-vit-large-patch14 """ model: str = "openai/clip-vit-base-patch16" def __post_init__(self): super().__post_init__() + # Validate that model name contains "clip" to ensure correct encoder type if "clip" not in self.model.lower(): - raise ValueError(f"CLIP text encoder requires a CLIP model. Got '{self.model}'") + raise ValueError(f"CLIP text encoder requires a CLIP model (contain 'clip'). Got '{self.model}'") @dataclass diff --git a/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py b/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py index de39d8ee8..1dd7ec43d 100644 --- a/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py +++ b/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py @@ -123,6 +123,11 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder): def create_vision_encoder(config) -> BaseVisionEncoder: + """Create a vision encoder from config. + + Supports any timm model with "clip" or "dinov3" in the backbone name. + The encoder type is automatically detected based on the backbone name. + """ backbone_name = config.backbone.lower() # Check if it's a CLIP model @@ -136,7 +141,7 @@ def create_vision_encoder(config) -> BaseVisionEncoder: else: raise ValueError( f"Unsupported vision backbone: {config.backbone}. " - f"Currently supported: DinoV3 models and CLIP models" + f"Currently supported: any timm model with 'dinov3' or 'clip' in the name" ) @@ -148,26 +153,19 @@ VISION_ENCODER_REGISTRY: dict[str, type] = { def register_vision_encoder(name: str, encoder_class: type): - """Register a new vision encoder type. - - Args: - name: Identifier for the encoder type - encoder_class: Class implementing BaseVisionEncoder interface - """ + """Register a new vision encoder type.""" VISION_ENCODER_REGISTRY[name] = encoder_class def get_registered_encoders() -> dict[str, type]: - """Get all registered vision encoder types. - - Returns: - Dictionary mapping encoder names to classes - """ + """Get all registered vision encoder types.""" return VISION_ENCODER_REGISTRY.copy() class CLIPTextEncoder(nn.Module): - """CLIP text encoder with frozen weights and learnable projection.""" + """Supports any HuggingFace CLIP model. The encoder weights are frozen, + and a learnable projection layer maps the CLIP embeddings to the desired dimension. + """ def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512): super().__init__()