simplify config for multitask dit by merging and flattening everything, then adding comments to denote where some parameters are only used for specific objectives

This commit is contained in:
Bryson Jones
2025-12-10 11:45:59 -08:00
parent cdacc090cd
commit 103230c64c
7 changed files with 242 additions and 454 deletions

View File

@@ -68,9 +68,7 @@ class CLIPVisionEncoder(nn.Module):
class CLIPTextEncoder(nn.Module):
"""Supports any HuggingFace CLIP model. The encoder weights are frozen,
and a learnable projection layer maps the CLIP embeddings to the desired dimension.
"""
"""CLIP text encoder with frozen weights and a learnable projection layer."""
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
super().__init__()
@@ -126,21 +124,20 @@ class ObservationEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
vision_config = config.observation_encoder.vision
self._setup_preprocessing(vision_config)
self._setup_preprocessing(config)
if config.image_features:
self.num_cameras = len(config.image_features)
self.camera_names = list(config.image_features.keys()) # Preserve ordering
self.camera_names = list(config.image_features.keys())
if vision_config.use_separate_encoder_per_camera:
if config.use_separate_encoder_per_camera:
self.vision_encoders = nn.ModuleList(
[CLIPVisionEncoder(model_name=vision_config.model_name) for _ in self.camera_names]
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
)
self.vision_encoder = None
else:
self.vision_encoder = CLIPVisionEncoder(model_name=vision_config.model_name)
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
self.vision_encoders = None
else:
self.vision_encoder = None
@@ -158,9 +155,8 @@ class ObservationEncoder(nn.Module):
else:
self.env_state_dim = 0
text_config = config.observation_encoder.text
self.text_dim = config.transformer.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=text_config.model, projection_dim=self.text_dim)
self.text_dim = config.hidden_dim
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
self._setup_vector_output()
@@ -173,22 +169,23 @@ class ObservationEncoder(nn.Module):
return images
def _setup_preprocessing(self, vision_config):
def _setup_preprocessing(self, config):
"""Setup image preprocessing transforms."""
if vision_config.resize_shape is not None:
if config.image_resize_shape is not None:
self.do_resize = True
self.resize = torchvision.transforms.Resize(
size=vision_config.resize_shape,
size=config.image_resize_shape,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True,
)
else:
self.do_resize = False
if vision_config.crop_shape is not None:
if config.image_crop_shape is not None:
self.do_crop = True
self.center_crop = torchvision.transforms.CenterCrop(vision_config.crop_shape)
if vision_config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(vision_config.crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
if config.image_crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
@@ -199,7 +196,7 @@ class ObservationEncoder(nn.Module):
# Vision features - get CLS token feature dimension
if self.vision_encoder is not None or self.vision_encoders is not None:
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders.values()))
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
# Get output shape from encoder (deterministic for CLS tokens)
feature_map_shape = encoder_to_check.get_output_shape()
@@ -233,8 +230,7 @@ class ObservationEncoder(nn.Module):
# Shape is (B, N, C, H, W) - add time dimension
images = images.unsqueeze(1) # (B, 1, N, C, H, W)
vision_config = self.config.observation_encoder.vision
if vision_config.use_separate_encoder_per_camera:
if self.config.use_separate_encoder_per_camera:
# Process each camera with its own encoder
camera_features = []