Files
lerobot-clone/src/lerobot/policies/multi_task_dit/modules/observation_encoder.py
Bryson Jones 14a7a4d7d4 Add multitask diffusion transformer policy
Add multitask diffusion transformer policy
2025-11-12 16:20:59 -08:00

397 lines
14 KiB
Python

#!/usr/bin/env python
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Observation encoding for Multi-Task DiT policy.
Handles vision encoding, text encoding, robot state, and environment state.
"""
from abc import ABC, abstractmethod
import einops
import timm
import torch
import torch.nn as nn
import torchvision
from torch import Tensor
from transformers import CLIPTextModel, CLIPTokenizer
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class BaseVisionEncoder(ABC):
"""Abstract base class for vision encoders."""
@abstractmethod
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to feature maps."""
pass
@abstractmethod
def get_output_shape(self) -> tuple:
"""Get the output shape (C', H', W')."""
pass
class DinoV3Encoder(nn.Module, BaseVisionEncoder):
"""DinoV3 vision encoder using the CLS token for global image representation."""
def __init__(self, config):
super().__init__()
self.config = config
self.model_name = config.backbone
# Create the timm model
self.model = timm.create_model(
self.model_name,
pretrained=True,
num_classes=0,
)
self.num_non_spatial_tokens = 5 # 1 CLS + 4 register
self.embed_dim = self.model.embed_dim
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to feature maps."""
# Extract all features
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
# Use only the CLS token (first token)
cls_token = features[:, 0] # (B, embed_dim)
b, embed_dim = cls_token.shape
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
return cls_features
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
class CLIPEncoder(nn.Module, BaseVisionEncoder):
"""CLIP vision encoder using the CLS token for global image representation."""
def __init__(self, config):
super().__init__()
self.config = config
self.model_name = config.backbone
# Create the timm model
self.model = timm.create_model(
self.model_name,
pretrained=True,
num_classes=0, # Remove classification head, we want features
)
# CLIP models have 1 CLS token (no register tokens like DinoV3)
self.num_non_spatial_tokens = 1
# Get embed_dim from model config
self.embed_dim = self.model.embed_dim
def forward(self, x: Tensor) -> Tensor:
"""Encode RGB image to CLS token.
Preprocessing (resize, crop) is handled by ObservationEncoder
"""
# Extract all features
features = self.model.forward_features(x) # (B, total_tokens, embed_dim)
# Use only the CLS token (first token)
cls_token = features[:, 0] # (B, embed_dim)
b, embed_dim = cls_token.shape
# Reshape to spatial format (B, C, H, W) with H=W=1 for compatibility
cls_features = cls_token.reshape(b, embed_dim, 1, 1)
return cls_features
def get_output_shape(self) -> tuple:
return (self.embed_dim, 1, 1)
def create_vision_encoder(config) -> BaseVisionEncoder:
backbone_name = config.backbone.lower()
# Check if it's a CLIP model
if "clip" in backbone_name:
return CLIPEncoder(config)
# Check if it's a DinoV3 model
elif "dinov3" in backbone_name:
return DinoV3Encoder(config)
else:
raise ValueError(
f"Unsupported vision backbone: {config.backbone}. "
f"Currently supported: DinoV3 models and CLIP models"
)
# Registry for easy extension
VISION_ENCODER_REGISTRY: dict[str, type] = {
"dinov3": DinoV3Encoder,
"clip": CLIPEncoder,
}
def register_vision_encoder(name: str, encoder_class: type):
"""Register a new vision encoder type.
Args:
name: Identifier for the encoder type
encoder_class: Class implementing BaseVisionEncoder interface
"""
VISION_ENCODER_REGISTRY[name] = encoder_class
def get_registered_encoders() -> dict[str, type]:
"""Get all registered vision encoder types.
Returns:
Dictionary mapping encoder names to classes
"""
return VISION_ENCODER_REGISTRY.copy()
class CLIPTextEncoder(nn.Module):
"""CLIP text encoder with frozen weights and learnable projection."""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
self.model_name = model_name
self.projection_dim = projection_dim
# Load CLIP text encoder and tokenizer
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
# Freeze all CLIP text encoder parameters
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_embed_dim = self.text_encoder.config.hidden_size
# Learnable projection layer (always present, only trainable component)
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
def forward(self, text: str | list[str]) -> Tensor:
"""Encode text to feature vectors.
Args:
text: Single string or list of strings
Returns:
Text features of shape (B, projection_dim)
"""
# handle single string input
if isinstance(text, str):
text = [text]
text_inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
text_inputs = {k: v.to(next(self.parameters()).device) for k, v in text_inputs.items()}
# encode text through CLIP (frozen)
with torch.no_grad():
outputs = self.text_encoder(**text_inputs)
# Extract pooled output (EOS token embedding)
clip_features = outputs.pooler_output # (B, text_embed_dim)
# project to desired dimension (trainable)
projected_features = self.projection(clip_features) # (B, projection_dim)
return projected_features
class ObservationEncoder(nn.Module):
"""Handles all observation processing for the conditioning vector."""
def __init__(self, config):
super().__init__()
self.config = config
vision_config = config.observation_encoder.vision
self._setup_preprocessing(vision_config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys()) # Preserve ordering
if vision_config.use_separate_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[create_vision_encoder(vision_config) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = create_vision_encoder(vision_config)
self.vision_encoders = None
else:
self.vision_encoder = None
self.vision_encoders = None
self.camera_names = []
self.num_cameras = 0
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
self.robot_state_dim = config.robot_state_feature.shape[0]
else:
self.robot_state_dim = 0
if hasattr(config, "env_state_feature") and config.env_state_feature:
self.env_state_dim = config.env_state_feature.shape[0]
else:
self.env_state_dim = 0
text_config = config.observation_encoder.text
self.text_dim = config.transformer.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
self._setup_vector_output()
def _apply_preprocessing(self, images: Tensor) -> Tensor:
"""Apply preprocessing transforms to images."""
if self.do_resize:
images = self.resize(images)
if self.do_crop:
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
return images
def _setup_preprocessing(self, vision_config):
"""Setup image preprocessing transforms."""
if vision_config.resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=vision_config.resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if vision_config.crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape)
if vision_config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
def _setup_vector_output(self):
"""Setup for vector output."""
total_dim = 0
# Vision features - get CLS token feature dimension
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values()))
# Get output shape from encoder (deterministic for CLS tokens)
feature_map_shape = encoder_to_check.get_output_shape()
c, h, w = feature_map_shape
spatial_feature_dim = c * h * w # For CLS token: embed_dim * 1 * 1 = embed_dim
total_dim += spatial_feature_dim * self.num_cameras
# State features
total_dim += self.robot_state_dim
total_dim += self.env_state_dim
# Text features
total_dim += self.text_dim
# Account for temporal stacking
self.conditioning_dim = total_dim * self.config.n_obs_steps
def encode(self, batch: dict) -> Tensor:
"""Encode observations to vector format."""
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
conditioning_feats = []
conditioning_feats.append(batch[OBS_STATE])
if self.vision_encoder is not None or self.vision_encoders is not None:
images = batch[OBS_IMAGES] # (B, n_obs_steps, num_cameras, C, H, W)
# Handle case when n_obs=1 and time dimension might be squeezed
if len(images.shape) == 5:
# Shape is (B, N, C, H, W) - add time dimension
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
vision_config = self.config.observation_encoder.vision
if vision_config.use_separate_encoder_per_camera:
# Process each camera with its own encoder
camera_features = []
for cam_idx in range(self.num_cameras):
# Extract images for this camera: (B, n_obs_steps, C, H, W)
cam_images = images[:, :, cam_idx]
# Rearrange to: (B*n_obs_steps, C, H, W)
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
# Apply preprocessing
cam_images_flat = self._apply_preprocessing(cam_images_flat)
# Process with camera-specific encoder (direct index access)
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
# Apply spatial vectorization (flatten CLS token features)
cam_visual_features = cam_features.flatten(start_dim=1)
# Reshape back: (B*n_obs_steps, feature_dim) → (B, n_obs_steps, feature_dim)
cam_features_reshaped = einops.rearrange(
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
)
camera_features.append(cam_features_reshaped)
# Concatenate features from all cameras: (B, n_obs_steps, total_feature_dim)
img_features = torch.cat(camera_features, dim=-1)
conditioning_feats.append(img_features)
else:
# Shared encoder for all cameras
# Rearrange to: (B*n_obs_steps*num_cameras, C, H, W)
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
images_flat = self._apply_preprocessing(images_flat)
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
# Reshape back and concatenate camera features
# (B*n_obs_steps*num_cameras, feature_dim) → (B, n_obs_steps, num_cameras*feature_dim)
img_features = einops.rearrange(
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
)
conditioning_feats.append(img_features)
if self.env_state_dim > 0 and OBS_ENV_STATE in batch:
conditioning_feats.append(batch[OBS_ENV_STATE])
if self.text_encoder is not None and "task" in batch:
text_features = self.text_encoder(batch["task"]) # (B, text_dim)
# Expand across temporal dimension to match other features
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1) # (B, T, text_dim)
print("Text features shape after unsqueeze and expand:", text_features.shape)
conditioning_feats.append(text_features)
for vec in conditioning_feats:
print(f"Conditioning feature shape: {vec.shape}")
combined_features = torch.cat(conditioning_feats, dim=-1) # (B, n_obs_steps, total_feature_dim)
return combined_features.flatten(start_dim=1) # (B, n_obs_steps * total_feature_dim)