remove dino vision encoder and simplify text and vision encoders by removing inheritance structure

This commit is contained in:
Bryson Jones
2025-12-10 11:09:37 -08:00
parent 55e19ff9a7
commit adabb37af6
4 changed files with 42 additions and 186 deletions

View File

@@ -19,99 +19,44 @@
Handles vision encoding, text encoding, robot state, and environment state.
"""
from abc import ABC, abstractmethod
import einops
import timm
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class BaseVisionEncoder(ABC):
"""Abstract base class for vision encoders."""
@abstractmethod
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to feature maps."""
pass
@abstractmethod
def get_output_shape(self) -> tuple:
"""Get the output shape (C', H', W')."""
pass
class DinoV3Encoder(nn.Module, BaseVisionEncoder):
"""DinoV3 vision encoder using the CLS token for global image representation."""
def __init__(self, config):
super().__init__()
self.config = config
self.model_name = config.backbone
# Create the timm model
self.model = timm.create_model(
self.model_name,
pretrained=True,
num_classes=0,
)
self.num_non_spatial_tokens = 5 # 1 CLS + 4 register
self.embed_dim = self.model.embed_dim
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to feature maps."""
# Extract all features
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
# Use only the CLS token (first token)
cls_token = features[:, 0] # (B, embed_dim)
b, embed_dim = cls_token.shape
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
return cls_features
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPEncoder(nn.Module, BaseVisionEncoder):
class CLIPVisionEncoder(nn.Module):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, config):
def __init__(self, model_name: str):
super().__init__()
self.config = config
self.model_name = config.backbone
self.model_name = model_name
# Create the timm model
self.model = timm.create_model(
self.model_name,
pretrained=True,
num_classes=0, # Remove classification head, we want features
)
# Load CLIP vision model from transformers
self.model = CLIPVisionModel.from_pretrained(self.model_name)
# CLIP models have 1 CLS token (no register tokens like DinoV3)
# CLIP models have 1 CLS token
self.num_non_spatial_tokens = 1
# Get embed_dim from model config
self.embed_dim = self.model.embed_dim
self.embed_dim = self.model.config.hidden_size
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token.
Preprocessing (resize, crop) is handled by ObservationEncoder
"""
# Extract all features
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
# Extract features using CLIPVisionModel
# Input: (B, C, H, W) - already preprocessed
outputs = self.model(pixel_values=x, output_hidden_states=False)
# Use only the CLS token (first token)
cls_token = features[:, 0] # (B, embed_dim)
# Extract CLS token from last_hidden_state (first token)
# last_hidden_state shape: (B, sequence_length, hidden_size)
cls_token = outputs.last_hidden_state[:, 0] # (B, embed_dim)
b, embed_dim = cls_token.shape
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
@@ -122,46 +67,6 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder):
return (self.embed_dim, 1, 1)
def create_vision_encoder(config) -> BaseVisionEncoder:
"""Create a vision encoder from config.
Supports any timm model with "clip" or "dinov3" in the backbone name.
The encoder type is automatically detected based on the backbone name.
"""
backbone_name = config.backbone.lower()
# Check if it's a CLIP model
if "clip" in backbone_name:
return CLIPEncoder(config)
# Check if it's a DinoV3 model
elif "dinov3" in backbone_name:
return DinoV3Encoder(config)
else:
raise ValueError(
f"Unsupported vision backbone: {config.backbone}. "
f"Currently supported: any timm model with 'dinov3' or 'clip' in the name"
)
# Registry for easy extension
VISION_ENCODER_REGISTRY: dict[str, type] = {
"dinov3": DinoV3Encoder,
"clip": CLIPEncoder,
}
def register_vision_encoder(name: str, encoder_class: type):
"""Register a new vision encoder type."""
VISION_ENCODER_REGISTRY[name] = encoder_class
def get_registered_encoders() -> dict[str, type]:
"""Get all registered vision encoder types."""
return VISION_ENCODER_REGISTRY.copy()
class CLIPTextEncoder(nn.Module):
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
@@ -231,11 +136,11 @@ class ObservationEncoder(nn.Module):
if vision_config.use_separate_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[create_vision_encoder(vision_config) for _ in self.camera_names]
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = create_vision_encoder(vision_config)
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
self.vision_encoders = None
else:
self.vision_encoder = None
@@ -290,7 +195,6 @@ class ObservationEncoder(nn.Module):
self.do_crop = False
def _setup_vector_output(self):
"""Setup for vector output."""
total_dim = 0
# Vision features - get CLS token feature dimension
@@ -384,11 +288,8 @@ class ObservationEncoder(nn.Module):
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
# Expand across temporal dimension to match other features
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
print("Text features shape after unsqueeze and expand:", text_features.shape)
conditioning_feats.append(text_features)
for vec in conditioning_feats:
print(f"Conditioning feature shape: {vec.shape}")
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)