mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives
This commit is contained in:
@@ -68,9 +68,7 @@ class CLIPVisionEncoder(nn.Module):
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
|
||||
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
|
||||
"""
|
||||
"""CLIP text encoder with frozen weights and a learnable projection layer."""
|
||||
|
||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||
super().__init__()
|
||||
@@ -126,21 +124,20 @@ class ObservationEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
vision_config = config.observation_encoder.vision
|
||||
|
||||
self._setup_preprocessing(vision_config)
|
||||
self._setup_preprocessing(config)
|
||||
|
||||
if config.image_features:
|
||||
self.num_cameras = len(config.image_features)
|
||||
self.camera_names = list(config.image_features.keys()) # Preserve ordering
|
||||
self.camera_names = list(config.image_features.keys())
|
||||
|
||||
if vision_config.use_separate_encoder_per_camera:
|
||||
if config.use_separate_encoder_per_camera:
|
||||
self.vision_encoders = nn.ModuleList(
|
||||
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
|
||||
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
|
||||
)
|
||||
self.vision_encoder = None
|
||||
else:
|
||||
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
|
||||
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
|
||||
self.vision_encoders = None
|
||||
else:
|
||||
self.vision_encoder = None
|
||||
@@ -158,9 +155,8 @@ class ObservationEncoder(nn.Module):
|
||||
else:
|
||||
self.env_state_dim = 0
|
||||
|
||||
text_config = config.observation_encoder.text
|
||||
self.text_dim = config.transformer.hidden_dim
|
||||
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
|
||||
self.text_dim = config.hidden_dim
|
||||
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
|
||||
|
||||
self._setup_vector_output()
|
||||
|
||||
@@ -173,22 +169,23 @@ class ObservationEncoder(nn.Module):
|
||||
|
||||
return images
|
||||
|
||||
def _setup_preprocessing(self, vision_config):
|
||||
def _setup_preprocessing(self, config):
|
||||
"""Setup image preprocessing transforms."""
|
||||
if vision_config.resize_shape is not None:
|
||||
if config.image_resize_shape is not None:
|
||||
self.do_resize = True
|
||||
self.resize = torchvision.transforms.Resize(
|
||||
size=vision_config.resize_shape,
|
||||
size=config.image_resize_shape,
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
||||
antialias=True,
|
||||
)
|
||||
else:
|
||||
self.do_resize = False
|
||||
if vision_config.crop_shape is not None:
|
||||
|
||||
if config.image_crop_shape is not None:
|
||||
self.do_crop = True
|
||||
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape)
|
||||
if vision_config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
|
||||
if config.image_crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -199,7 +196,7 @@ class ObservationEncoder(nn.Module):
|
||||
|
||||
# Vision features - get CLS token feature dimension
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values()))
|
||||
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
|
||||
|
||||
# Get output shape from encoder (deterministic for CLS tokens)
|
||||
feature_map_shape = encoder_to_check.get_output_shape()
|
||||
@@ -233,8 +230,7 @@ class ObservationEncoder(nn.Module):
|
||||
# Shape is (B, N, C, H, W) - add time dimension
|
||||
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
|
||||
|
||||
vision_config = self.config.observation_encoder.vision
|
||||
if vision_config.use_separate_encoder_per_camera:
|
||||
if self.config.use_separate_encoder_per_camera:
|
||||
# Process each camera with its own encoder
|
||||
camera_features = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user