mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
expand the observation encoder to support differnt size encoders for vision and text
This commit is contained in:
@@ -123,6 +123,11 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
||||
|
||||
|
||||
def create_vision_encoder(config) -> BaseVisionEncoder:
|
||||
"""Create a vision encoder from config.
|
||||
|
||||
Supports any timm model with "clip" or "dinov3" in the backbone name.
|
||||
The encoder type is automatically detected based on the backbone name.
|
||||
"""
|
||||
backbone_name = config.backbone.lower()
|
||||
|
||||
# Check if it's a CLIP model
|
||||
@@ -136,7 +141,7 @@ def create_vision_encoder(config) -> BaseVisionEncoder:
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported vision backbone: {config.backbone}. "
|
||||
f"Currently supported: DinoV3 models and CLIP models"
|
||||
f"Currently supported: any timm model with 'dinov3' or 'clip' in the name"
|
||||
)
|
||||
|
||||
|
||||
@@ -148,26 +153,19 @@ VISION_ENCODER_REGISTRY: dict[str, type] = {
|
||||
|
||||
|
||||
def register_vision_encoder(name: str, encoder_class: type):
|
||||
"""Register a new vision encoder type.
|
||||
|
||||
Args:
|
||||
name: Identifier for the encoder type
|
||||
encoder_class: Class implementing BaseVisionEncoder interface
|
||||
"""
|
||||
"""Register a new vision encoder type."""
|
||||
VISION_ENCODER_REGISTRY[name] = encoder_class
|
||||
|
||||
|
||||
def get_registered_encoders() -> dict[str, type]:
|
||||
"""Get all registered vision encoder types.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping encoder names to classes
|
||||
"""
|
||||
"""Get all registered vision encoder types."""
|
||||
return VISION_ENCODER_REGISTRY.copy()
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""CLIP text encoder with frozen weights and learnable projection."""
|
||||
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
|
||||
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user