mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
remove dino vision encoder and simplify text and vision encoders by removing inheritance structure
This commit is contained in:
@@ -19,99 +19,44 @@
|
||||
Handles vision encoding, text encoding, robot state, and environment state.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import einops
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch import Tensor
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class BaseVisionEncoder(ABC):
|
||||
"""Abstract base class for vision encoders."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to feature maps."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_output_shape(self) -> tuple:
|
||||
"""Get the output shape (C', H', W')."""
|
||||
pass
|
||||
|
||||
|
||||
class DinoV3Encoder(nn.Module, BaseVisionEncoder):
|
||||
"""DinoV3 vision encoder using the CLS token for global image representation."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_name = config.backbone
|
||||
|
||||
# Create the timm model
|
||||
self.model = timm.create_model(
|
||||
self.model_name,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
)
|
||||
|
||||
self.num_non_spatial_tokens = 5 # 1 CLS + 4 register
|
||||
self.embed_dim = self.model.embed_dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to feature maps."""
|
||||
# Extract all features
|
||||
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
||||
|
||||
# Use only the CLS token (first token)
|
||||
cls_token = features[:, 0] # (B, embed_dim)
|
||||
b, embed_dim = cls_token.shape
|
||||
|
||||
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
||||
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
|
||||
return cls_features
|
||||
|
||||
def get_output_shape(self) -> tuple:
|
||||
return (self.embed_dim, 1, 1)
|
||||
|
||||
|
||||
class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
||||
class CLIPVisionEncoder(nn.Module):
|
||||
"""CLIP vision encoder using the CLS token for global image representation."""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model_name = config.backbone
|
||||
self.model_name = model_name
|
||||
|
||||
# Create the timm model
|
||||
self.model = timm.create_model(
|
||||
self.model_name,
|
||||
pretrained=True,
|
||||
num_classes=0, # Remove classification head, we want features
|
||||
)
|
||||
# Load CLIP vision model from transformers
|
||||
self.model = CLIPVisionModel.from_pretrained(self.model_name)
|
||||
|
||||
# CLIP models have 1 CLS token (no register tokens like DinoV3)
|
||||
# CLIP models have 1 CLS token
|
||||
self.num_non_spatial_tokens = 1
|
||||
|
||||
# Get embed_dim from model config
|
||||
self.embed_dim = self.model.embed_dim
|
||||
self.embed_dim = self.model.config.hidden_size
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to CLS token.
|
||||
|
||||
Preprocessing (resize, crop) is handled by ObservationEncoder
|
||||
"""
|
||||
# Extract all features
|
||||
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
||||
# Extract features using CLIPVisionModel
|
||||
# Input: (B, C, H, W) - already preprocessed
|
||||
outputs = self.model(pixel_values=x, output_hidden_states=False)
|
||||
|
||||
# Use only the CLS token (first token)
|
||||
cls_token = features[:, 0] # (B, embed_dim)
|
||||
# Extract CLS token from last_hidden_state (first token)
|
||||
# last_hidden_state shape: (B, sequence_length, hidden_size)
|
||||
cls_token = outputs.last_hidden_state[:, 0] # (B, embed_dim)
|
||||
b, embed_dim = cls_token.shape
|
||||
|
||||
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
||||
@@ -122,46 +67,6 @@ class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
||||
return (self.embed_dim, 1, 1)
|
||||
|
||||
|
||||
def create_vision_encoder(config) -> BaseVisionEncoder:
|
||||
"""Create a vision encoder from config.
|
||||
|
||||
Supports any timm model with "clip" or "dinov3" in the backbone name.
|
||||
The encoder type is automatically detected based on the backbone name.
|
||||
"""
|
||||
backbone_name = config.backbone.lower()
|
||||
|
||||
# Check if it's a CLIP model
|
||||
if "clip" in backbone_name:
|
||||
return CLIPEncoder(config)
|
||||
|
||||
# Check if it's a DinoV3 model
|
||||
elif "dinov3" in backbone_name:
|
||||
return DinoV3Encoder(config)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported vision backbone: {config.backbone}. "
|
||||
f"Currently supported: any timm model with 'dinov3' or 'clip' in the name"
|
||||
)
|
||||
|
||||
|
||||
# Registry for easy extension
|
||||
VISION_ENCODER_REGISTRY: dict[str, type] = {
|
||||
"dinov3": DinoV3Encoder,
|
||||
"clip": CLIPEncoder,
|
||||
}
|
||||
|
||||
|
||||
def register_vision_encoder(name: str, encoder_class: type):
|
||||
"""Register a new vision encoder type."""
|
||||
VISION_ENCODER_REGISTRY[name] = encoder_class
|
||||
|
||||
|
||||
def get_registered_encoders() -> dict[str, type]:
|
||||
"""Get all registered vision encoder types."""
|
||||
return VISION_ENCODER_REGISTRY.copy()
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
|
||||
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
|
||||
@@ -231,11 +136,11 @@ class ObservationEncoder(nn.Module):
|
||||
|
||||
if vision_config.use_separate_encoder_per_camera:
|
||||
self.vision_encoders = nn.ModuleList(
|
||||
[create_vision_encoder(vision_config) for _ in self.camera_names]
|
||||
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
|
||||
)
|
||||
self.vision_encoder = None
|
||||
else:
|
||||
self.vision_encoder = create_vision_encoder(vision_config)
|
||||
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
|
||||
self.vision_encoders = None
|
||||
else:
|
||||
self.vision_encoder = None
|
||||
@@ -290,7 +195,6 @@ class ObservationEncoder(nn.Module):
|
||||
self.do_crop = False
|
||||
|
||||
def _setup_vector_output(self):
|
||||
"""Setup for vector output."""
|
||||
total_dim = 0
|
||||
|
||||
# Vision features - get CLS token feature dimension
|
||||
@@ -384,11 +288,8 @@ class ObservationEncoder(nn.Module):
|
||||
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
|
||||
# Expand across temporal dimension to match other features
|
||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
|
||||
print("Text features shape after unsqueeze and expand:", text_features.shape)
|
||||
conditioning_feats.append(text_features)
|
||||
|
||||
for vec in conditioning_feats:
|
||||
print(f"Conditioning feature shape: {vec.shape}")
|
||||
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
|
||||
|
||||
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)
|
||||
|
||||
Reference in New Issue
Block a user