From 999cc625d6a533a9150b408ea7d4e4a3829c0e68 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Wed, 20 May 2026 15:24:59 +0000 Subject: [PATCH] lots of changes to make existing weights work, need to massively refactor the pre and post processing --- src/lerobot/policies/vla_jepa/action_head.py | 70 +-- .../vla_jepa/configuration_vla_jepa.py | 6 + .../vla_jepa/convert_vla_jepa_checkpoints.py | 158 ++++--- .../policies/vla_jepa/modeling_vla_jepa.py | 149 +++++-- .../policies/vla_jepa/processor_vla_jepa.py | 22 +- .../policies/vla_jepa/qwen_interface.py | 5 +- src/lerobot/policies/vla_jepa/world_model.py | 420 +++++++++++++++--- 7 files changed, 612 insertions(+), 218 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/action_head.py b/src/lerobot/policies/vla_jepa/action_head.py index 430c9cfe9..200ecdd91 100644 --- a/src/lerobot/policies/vla_jepa/action_head.py +++ b/src/lerobot/policies/vla_jepa/action_head.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import OrderedDict from dataclasses import dataclass from typing import TYPE_CHECKING @@ -33,17 +34,6 @@ def swish(x: torch.Tensor) -> torch.Tensor: return x * torch.sigmoid(x) -class _MLP2(nn.Module): - """Two-layer GELU MLP with layer1/layer2 attribute names matching the original checkpoint.""" - - def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None: - super().__init__() - self.layer1 = nn.Linear(in_dim, hidden_dim) - self.layer2 = nn.Linear(hidden_dim, out_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.layer2(F.gelu(self.layer1(x))) - class SinusoidalPositionalEncoding(nn.Module): def __init__(self, embedding_dim: int): @@ -108,10 +98,11 @@ class BasicTransformerBlock(nn.Module): num_attention_heads: int, attention_head_dim: int, dropout: float, - cross_attention_dim: int | None, + cross_attention_dim: int, + is_cross_attention: bool = True, ) -> None: super().__init__() - self.is_cross_attention = cross_attention_dim is not None + self.is_cross_attention = is_cross_attention self.norm1 = AdaLayerNorm(dim) self.attn1 = Attention( query_dim=dim, @@ -132,7 +123,8 @@ class BasicTransformerBlock(nn.Module): temb: torch.Tensor, ) -> torch.Tensor: attn_input = self.norm1(hidden_states, temb) - hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=encoder_hidden_states) + attention_context = encoder_hidden_states if self.is_cross_attention else None + hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context) hidden_states = hidden_states + self.ff(self.norm2(hidden_states)) return hidden_states @@ -160,10 +152,10 @@ class DiT(ModelMixin, ConfigMixin): num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, dropout=dropout, - # Even blocks attend to context (cross-attention), odd blocks are self-attention. - cross_attention_dim=cross_attention_dim if i % 2 == 0 else None, + cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim, + is_cross_attention=layer_idx % 2 == 0, ) - for i in range(num_layers) + for layer_idx in range(num_layers) ] ) self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False) @@ -179,8 +171,7 @@ class DiT(ModelMixin, ConfigMixin): temb = self.timestep_encoder(timestep) x = hidden_states for block in self.transformer_blocks: - es = encoder_hidden_states if block.is_cross_attention else None - x = block(x, encoder_hidden_states=es, temb=temb) + x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb) shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1) x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None] return self.proj_out_2(x) @@ -205,34 +196,49 @@ class VLAJEPAActionHead(nn.Module): super().__init__() preset = DIT_PRESETS[config.action_model_type] self.config = config - num_heads = preset.num_attention_heads - head_dim = preset.attention_head_dim + num_heads = config.action_num_heads or preset.num_attention_heads + head_dim = config.action_attention_head_dim or preset.attention_head_dim inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768 self.input_embedding_dim = inner_dim self.action_horizon = config.chunk_size self.num_inference_timesteps = config.num_inference_timesteps + hidden_size = config.action_hidden_size self.model = DiT( num_attention_heads=num_heads, attention_head_dim=head_dim, - output_dim=config.action_hidden_size, + output_dim=hidden_size, num_layers=config.action_num_layers, dropout=config.action_dropout, cross_attention_dim=cross_attention_dim, ) - # action_encoder/decoder and state_encoder use action_hidden_size (DiT output dim). - # action_encoder and state_encoder produce inner_dim-sized tokens (DiT input width). - # action_decoder takes DiT output (action_hidden_size) and produces action_dim predictions. self.action_encoder = ActionEncoder(config.action_dim, inner_dim) - self.action_decoder = _MLP2(config.action_hidden_size, config.action_hidden_size, config.action_dim) - self.state_encoder = ( - _MLP2(config.state_dim, config.action_hidden_size, inner_dim) if config.state_dim > 0 else None + self.action_decoder = nn.Sequential( + OrderedDict([ + ("layer1", nn.Linear(hidden_size, hidden_size)), + ("relu", nn.ReLU()), + ("layer2", nn.Linear(hidden_size, config.action_dim)), + ]) + ) + self.state_encoder = ( + nn.Sequential( + OrderedDict([ + ("layer1", nn.Linear(config.state_dim, hidden_size)), + ("relu", nn.ReLU()), + ("layer2", nn.Linear(hidden_size, inner_dim)), + ]) + ) + if config.state_dim > 0 + else None + ) + self.future_tokens = nn.Embedding( + config.num_embodied_action_tokens_per_instruction, inner_dim + ) + self.position_embedding = nn.Embedding( + max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4), + inner_dim, ) - # future_tokens and position_embedding operate at inner_dim (DiT input width), - # not at action_hidden_size (DiT output width). - self.future_tokens = nn.Embedding(config.num_target_vision_tokens, inner_dim) - self.position_embedding = nn.Embedding(config.action_max_seq_len, inner_dim) self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta) def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index 9bcff66ea..53dced2cd 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field +from typing import Any from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode @@ -44,6 +45,8 @@ class VLAJEPAConfig(PreTrainedConfig): action_hidden_size: int = 1024 action_model_type: str = "DiT-B" action_num_layers: int = 16 + action_num_heads: int | None = None + action_attention_head_dim: int | None = None action_dropout: float = 0.2 action_num_timestep_buckets: int = 1000 action_noise_beta_alpha: float = 1.5 @@ -63,6 +66,9 @@ class VLAJEPAConfig(PreTrainedConfig): repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style) resize_images_to: tuple[int, int] | None = None + action_unnormalization_stats: dict[str, Any] | None = None + binarize_gripper_action: bool = True + clip_normalized_actions: bool = True torch_dtype: str = "bfloat16" optimizer_lr: float = 1e-4 diff --git a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py index d6f444645..a586dc625 100644 --- a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py +++ b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py @@ -25,36 +25,75 @@ Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed from __future__ import annotations import logging +import tempfile from pathlib import Path import torch from huggingface_hub import HfApi +from safetensors.torch import save_file as save_safetensors +from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors # --------------------------------------------------------------------------- # Top-level settings # --------------------------------------------------------------------------- SOURCE_REPO_ID = "ginwind/VLA-JEPA" -TARGET_ORG = "lerobot" +TARGET_ORG = "maximellerbach" COLLECTION_TITLE = "VLA-JEPA" COLLECTION_DESCRIPTION = ( "VLA-JEPA model checkpoints (LIBERO, Pretrain, SimplerEnv) converted from .pt to safetensors via LeRobot." ) -# Remap state-dict key prefixes before loading into the LeRobot policy. -# E.g. {"": "model."} prepends "model." to every key. -# Leave empty if keys already match — the first run's log will tell you. -KEY_PREFIX_REMAP: dict[str, str] = { - # Specific rules must come before the "" catch-all (dict order is preserved). - "qwen_vl_interface.": "model.qwen.", - "vj_encoder.": "model.video_encoder.", - "vj_predictor.": "model.video_predictor.", - # Everything else (action_model.*) just needs the "model." wrapper. - "": "model.", -} - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Key mapping — mirrors todo_converter.py map_key() so both converters +# produce identical safetensors layouts that match the LeRobot action_head code. +# --------------------------------------------------------------------------- + + +def _normalize_source_key(key: str) -> str: + return key[len("module."):] if key.startswith("module.") else key + + +def _map_checkpoint_key(raw_key: str) -> str | None: + """Map original VLA-JEPA state-dict keys to LeRobot vla_jepa layout.""" + key = _normalize_source_key(raw_key) + + if key.startswith("qwen_vl_interface."): + return "model.qwen." + key[len("qwen_vl_interface."):] + if key.startswith("vj_encoder."): + return "model.video_encoder." + key[len("vj_encoder."):] + if key.startswith("vj_predictor."): + return "model.video_predictor." + key[len("vj_predictor."):] + if key.startswith("action_model."): + # LeRobot code uses the same sub-key names as the source checkpoint, + # so only the top-level "model." prefix needs to be added. + return "model." + key + + return None + + +def _fetch_action_stats(api: "HfApi", source_repo_id: str, subfolder: str) -> dict | None: + """Try to download dataset_statistics.json and return the action stats dict.""" + import json + + stats_file = f"{subfolder}/dataset_statistics.json" + try: + local = api.hf_hub_download(source_repo_id, stats_file) + data = json.loads(Path(local).read_text()) + # The original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}} + for robot_key in data: + if isinstance(data[robot_key], dict) and "action" in data[robot_key]: + log.info(" Loaded action stats from %s (robot key: %s)", stats_file, robot_key) + return data[robot_key]["action"] + log.warning(" %s found but no 'action' key under any robot — skipping action stats.", stats_file) + except Exception as exc: # noqa: BLE001 + log.warning(" Could not fetch %s: %s — action_unnormalization_stats will be None.", stats_file, exc) + return None + + # --------------------------------------------------------------------------- # Architecture — identical across all 4 variants (from config.json) # --------------------------------------------------------------------------- @@ -69,6 +108,9 @@ _ARCH = { "num_inference_timesteps": 4, "action_hidden_size": 1024, "action_model_type": "DiT-B", + # Explicit dims matching DiT-B preset and ginwind checkpoint shape + "action_num_heads": 12, + "action_attention_head_dim": 64, "action_num_layers": 16, "action_dropout": 0.2, "repeated_diffusion_steps": 8, @@ -76,9 +118,6 @@ _ARCH = { "action_noise_beta_beta": 1.0, "action_noise_s": 0.999, "action_num_timestep_buckets": 1000, - # Action head embedding params (from original config.json) - "num_target_vision_tokens": 32, - "action_max_seq_len": 1024, # World model predictor (12 blocks, confirmed from checkpoint) "predictor_depth": 12, } @@ -109,7 +148,12 @@ _OXE_CAMS = [ # --------------------------------------------------------------------------- -def _build_config(camera_keys: list[str], with_state: bool, enable_world_model: bool = True): +def _build_config( + camera_keys: list[str], + with_state: bool, + enable_world_model: bool = True, + action_stats: dict | None = None, +): from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig @@ -123,6 +167,9 @@ def _build_config(camera_keys: list[str], with_state: bool, enable_world_model: "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), }, enable_world_model=enable_world_model, + action_unnormalization_stats=action_stats, + binarize_gripper_action=True, + clip_normalized_actions=True, **_ARCH, ) cfg.validate_features() @@ -153,20 +200,6 @@ def extract_state_dict(ckpt: object) -> dict[str, torch.Tensor]: return {k: v for k, v in sd.items() if isinstance(v, torch.Tensor)} -def remap_keys(sd: dict[str, torch.Tensor], remap: dict[str, str]) -> dict[str, torch.Tensor]: - if not remap: - return sd - out = {} - for k, v in sd.items(): - new_k = k - for old, new in remap.items(): - if k.startswith(old): - new_k = new + k[len(old) :] - break - out[new_k] = v - return out - - def subfolder_of(pt_path: str) -> str | None: for part in Path(pt_path).parts: if part in VARIANTS: @@ -229,37 +262,48 @@ def main() -> None: ckpt = torch.load(local_pt, map_location="cpu") # nosec B614 sd = extract_state_dict(ckpt) - sd = remap_keys(sd, KEY_PREFIX_REMAP) - log.info(" %d tensors extracted", len(sd)) - log.info(" First 5 keys: %s", list(sd)[:5]) - # 3. Build policy - from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy + # Map source key names → LeRobot layout (handles layer1→w1, transformer_blocks→blocks, etc.) + mapped_sd: dict[str, torch.Tensor] = {} + skipped_keys: list[str] = [] + for raw_key, value in sd.items(): + target_key = _map_checkpoint_key(raw_key) + if target_key is None: + skipped_keys.append(raw_key) + else: + mapped_sd[target_key] = value + log.info(" %d tensors mapped, %d skipped", len(mapped_sd), len(skipped_keys)) + if skipped_keys: + log.info(" Skipped sample: %s", skipped_keys[:5]) + log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5]) - config = _build_config(camera_keys, with_state, enable_world_model) - policy = VLAJEPAPolicy(config) + # Fetch action unnormalization stats from the source repo + action_stats = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder) - # 4. Load weights - missing, unexpected = policy.load_state_dict(sd, strict=False) + # 3. Build config (no policy instantiation — avoids loading backbone from Hub) + config = _build_config(camera_keys, with_state, enable_world_model, action_stats) - def _prefix_summary(keys: list[str]) -> dict[str, int]: - from collections import Counter - - return dict(Counter(".".join(k.split(".")[:3]) for k in keys).most_common()) - - if missing: - log.warning(" Missing (%d) by prefix: %s", len(missing), _prefix_summary(missing)) - if unexpected: - log.warning(" Unexpected (%d) by prefix: %s", len(unexpected), _prefix_summary(unexpected)) - if not missing and not unexpected: - log.info(" State dict loaded cleanly.") - - # 5. Push to hub (writes model.safetensors + config.json) + # 4. Save everything to a temp dir and upload in one shot api.create_repo(target_repo_id, repo_type="model", exist_ok=True) - commit_url = policy.push_to_hub( - repo_id=target_repo_id, - commit_message=f"Convert {Path(pt_filename).name} to safetensors", - ) + with tempfile.TemporaryDirectory() as tmp: + save_dir = Path(tmp) + + log.info(" Saving model.safetensors …") + save_safetensors(mapped_sd, save_dir / "model.safetensors") + + config._save_pretrained(save_dir) # writes config.json via draccus + + preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config) + preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json + postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json + + log.info(" Uploading …") + commit_url = api.upload_folder( + folder_path=save_dir, + repo_id=target_repo_id, + repo_type="model", + commit_message=f"Convert {Path(pt_filename).name} to safetensors", + ) log.info(" Uploaded → %s", commit_url) # 6. Add to collection diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 5f2cf8a9d..5d05e03eb 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -9,6 +9,7 @@ import numpy as np import torch import torch.nn.functional as F # noqa: N812 from PIL import Image +from safetensors.torch import load_file as load_safetensors_file from torch import Tensor, nn from lerobot.policies.pretrained import PreTrainedPolicy, T @@ -72,9 +73,18 @@ class VLAJEPAModel(nn.Module): torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype), ) self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name) - num_views = max(len(config.image_features), 1) + num_views = max(1, len(config.image_features)) + tubelet_size = self.video_encoder.config.tubelet_size + image_size = getattr(self.video_encoder.config, "image_size", None) + if image_size is None: + first_image_shape = next(iter(config.image_features.values())).shape + image_size = first_image_shape[-1] self.video_predictor = ActionConditionedVideoPredictor( - embed_dim=num_views * self.video_encoder.config.hidden_size, + num_frames=config.num_video_frames // tubelet_size, + img_size=(image_size, image_size), + patch_size=16, + tubelet_size=1, + embed_dim=self.video_encoder.config.hidden_size * num_views, action_embed_dim=self.qwen.model.config.hidden_size, predictor_embed_dim=self.video_encoder.config.hidden_size, depth=config.predictor_depth, @@ -91,17 +101,56 @@ class VLAJEPAModel(nn.Module): self.qwen.requires_grad_(False) # Build prompt placeholders. - # Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor. - # This matches the number of context temporal positions after tubelet compression. - n_wm_action_groups = max(1, self.config.num_video_frames // self.config.jepa_tubelet_size - 1) + # Use the encoder's actual tubelet_size when available (world model enabled), + # otherwise fall back to config. + _tubelet_size = ( + self.video_encoder.config.tubelet_size + if config.enable_world_model + else self.config.jepa_tubelet_size + ) + num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1 self.replace_prompt = "".join( token * self.config.num_action_tokens_per_timestep - for token in self.action_tokens[:n_wm_action_groups] + for token in self.action_tokens[:num_action_prompt_steps] ) self.embodied_replace_prompt = ( self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction ) + def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor: + """Return Qwen's final decoder-layer output before the final RMSNorm. + + starVLA trained its downstream heads on the legacy transformers-4.57 + `hidden_states[-1]` value, which is the last decoder layer output before + Qwen's final RMSNorm. Newer transformers versions expose `hidden_states[-1]` + as the post-norm last hidden state, so capture the layer output directly. + """ + captured: dict[str, torch.Tensor] = {} + language_model = self.qwen.model.model.language_model + + def capture_last_layer_output( + _module: nn.Module, + _inputs: tuple[torch.Tensor, ...], + output: torch.Tensor | tuple[torch.Tensor, ...], + ) -> None: + captured["last_hidden"] = output[0] if isinstance(output, tuple) else output + return None + + handle = language_model.layers[-1].register_forward_hook(capture_last_layer_output) + try: + self.qwen.model.model( + **qwen_inputs, + output_hidden_states=False, + output_attentions=False, + return_dict=True, + ) + finally: + handle.remove() + + if "last_hidden" not in captured: + raise RuntimeError("Failed to capture Qwen last decoder hidden states.") + return captured["last_hidden"] + # ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ---- def forward(self, examples: list[dict]) -> dict[str, Tensor]: @@ -160,13 +209,7 @@ class VLAJEPAModel(nn.Module): device_type = next(self.parameters()).device.type with torch.autocast(device_type=device_type, dtype=torch.bfloat16): - qwen_outputs = self.qwen.model( - **qwen_inputs, - output_hidden_states=True, - output_attentions=False, - return_dict=True, - ) - last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H] + last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H] b, _, h = last_hidden.shape if self.config.enable_world_model: @@ -211,20 +254,16 @@ class VLAJEPAModel(nn.Module): input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :] gt_states = video_embeddings[:, tokens_per_frame:, :] - d_emb = input_states.shape[-1] - - input_states_4d = input_states.view(b, t_enc_ctx, tokens_per_frame, d_emb) expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep if action_tokens.shape[1] < expected_actions: pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) action_tokens = torch.cat([action_tokens, pad], dim=1) - act_4d = action_tokens[:, :expected_actions].view( - b, t_enc_ctx, self.config.num_action_tokens_per_timestep, -1 - ) - pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float()) - predicted_states = pred_4d.reshape(b, -1, d_emb) + predicted_states = self.video_predictor( + input_states.float(), + action_tokens[:, :expected_actions].float(), + ) wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean") @@ -242,15 +281,14 @@ class VLAJEPAModel(nn.Module): state_tensor = None if state is not None: state_tensor = torch.tensor( - np.array(state), device=last_hidden.device, dtype=torch.float32 + np.array(state), device=last_hidden.device, dtype=last_hidden.dtype ) # [B, 1, state_dim] - # repeated_diffusion_steps: draw R independent noise samples per batch item (CogACT-style). - # Effectively multiplies data efficiency of the action head by R with no extra Qwen/JEPA cost. - num_repeated = self.config.repeated_diffusion_steps - embodied_rep = embodied_action_tokens.float().repeat(num_repeated, 1, 1) - actions_rep = actions_target.repeat(num_repeated, 1, 1) - state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None + repeated_diffusion_steps = self.config.repeated_diffusion_steps + actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1) + embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1) + if state_tensor is not None: + state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1) action_is_pad_rep = None if action_is_pad is not None: @@ -263,9 +301,11 @@ class VLAJEPAModel(nn.Module): ] ) # [B, T_full] pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon] - action_is_pad_rep = pad_tensor.repeat(num_repeated, 1) # [B*R, action_horizon] + action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon] - action_loss = self.action_model(embodied_rep, actions_rep, state_rep, action_is_pad_rep) + action_loss = self.action_model( + embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep + ) return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight} @@ -289,6 +329,14 @@ class VLAJEPAModel(nn.Module): Returns: np.ndarray [B, action_horizon, action_dim] — predicted actions. """ + if self.config.resize_images_to is not None: + height, width = self.config.resize_images_to + resampling = getattr(Image, "Resampling", Image).BOX + batch_images = [ + [image.resize((width, height), resample=resampling) for image in sample_images] + for sample_images in batch_images + ] + qwen_inputs = self.qwen.build_inputs( images=batch_images, instructions=instructions, @@ -302,27 +350,19 @@ class VLAJEPAModel(nn.Module): device_type = next(self.parameters()).device.type with torch.autocast(device_type=device_type, dtype=torch.bfloat16): - qwen_outputs = self.qwen.model( - **qwen_inputs, - output_hidden_states=True, - output_attentions=False, - return_dict=True, - ) - last_hidden = qwen_outputs.hidden_states[-1] + last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H] b, _, h = last_hidden.shape embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h) state_tensor = None if state is not None: state_tensor = torch.from_numpy(np.array(state)).to( - device=last_hidden.device, dtype=torch.float32 + device=last_hidden.device, dtype=last_hidden.dtype ) - with torch.autocast(device_type=device_type, dtype=torch.float32): - # Cast embodied tokens to float32 for action model compatibility - pred_actions = self.action_model.predict_action( - embodied_action_tokens.float(), state_tensor - ) # [B, action_horizon, action_dim] + pred_actions = self.action_model.predict_action( + embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None + ) # [B, action_horizon, action_dim] return pred_actions.detach().cpu().numpy() @@ -546,8 +586,31 @@ class VLAJEPAPolicy(PreTrainedPolicy): actions_np = self.model.predict_action(batch_images, instructions, state_np) # Convert back to tensor on the right device + actions_np = self._unnormalize_actions(actions_np) return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32) + def _unnormalize_actions(self, normalized_actions: np.ndarray) -> np.ndarray: + """Match starVLA's LIBERO action post-processing exactly.""" + stats = self.config.action_unnormalization_stats + if not stats: + return normalized_actions + + actions = normalized_actions.astype(np.float32, copy=True) + if self.config.clip_normalized_actions: + actions = np.clip(actions, -1.0, 1.0) + + if self.config.binarize_gripper_action and actions.shape[-1] >= 7: + actions[..., 6] = np.where(actions[..., 6] < 0.5, 0.0, 1.0) + + action_min = np.asarray(stats["min"], dtype=np.float32) + action_max = np.asarray(stats["max"], dtype=np.float32) + mask = np.asarray(stats.get("mask", np.ones_like(action_min, dtype=bool)), dtype=bool) + scaled = 0.5 * (actions + 1.0) * (action_max - action_min) + action_min + actions = np.where(mask, scaled, actions).astype(np.float32) + if self.config.binarize_gripper_action and actions.shape[-1] >= 7: + actions[..., 6] = 1.0 - 2.0 * (actions[..., 6] > 0.5) + return actions + @torch.no_grad() def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """LeRobot select_action with action queue caching.""" diff --git a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py index acd6ea2b6..5aab01c18 100644 --- a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py @@ -14,7 +14,6 @@ from lerobot.processor import ( PolicyProcessorPipeline, ProcessorStepRegistry, RenameObservationsProcessorStep, - UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME @@ -31,7 +30,6 @@ def make_vla_jepa_pre_post_processors( input_steps = [ RenameObservationsProcessorStep(rename_map={}), AddBatchDimensionProcessorStep(), - VLAJEPANewLineProcessor(), DeviceProcessorStep(device=config.device), NormalizerProcessorStep( features=features, @@ -40,11 +38,6 @@ def make_vla_jepa_pre_post_processors( ), ] output_steps = [ - UnnormalizerProcessorStep( - features=config.output_features, - norm_map=config.normalization_mapping, - stats=dataset_stats, - ), DeviceProcessorStep(device="cpu"), ] return ( @@ -64,20 +57,7 @@ def make_vla_jepa_pre_post_processors( @ProcessorStepRegistry.register(name="vla_jepa_new_line_processor") class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep): def complementary_data(self, complementary_data): - if "task" not in complementary_data: - return complementary_data - - task = complementary_data["task"] - if task is None: - return complementary_data - - new_complementary_data = dict(complementary_data) - if isinstance(task, str): - if not task.endswith("\n"): - new_complementary_data["task"] = f"{task}\n" - elif isinstance(task, list) and all(isinstance(t, str) for t in task): - new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] - return new_complementary_data + return complementary_data def transform_features(self, features): return features diff --git a/src/lerobot/policies/vla_jepa/qwen_interface.py b/src/lerobot/policies/vla_jepa/qwen_interface.py index c4cf64ab9..1031f837b 100644 --- a/src/lerobot/policies/vla_jepa/qwen_interface.py +++ b/src/lerobot/policies/vla_jepa/qwen_interface.py @@ -39,7 +39,10 @@ class Qwen3VLInterface(torch.nn.Module): return torch.bfloat16 def expand_tokenizer(self) -> tuple[list[str], list[int], int]: - max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep + # starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4, + # independent of vj2 num_action_tokens_per_timestep. Keeping this count + # is required for Qwen embedding/lm_head checkpoint shapes to match. + max_action_tokens = self.config.chunk_size * 4 tokenizer = self.processor.tokenizer action_tokens = [] action_token_ids = [] diff --git a/src/lerobot/policies/vla_jepa/world_model.py b/src/lerobot/policies/vla_jepa/world_model.py index 4a398e7df..1df495e82 100644 --- a/src/lerobot/policies/vla_jepa/world_model.py +++ b/src/lerobot/policies/vla_jepa/world_model.py @@ -5,59 +5,298 @@ import torch.nn.functional as F # noqa: N812 from torch import nn -def build_block_causal_attention_mask(num_steps: int, tokens_per_step: int, cond_tokens: int) -> torch.Tensor: - total_tokens = num_steps * (tokens_per_step + cond_tokens) - mask = torch.full((total_tokens, total_tokens), float("-inf")) - for current_step in range(num_steps): - row_start = current_step * (tokens_per_step + cond_tokens) - row_end = row_start + tokens_per_step + cond_tokens - mask[row_start:row_end, :row_end] = 0 +def build_action_block_causal_attention_mask( + num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1 +) -> torch.Tensor: + tokens_per_frame = add_tokens + grid_height * grid_width + num_tokens = num_frames * tokens_per_frame + mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool) + mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool) + local_window_time = num_frames + + for current_frame in range(num_frames): + first_context_frame = max(0, current_frame - local_window_time + 1) + for context_frame in range(first_context_frame, current_frame + 1): + row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame) + col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame) + mask[row, col] = mask_block return mask -class _Attention(nn.Module): - def __init__(self, embed_dim: int, num_heads: int) -> None: +def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + _, _, _, dim = x.size() + if dim % 2 != 0: + raise ValueError("Embedding dimension must be even for rotary position encoding.") + + omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device) + omega /= dim / 2.0 + omega = 1.0 / 10000**omega + freqs = torch.einsum("..., f -> ... f", pos, omega) + emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2) + emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2) + + y = x.unflatten(-1, (-1, 2)) + y1, y2 = y.unbind(dim=-1) + y = torch.stack((-y2, y1), dim=-1).flatten(-2) + return x * emb_cos + y * emb_sin + + +class DropPath(nn.Module): + def __init__(self, drop_prob: float = 0.0) -> None: super().__init__() - self.num_heads = num_heads - self.head_dim = embed_dim // num_heads - self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True) - self.proj = nn.Linear(embed_dim, embed_dim, bias=True) - - def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor: - b, n, c = x.shape - qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) - x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - return self.proj(x.transpose(1, 2).reshape(b, n, c)) - - -class _MLP(nn.Module): - def __init__(self, embed_dim: int, mlp_ratio: float) -> None: - super().__init__() - hidden = int(embed_dim * mlp_ratio) - self.fc1 = nn.Linear(embed_dim, hidden) - self.fc2 = nn.Linear(hidden, embed_dim) + self.drop_prob = drop_prob def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.fc2(F.gelu(self.fc1(x))) + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() + return x.div(keep_prob) * random_tensor -class _PredictorBlock(nn.Module): - def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float) -> None: +class MLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer: type[nn.Module] = nn.GELU, + drop: float = 0.0, + ) -> None: super().__init__() - self.norm1 = nn.LayerNorm(embed_dim) - self.attn = _Attention(embed_dim, num_heads) - self.norm2 = nn.LayerNorm(embed_dim) - self.mlp = _MLP(embed_dim, mlp_ratio) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) - def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor: - x = x + self.attn(self.norm1(x), attn_mask=attn_mask) - return x + self.mlp(self.norm2(x)) + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ACRoPEAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: float | None = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + use_sdpa: bool = True, + is_causal: bool = False, + grid_size: int = 16, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + self.d_dim = int(2 * ((self.head_dim // 3) // 2)) + self.h_dim = int(2 * ((self.head_dim // 3) // 2)) + self.w_dim = int(2 * ((self.head_dim // 3) // 2)) + self.grid_size = grid_size + self.is_causal = is_causal + + @staticmethod + def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor: + return ids // int(height * width) + + def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor: + frame_ids = self._get_frame_pos(ids, height, width) + ids = ids - int(height * width) * frame_ids + return ids // width + + def separate_positions( + self, ids: torch.Tensor, height: int, width: int + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + frame_ids = self._get_frame_pos(ids, height, width) + height_ids = self._get_height_pos(ids, height, width) + width_ids = ids - int(height * width) * frame_ids - width * height_ids + return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None = None, + attn_mask: torch.Tensor | None = None, + num_frames: int | None = None, + grid_height: int | None = None, + grid_width: int | None = None, + action_tokens: int = 0, + ) -> torch.Tensor: + batch_size, num_tokens, channels = x.size() + if num_frames is None or grid_height is None or grid_width is None: + raise ValueError("num_frames, grid_height and grid_width are required.") + + if mask is not None: + mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1) + d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width) + else: + mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device) + d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width) + + h_mask *= self.grid_size / grid_height + w_mask *= self.grid_size / grid_width + + if action_tokens > 0: + x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels) + action_q, action_k, action_v = [], [], [] + for idx in range(action_tokens): + action_token = x[:, :, idx : idx + 1, :].flatten(1, 2) + qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + qd = rotate_queries_or_keys(q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)) + kd = rotate_queries_or_keys(k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device)) + qr = q[..., self.d_dim :] + kr = k[..., self.d_dim :] + action_q.append(torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)) + action_k.append(torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1)) + action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1)) + + action_q = torch.cat(action_q, dim=3).flatten(2, 3) + action_k = torch.cat(action_k, dim=3).flatten(2, 3) + action_v = torch.cat(action_v, dim=3).flatten(2, 3) + x = x[:, :, action_tokens:, :].flatten(1, 2) + + qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + offset = 0 + qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask) + kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask) + offset += self.d_dim + qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask) + kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask) + offset += self.h_dim + qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask) + kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask) + offset += self.w_dim + + if offset < self.head_dim: + q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1) + k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1) + else: + q = torch.cat([qd, qh, qw], dim=-1) + k = torch.cat([kd, kh, kw], dim=-1) + + if action_tokens > 0: + + def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor: + frame_tokens = frame_tokens.view( + batch_size, self.num_heads, num_frames, grid_height * grid_width, -1 + ) + action_token_values = action_token_values.view( + batch_size, self.num_heads, num_frames, action_tokens, -1 + ) + return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3) + + q = merge(q, action_q) + k = merge(k, action_k) + v = merge(v, action_v) + + if attn_mask is not None or self.use_sdpa: + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask + ) + else: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels) + x = self.proj(x) + return self.proj_drop(x) + + +class ACBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_scale: float | None = None, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + norm_layer: type[nn.Module] = nn.LayerNorm, + use_sdpa: bool = True, + is_causal: bool = False, + grid_size: int = 16, + use_rope: bool = True, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + if not use_rope: + raise ValueError("JEVLA1 world predictor uses AC RoPE attention.") + self.attn = ACRoPEAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + use_sdpa=use_sdpa, + is_causal=is_causal, + grid_size=grid_size, + proj_drop=drop, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = MLP( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=nn.GELU, + drop=drop, + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor | None = None, + num_frames: int | None = None, + grid_height: int | None = None, + grid_width: int | None = None, + action_tokens: int = 0, + ) -> torch.Tensor: + y = self.norm1(x) + y = self.attn( + y, + mask=None, + attn_mask=attn_mask, + num_frames=num_frames, + grid_height=grid_height, + grid_width=grid_width, + action_tokens=action_tokens, + ) + x = x + self.drop_path(y) + y = self.norm2(x) + return x + self.drop_path(self.mlp(y)) class ActionConditionedVideoPredictor(nn.Module): + """JEVLA1-compatible action-conditioned V-JEPA predictor.""" + def __init__( self, + num_frames: int, + img_size: tuple[int, int], + patch_size: int, + tubelet_size: int, embed_dim: int, action_embed_dim: int, predictor_embed_dim: int, @@ -65,40 +304,93 @@ class ActionConditionedVideoPredictor(nn.Module): num_heads: int, mlp_ratio: float, num_action_tokens_per_step: int, + use_extrinsics: bool = False, ) -> None: super().__init__() - self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim) - self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim) + self.is_frame_causal = True + self.use_extrinsics = use_extrinsics + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True) + self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True) + self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True) + + self.img_height, self.img_width = img_size + self.patch_size = patch_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.grid_height = self.img_height // self.patch_size + self.grid_width = self.img_width // self.patch_size + self.predictor_blocks = nn.ModuleList( - [_PredictorBlock(predictor_embed_dim, num_heads, mlp_ratio) for _ in range(depth)] + [ + ACBlock( + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6), + grid_size=self.grid_height, + use_rope=True, + ) + for _ in range(depth) + ] ) - self.predictor_norm = nn.LayerNorm(predictor_embed_dim) - self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim) + self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) self.num_action_tokens_per_step = num_action_tokens_per_step - def forward(self, frame_tokens: torch.Tensor, action_tokens: torch.Tensor) -> torch.Tensor: - batch_size, num_steps, tokens_per_frame, _ = frame_tokens.shape - _, action_steps, _, _ = action_tokens.shape - if action_steps != num_steps: - raise ValueError(f"Expected {num_steps} action steps, got {action_steps}.") + @property + def norm(self) -> nn.LayerNorm: + return self.predictor_norm - frame_tokens = self.predictor_embed(frame_tokens) - action_tokens = self.action_encoder(action_tokens) - fused_steps = [ - torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1) for step in range(num_steps) - ] - fused = torch.cat(fused_steps, dim=1) + @property + def proj(self) -> nn.Linear: + return self.predictor_proj - attn_mask = build_block_causal_attention_mask( - num_steps=num_steps, - tokens_per_step=tokens_per_frame, - cond_tokens=self.num_action_tokens_per_step, - ).to(device=fused.device, dtype=fused.dtype) + def forward( + self, + frame_tokens: torch.Tensor, + action_tokens: torch.Tensor, + extrinsics: torch.Tensor | None = None, + ) -> torch.Tensor: + # starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D]. + x = self.predictor_embed(frame_tokens) + batch_size, num_context_tokens, hidden_dim = x.size() + num_frames = num_context_tokens // (self.grid_height * self.grid_width) + + actions = self.action_encoder(action_tokens) + actions = actions.view(batch_size, num_frames, -1, hidden_dim) + cond_tokens = actions.shape[2] + + x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim) + if self.use_extrinsics: + if extrinsics is None: + raise ValueError("extrinsics are required when use_extrinsics=True.") + cond_tokens += 1 + extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2) + x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2) + else: + x = torch.cat([actions, x], dim=2).flatten(1, 2) + + attn_mask = build_action_block_causal_attention_mask( + num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens + ) + attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True) for block in self.predictor_blocks: - fused = block(fused, attn_mask=attn_mask) + x = block( + x, + attn_mask=attn_mask, + num_frames=num_frames, + grid_height=self.grid_height, + grid_width=self.grid_width, + action_tokens=cond_tokens, + ) - fused = self.predictor_norm(fused) - fused = fused.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1) - predicted_frame_tokens = fused[:, :, self.num_action_tokens_per_step :, :] - return self.predictor_proj(predicted_frame_tokens) + x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim) + x = x[:, :, cond_tokens:, :].flatten(1, 2) + x = self.predictor_norm(x) + return self.predictor_proj(x)