expand the observation encoder to support differnt size encoders for vision and text

This commit is contained in:
Bryson Jones
2025-11-21 14:31:35 -08:00
parent ab97d5c019
commit 8b9fada80f
2 changed files with 26 additions and 24 deletions

View File

@@ -123,6 +123,11 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder):
def create_vision_encoder(config) -> BaseVisionEncoder:
"""Create a vision encoder from config.
Supports any timm model with "clip" or "dinov3" in the backbone name.
The encoder type is automatically detected based on the backbone name.
"""
backbone_name = config.backbone.lower()
# Check if it's a CLIP model
@@ -136,7 +141,7 @@ def create_vision_encoder(config) -> BaseVisionEncoder:
else:
raise ValueError(
f"Unsupported vision backbone: {config.backbone}. "
f"Currently supported: DinoV3 models and CLIP models"
f"Currently supported: any timm model with 'dinov3' or 'clip' in the name"
)
@@ -148,26 +153,19 @@ VISION_ENCODER_REGISTRY: dict[str, type] = {
def register_vision_encoder(name: str, encoder_class: type):
"""Register a new vision encoder type.
Args:
name: Identifier for the encoder type
encoder_class: Class implementing BaseVisionEncoder interface
"""
"""Register a new vision encoder type."""
VISION_ENCODER_REGISTRY[name] = encoder_class
def get_registered_encoders() -> dict[str, type]:
"""Get all registered vision encoder types.
Returns:
Dictionary mapping encoder names to classes
"""
"""Get all registered vision encoder types."""
return VISION_ENCODER_REGISTRY.copy()
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and learnable projection."""
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
"""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()