mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
397 lines
14 KiB
Python
397 lines
14 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Observation encoding for Multi-Task DiT policy.
|
|
|
|
Handles vision encoding, text encoding, robot state, and environment state.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
import einops
|
|
import timm
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
from torch import Tensor
|
|
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
|
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
|
|
|
|
|
class BaseVisionEncoder(ABC):
|
|
"""Abstract base class for vision encoders."""
|
|
|
|
@abstractmethod
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Encode RGB image to feature maps."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_output_shape(self) -> tuple:
|
|
"""Get the output shape (C', H', W')."""
|
|
pass
|
|
|
|
|
|
class DinoV3Encoder(nn.Module, BaseVisionEncoder):
|
|
"""DinoV3 vision encoder using the CLS token for global image representation."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.model_name = config.backbone
|
|
|
|
# Create the timm model
|
|
self.model = timm.create_model(
|
|
self.model_name,
|
|
pretrained=True,
|
|
num_classes=0,
|
|
)
|
|
|
|
self.num_non_spatial_tokens = 5 # 1 CLS + 4 register
|
|
self.embed_dim = self.model.embed_dim
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Encode RGB image to feature maps."""
|
|
# Extract all features
|
|
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
|
|
|
# Use only the CLS token (first token)
|
|
cls_token = features[:, 0] # (B, embed_dim)
|
|
b, embed_dim = cls_token.shape
|
|
|
|
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
|
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
|
|
return cls_features
|
|
|
|
def get_output_shape(self) -> tuple:
|
|
return (self.embed_dim, 1, 1)
|
|
|
|
|
|
class CLIPEncoder(nn.Module, BaseVisionEncoder):
|
|
"""CLIP vision encoder using the CLS token for global image representation."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.model_name = config.backbone
|
|
|
|
# Create the timm model
|
|
self.model = timm.create_model(
|
|
self.model_name,
|
|
pretrained=True,
|
|
num_classes=0, # Remove classification head, we want features
|
|
)
|
|
|
|
# CLIP models have 1 CLS token (no register tokens like DinoV3)
|
|
self.num_non_spatial_tokens = 1
|
|
|
|
# Get embed_dim from model config
|
|
self.embed_dim = self.model.embed_dim
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Encode RGB image to CLS token.
|
|
|
|
Preprocessing (resize, crop) is handled by ObservationEncoder
|
|
"""
|
|
# Extract all features
|
|
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
|
|
|
|
# Use only the CLS token (first token)
|
|
cls_token = features[:, 0] # (B, embed_dim)
|
|
b, embed_dim = cls_token.shape
|
|
|
|
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
|
|
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
|
|
return cls_features
|
|
|
|
def get_output_shape(self) -> tuple:
|
|
return (self.embed_dim, 1, 1)
|
|
|
|
|
|
def create_vision_encoder(config) -> BaseVisionEncoder:
|
|
backbone_name = config.backbone.lower()
|
|
|
|
# Check if it's a CLIP model
|
|
if "clip" in backbone_name:
|
|
return CLIPEncoder(config)
|
|
|
|
# Check if it's a DinoV3 model
|
|
elif "dinov3" in backbone_name:
|
|
return DinoV3Encoder(config)
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported vision backbone: {config.backbone}. "
|
|
f"Currently supported: DinoV3 models and CLIP models"
|
|
)
|
|
|
|
|
|
# Registry for easy extension
|
|
VISION_ENCODER_REGISTRY: dict[str, type] = {
|
|
"dinov3": DinoV3Encoder,
|
|
"clip": CLIPEncoder,
|
|
}
|
|
|
|
|
|
def register_vision_encoder(name: str, encoder_class: type):
|
|
"""Register a new vision encoder type.
|
|
|
|
Args:
|
|
name: Identifier for the encoder type
|
|
encoder_class: Class implementing BaseVisionEncoder interface
|
|
"""
|
|
VISION_ENCODER_REGISTRY[name] = encoder_class
|
|
|
|
|
|
def get_registered_encoders() -> dict[str, type]:
|
|
"""Get all registered vision encoder types.
|
|
|
|
Returns:
|
|
Dictionary mapping encoder names to classes
|
|
"""
|
|
return VISION_ENCODER_REGISTRY.copy()
|
|
|
|
|
|
class CLIPTextEncoder(nn.Module):
|
|
"""CLIP text encoder with frozen weights and learnable projection."""
|
|
|
|
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
|
super().__init__()
|
|
|
|
self.model_name = model_name
|
|
self.projection_dim = projection_dim
|
|
|
|
# Load CLIP text encoder and tokenizer
|
|
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
|
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
|
|
|
|
# Freeze all CLIP text encoder parameters
|
|
for param in self.text_encoder.parameters():
|
|
param.requires_grad = False
|
|
|
|
self.text_embed_dim = self.text_encoder.config.hidden_size
|
|
|
|
# Learnable projection layer (always present, only trainable component)
|
|
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
|
|
|
|
def forward(self, text: str | list[str]) -> Tensor:
|
|
"""Encode text to feature vectors.
|
|
|
|
Args:
|
|
text: Single string or list of strings
|
|
|
|
Returns:
|
|
Text features of shape (B, projection_dim)
|
|
"""
|
|
# handle single string input
|
|
if isinstance(text, str):
|
|
text = [text]
|
|
|
|
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
|
|
|
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
|
|
|
|
# encode text through CLIP (frozen)
|
|
with torch.no_grad():
|
|
outputs = self.text_encoder(**text_inputs)
|
|
# Extract pooled output (EOS token embedding)
|
|
clip_features = outputs.pooler_output # (B, text_embed_dim)
|
|
|
|
# project to desired dimension (trainable)
|
|
projected_features = self.projection(clip_features) # (B, projection_dim)
|
|
|
|
return projected_features
|
|
|
|
|
|
class ObservationEncoder(nn.Module):
|
|
"""Handles all observation processing for the conditioning vector."""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
vision_config = config.observation_encoder.vision
|
|
|
|
self._setup_preprocessing(vision_config)
|
|
|
|
if config.image_features:
|
|
self.num_cameras = len(config.image_features)
|
|
self.camera_names = list(config.image_features.keys()) # Preserve ordering
|
|
|
|
if vision_config.use_separate_encoder_per_camera:
|
|
self.vision_encoders = nn.ModuleList(
|
|
[create_vision_encoder(vision_config) for _ in self.camera_names]
|
|
)
|
|
self.vision_encoder = None
|
|
else:
|
|
self.vision_encoder = create_vision_encoder(vision_config)
|
|
self.vision_encoders = None
|
|
else:
|
|
self.vision_encoder = None
|
|
self.vision_encoders = None
|
|
self.camera_names = []
|
|
self.num_cameras = 0
|
|
|
|
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
|
|
self.robot_state_dim = config.robot_state_feature.shape[0]
|
|
else:
|
|
self.robot_state_dim = 0
|
|
|
|
if hasattr(config, "env_state_feature") and config.env_state_feature:
|
|
self.env_state_dim = config.env_state_feature.shape[0]
|
|
else:
|
|
self.env_state_dim = 0
|
|
|
|
text_config = config.observation_encoder.text
|
|
self.text_dim = config.transformer.hidden_dim
|
|
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
|
|
|
|
self._setup_vector_output()
|
|
|
|
def _apply_preprocessing(self, images: Tensor) -> Tensor:
|
|
"""Apply preprocessing transforms to images."""
|
|
if self.do_resize:
|
|
images = self.resize(images)
|
|
if self.do_crop:
|
|
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
|
|
|
|
return images
|
|
|
|
def _setup_preprocessing(self, vision_config):
|
|
"""Setup image preprocessing transforms."""
|
|
if vision_config.resize_shape is not None:
|
|
self.do_resize = True
|
|
self.resize = torchvision.transforms.Resize(
|
|
size=vision_config.resize_shape,
|
|
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
|
antialias=True,
|
|
)
|
|
else:
|
|
self.do_resize = False
|
|
if vision_config.crop_shape is not None:
|
|
self.do_crop = True
|
|
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape)
|
|
if vision_config.crop_is_random:
|
|
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape)
|
|
else:
|
|
self.maybe_random_crop = self.center_crop
|
|
else:
|
|
self.do_crop = False
|
|
|
|
def _setup_vector_output(self):
|
|
"""Setup for vector output."""
|
|
total_dim = 0
|
|
|
|
# Vision features - get CLS token feature dimension
|
|
if self.vision_encoder is not None or self.vision_encoders is not None:
|
|
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values()))
|
|
|
|
# Get output shape from encoder (deterministic for CLS tokens)
|
|
feature_map_shape = encoder_to_check.get_output_shape()
|
|
c, h, w = feature_map_shape
|
|
spatial_feature_dim = c * h * w # For CLS token: embed_dim * 1 * 1 = embed_dim
|
|
|
|
total_dim += spatial_feature_dim * self.num_cameras
|
|
|
|
# State features
|
|
total_dim += self.robot_state_dim
|
|
total_dim += self.env_state_dim
|
|
|
|
# Text features
|
|
total_dim += self.text_dim
|
|
|
|
# Account for temporal stacking
|
|
self.conditioning_dim = total_dim * self.config.n_obs_steps
|
|
|
|
def encode(self, batch: dict) -> Tensor:
|
|
"""Encode observations to vector format."""
|
|
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
|
conditioning_feats = []
|
|
|
|
conditioning_feats.append(batch[OBS_STATE])
|
|
|
|
if self.vision_encoder is not None or self.vision_encoders is not None:
|
|
images = batch[OBS_IMAGES] # (B, n_obs_steps, num_cameras, C, H, W)
|
|
|
|
# Handle case when n_obs=1 and time dimension might be squeezed
|
|
if len(images.shape) == 5:
|
|
# Shape is (B, N, C, H, W) - add time dimension
|
|
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
|
|
|
|
vision_config = self.config.observation_encoder.vision
|
|
if vision_config.use_separate_encoder_per_camera:
|
|
# Process each camera with its own encoder
|
|
camera_features = []
|
|
|
|
for cam_idx in range(self.num_cameras):
|
|
# Extract images for this camera: (B, n_obs_steps, C, H, W)
|
|
cam_images = images[:, :, cam_idx]
|
|
|
|
# Rearrange to: (B*n_obs_steps, C, H, W)
|
|
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
|
|
|
|
# Apply preprocessing
|
|
cam_images_flat = self._apply_preprocessing(cam_images_flat)
|
|
|
|
# Process with camera-specific encoder (direct index access)
|
|
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
|
|
|
|
# Apply spatial vectorization (flatten CLS token features)
|
|
cam_visual_features = cam_features.flatten(start_dim=1)
|
|
|
|
# Reshape back: (B*n_obs_steps, feature_dim) → (B, n_obs_steps, feature_dim)
|
|
cam_features_reshaped = einops.rearrange(
|
|
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
|
|
)
|
|
camera_features.append(cam_features_reshaped)
|
|
|
|
# Concatenate features from all cameras: (B, n_obs_steps, total_feature_dim)
|
|
img_features = torch.cat(camera_features, dim=-1)
|
|
conditioning_feats.append(img_features)
|
|
|
|
else:
|
|
# Shared encoder for all cameras
|
|
# Rearrange to: (B*n_obs_steps*num_cameras, C, H, W)
|
|
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
|
|
|
|
images_flat = self._apply_preprocessing(images_flat)
|
|
|
|
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
|
|
|
|
# Reshape back and concatenate camera features
|
|
# (B*n_obs_steps*num_cameras, feature_dim) → (B, n_obs_steps, num_cameras*feature_dim)
|
|
img_features = einops.rearrange(
|
|
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
|
|
)
|
|
|
|
conditioning_feats.append(img_features)
|
|
|
|
if self.env_state_dim > 0 and OBS_ENV_STATE in batch:
|
|
conditioning_feats.append(batch[OBS_ENV_STATE])
|
|
|
|
if self.text_encoder is not None and "task" in batch:
|
|
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
|
|
# Expand across temporal dimension to match other features
|
|
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
|
|
print("Text features shape after unsqueeze and expand:", text_features.shape)
|
|
conditioning_feats.append(text_features)
|
|
|
|
for vec in conditioning_feats:
|
|
print(f"Conditioning feature shape: {vec.shape}")
|
|
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
|
|
|
|
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)
|