lots of changes to make existing weights work, need to massively refactor the pre and post processing

This commit is contained in:
Maxime Ellerbach
2026-05-20 15:24:59 +00:00
committed by Maximellerbach
parent c6bf11b2d5
commit 999cc625d6
7 changed files with 612 additions and 218 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from collections import OrderedDict
from dataclasses import dataclass
from typing import TYPE_CHECKING
@@ -33,17 +34,6 @@ def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class _MLP2(nn.Module):
"""Two-layer GELU MLP with layer1/layer2 attribute names matching the original checkpoint."""
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None:
super().__init__()
self.layer1 = nn.Linear(in_dim, hidden_dim)
self.layer2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layer2(F.gelu(self.layer1(x)))
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, embedding_dim: int):
@@ -108,10 +98,11 @@ class BasicTransformerBlock(nn.Module):
num_attention_heads: int,
attention_head_dim: int,
dropout: float,
cross_attention_dim: int | None,
cross_attention_dim: int,
is_cross_attention: bool = True,
) -> None:
super().__init__()
self.is_cross_attention = cross_attention_dim is not None
self.is_cross_attention = is_cross_attention
self.norm1 = AdaLayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
@@ -132,7 +123,8 @@ class BasicTransformerBlock(nn.Module):
temb: torch.Tensor,
) -> torch.Tensor:
attn_input = self.norm1(hidden_states, temb)
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=encoder_hidden_states)
attention_context = encoder_hidden_states if self.is_cross_attention else None
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
return hidden_states
@@ -160,10 +152,10 @@ class DiT(ModelMixin, ConfigMixin):
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
# Even blocks attend to context (cross-attention), odd blocks are self-attention.
cross_attention_dim=cross_attention_dim if i % 2 == 0 else None,
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
is_cross_attention=layer_idx % 2 == 0,
)
for i in range(num_layers)
for layer_idx in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
@@ -179,8 +171,7 @@ class DiT(ModelMixin, ConfigMixin):
temb = self.timestep_encoder(timestep)
x = hidden_states
for block in self.transformer_blocks:
es = encoder_hidden_states if block.is_cross_attention else None
x = block(x, encoder_hidden_states=es, temb=temb)
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
return self.proj_out_2(x)
@@ -205,34 +196,49 @@ class VLAJEPAActionHead(nn.Module):
super().__init__()
preset = DIT_PRESETS[config.action_model_type]
self.config = config
num_heads = preset.num_attention_heads
head_dim = preset.attention_head_dim
num_heads = config.action_num_heads or preset.num_attention_heads
head_dim = config.action_attention_head_dim or preset.attention_head_dim
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
self.input_embedding_dim = inner_dim
self.action_horizon = config.chunk_size
self.num_inference_timesteps = config.num_inference_timesteps
hidden_size = config.action_hidden_size
self.model = DiT(
num_attention_heads=num_heads,
attention_head_dim=head_dim,
output_dim=config.action_hidden_size,
output_dim=hidden_size,
num_layers=config.action_num_layers,
dropout=config.action_dropout,
cross_attention_dim=cross_attention_dim,
)
# action_encoder/decoder and state_encoder use action_hidden_size (DiT output dim).
# action_encoder and state_encoder produce inner_dim-sized tokens (DiT input width).
# action_decoder takes DiT output (action_hidden_size) and produces action_dim predictions.
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
self.action_decoder = _MLP2(config.action_hidden_size, config.action_hidden_size, config.action_dim)
self.state_encoder = (
_MLP2(config.state_dim, config.action_hidden_size, inner_dim) if config.state_dim > 0 else None
self.action_decoder = nn.Sequential(
OrderedDict([
("layer1", nn.Linear(hidden_size, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, config.action_dim)),
])
)
self.state_encoder = (
nn.Sequential(
OrderedDict([
("layer1", nn.Linear(config.state_dim, hidden_size)),
("relu", nn.ReLU()),
("layer2", nn.Linear(hidden_size, inner_dim)),
])
)
if config.state_dim > 0
else None
)
self.future_tokens = nn.Embedding(
config.num_embodied_action_tokens_per_instruction, inner_dim
)
self.position_embedding = nn.Embedding(
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
inner_dim,
)
# future_tokens and position_embedding operate at inner_dim (DiT input width),
# not at action_hidden_size (DiT output width).
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, inner_dim)
self.position_embedding = nn.Embedding(config.action_max_seq_len, inner_dim)
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@@ -44,6 +45,8 @@ class VLAJEPAConfig(PreTrainedConfig):
action_hidden_size: int = 1024
action_model_type: str = "DiT-B"
action_num_layers: int = 16
action_num_heads: int | None = None
action_attention_head_dim: int | None = None
action_dropout: float = 0.2
action_num_timestep_buckets: int = 1000
action_noise_beta_alpha: float = 1.5
@@ -63,6 +66,9 @@ class VLAJEPAConfig(PreTrainedConfig):
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
action_unnormalization_stats: dict[str, Any] | None = None
binarize_gripper_action: bool = True
clip_normalized_actions: bool = True
torch_dtype: str = "bfloat16"
optimizer_lr: float = 1e-4

View File

@@ -25,36 +25,75 @@ Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed
from __future__ import annotations
import logging
import tempfile
from pathlib import Path
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file as save_safetensors
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
# ---------------------------------------------------------------------------
# Top-level settings
# ---------------------------------------------------------------------------
SOURCE_REPO_ID = "ginwind/VLA-JEPA"
TARGET_ORG = "lerobot"
TARGET_ORG = "maximellerbach"
COLLECTION_TITLE = "VLA-JEPA"
COLLECTION_DESCRIPTION = (
"VLA-JEPA model checkpoints (LIBERO, Pretrain, SimplerEnv) converted from .pt to safetensors via LeRobot."
)
# Remap state-dict key prefixes before loading into the LeRobot policy.
# E.g. {"": "model."} prepends "model." to every key.
# Leave empty if keys already match — the first run's log will tell you.
KEY_PREFIX_REMAP: dict[str, str] = {
# Specific rules must come before the "" catch-all (dict order is preserved).
"qwen_vl_interface.": "model.qwen.",
"vj_encoder.": "model.video_encoder.",
"vj_predictor.": "model.video_predictor.",
# Everything else (action_model.*) just needs the "model." wrapper.
"": "model.",
}
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Key mapping — mirrors todo_converter.py map_key() so both converters
# produce identical safetensors layouts that match the LeRobot action_head code.
# ---------------------------------------------------------------------------
def _normalize_source_key(key: str) -> str:
return key[len("module."):] if key.startswith("module.") else key
def _map_checkpoint_key(raw_key: str) -> str | None:
"""Map original VLA-JEPA state-dict keys to LeRobot vla_jepa layout."""
key = _normalize_source_key(raw_key)
if key.startswith("qwen_vl_interface."):
return "model.qwen." + key[len("qwen_vl_interface."):]
if key.startswith("vj_encoder."):
return "model.video_encoder." + key[len("vj_encoder."):]
if key.startswith("vj_predictor."):
return "model.video_predictor." + key[len("vj_predictor."):]
if key.startswith("action_model."):
# LeRobot code uses the same sub-key names as the source checkpoint,
# so only the top-level "model." prefix needs to be added.
return "model." + key
return None
def _fetch_action_stats(api: "HfApi", source_repo_id: str, subfolder: str) -> dict | None:
"""Try to download dataset_statistics.json and return the action stats dict."""
import json
stats_file = f"{subfolder}/dataset_statistics.json"
try:
local = api.hf_hub_download(source_repo_id, stats_file)
data = json.loads(Path(local).read_text())
# The original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}}
for robot_key in data:
if isinstance(data[robot_key], dict) and "action" in data[robot_key]:
log.info(" Loaded action stats from %s (robot key: %s)", stats_file, robot_key)
return data[robot_key]["action"]
log.warning(" %s found but no 'action' key under any robot — skipping action stats.", stats_file)
except Exception as exc: # noqa: BLE001
log.warning(" Could not fetch %s: %s — action_unnormalization_stats will be None.", stats_file, exc)
return None
# ---------------------------------------------------------------------------
# Architecture — identical across all 4 variants (from config.json)
# ---------------------------------------------------------------------------
@@ -69,6 +108,9 @@ _ARCH = {
"num_inference_timesteps": 4,
"action_hidden_size": 1024,
"action_model_type": "DiT-B",
# Explicit dims matching DiT-B preset and ginwind checkpoint shape
"action_num_heads": 12,
"action_attention_head_dim": 64,
"action_num_layers": 16,
"action_dropout": 0.2,
"repeated_diffusion_steps": 8,
@@ -76,9 +118,6 @@ _ARCH = {
"action_noise_beta_beta": 1.0,
"action_noise_s": 0.999,
"action_num_timestep_buckets": 1000,
# Action head embedding params (from original config.json)
"num_target_vision_tokens": 32,
"action_max_seq_len": 1024,
# World model predictor (12 blocks, confirmed from checkpoint)
"predictor_depth": 12,
}
@@ -109,7 +148,12 @@ _OXE_CAMS = [
# ---------------------------------------------------------------------------
def _build_config(camera_keys: list[str], with_state: bool, enable_world_model: bool = True):
def _build_config(
camera_keys: list[str],
with_state: bool,
enable_world_model: bool = True,
action_stats: dict | None = None,
):
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
@@ -123,6 +167,9 @@ def _build_config(camera_keys: list[str], with_state: bool, enable_world_model:
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
},
enable_world_model=enable_world_model,
action_unnormalization_stats=action_stats,
binarize_gripper_action=True,
clip_normalized_actions=True,
**_ARCH,
)
cfg.validate_features()
@@ -153,20 +200,6 @@ def extract_state_dict(ckpt: object) -> dict[str, torch.Tensor]:
return {k: v for k, v in sd.items() if isinstance(v, torch.Tensor)}
def remap_keys(sd: dict[str, torch.Tensor], remap: dict[str, str]) -> dict[str, torch.Tensor]:
if not remap:
return sd
out = {}
for k, v in sd.items():
new_k = k
for old, new in remap.items():
if k.startswith(old):
new_k = new + k[len(old) :]
break
out[new_k] = v
return out
def subfolder_of(pt_path: str) -> str | None:
for part in Path(pt_path).parts:
if part in VARIANTS:
@@ -229,37 +262,48 @@ def main() -> None:
ckpt = torch.load(local_pt, map_location="cpu") # nosec B614
sd = extract_state_dict(ckpt)
sd = remap_keys(sd, KEY_PREFIX_REMAP)
log.info(" %d tensors extracted", len(sd))
log.info(" First 5 keys: %s", list(sd)[:5])
# 3. Build policy
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
# Map source key names → LeRobot layout (handles layer1→w1, transformer_blocks→blocks, etc.)
mapped_sd: dict[str, torch.Tensor] = {}
skipped_keys: list[str] = []
for raw_key, value in sd.items():
target_key = _map_checkpoint_key(raw_key)
if target_key is None:
skipped_keys.append(raw_key)
else:
mapped_sd[target_key] = value
log.info(" %d tensors mapped, %d skipped", len(mapped_sd), len(skipped_keys))
if skipped_keys:
log.info(" Skipped sample: %s", skipped_keys[:5])
log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5])
config = _build_config(camera_keys, with_state, enable_world_model)
policy = VLAJEPAPolicy(config)
# Fetch action unnormalization stats from the source repo
action_stats = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder)
# 4. Load weights
missing, unexpected = policy.load_state_dict(sd, strict=False)
# 3. Build config (no policy instantiation — avoids loading backbone from Hub)
config = _build_config(camera_keys, with_state, enable_world_model, action_stats)
def _prefix_summary(keys: list[str]) -> dict[str, int]:
from collections import Counter
return dict(Counter(".".join(k.split(".")[:3]) for k in keys).most_common())
if missing:
log.warning(" Missing (%d) by prefix: %s", len(missing), _prefix_summary(missing))
if unexpected:
log.warning(" Unexpected (%d) by prefix: %s", len(unexpected), _prefix_summary(unexpected))
if not missing and not unexpected:
log.info(" State dict loaded cleanly.")
# 5. Push to hub (writes model.safetensors + config.json)
# 4. Save everything to a temp dir and upload in one shot
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)
commit_url = policy.push_to_hub(
repo_id=target_repo_id,
commit_message=f"Convert {Path(pt_filename).name} to safetensors",
)
with tempfile.TemporaryDirectory() as tmp:
save_dir = Path(tmp)
log.info(" Saving model.safetensors …")
save_safetensors(mapped_sd, save_dir / "model.safetensors")
config._save_pretrained(save_dir) # writes config.json via draccus
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config)
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json
log.info(" Uploading …")
commit_url = api.upload_folder(
folder_path=save_dir,
repo_id=target_repo_id,
repo_type="model",
commit_message=f"Convert {Path(pt_filename).name} to safetensors",
)
log.info(" Uploaded → %s", commit_url)
# 6. Add to collection

View File

@@ -9,6 +9,7 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from PIL import Image
from safetensors.torch import load_file as load_safetensors_file
from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -72,9 +73,18 @@ class VLAJEPAModel(nn.Module):
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = max(len(config.image_features), 1)
num_views = max(1, len(config.image_features))
tubelet_size = self.video_encoder.config.tubelet_size
image_size = getattr(self.video_encoder.config, "image_size", None)
if image_size is None:
first_image_shape = next(iter(config.image_features.values())).shape
image_size = first_image_shape[-1]
self.video_predictor = ActionConditionedVideoPredictor(
embed_dim=num_views * self.video_encoder.config.hidden_size,
num_frames=config.num_video_frames // tubelet_size,
img_size=(image_size, image_size),
patch_size=16,
tubelet_size=1,
embed_dim=self.video_encoder.config.hidden_size * num_views,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
@@ -91,17 +101,56 @@ class VLAJEPAModel(nn.Module):
self.qwen.requires_grad_(False)
# Build prompt placeholders.
# Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor.
# This matches the number of context temporal positions after tubelet compression.
n_wm_action_groups = max(1, self.config.num_video_frames // self.config.jepa_tubelet_size - 1)
# Use the encoder's actual tubelet_size when available (world model enabled),
# otherwise fall back to config.
_tubelet_size = (
self.video_encoder.config.tubelet_size
if config.enable_world_model
else self.config.jepa_tubelet_size
)
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[:n_wm_action_groups]
for token in self.action_tokens[:num_action_prompt_steps]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return Qwen's final decoder-layer output before the final RMSNorm.
starVLA trained its downstream heads on the legacy transformers-4.57
`hidden_states[-1]` value, which is the last decoder layer output before
Qwen's final RMSNorm. Newer transformers versions expose `hidden_states[-1]`
as the post-norm last hidden state, so capture the layer output directly.
"""
captured: dict[str, torch.Tensor] = {}
language_model = self.qwen.model.model.language_model
def capture_last_layer_output(
_module: nn.Module,
_inputs: tuple[torch.Tensor, ...],
output: torch.Tensor | tuple[torch.Tensor, ...],
) -> None:
captured["last_hidden"] = output[0] if isinstance(output, tuple) else output
return None
handle = language_model.layers[-1].register_forward_hook(capture_last_layer_output)
try:
self.qwen.model.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
if "last_hidden" not in captured:
raise RuntimeError("Failed to capture Qwen last decoder hidden states.")
return captured["last_hidden"]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
@@ -160,13 +209,7 @@ class VLAJEPAModel(nn.Module):
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
qwen_outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
output_attentions=False,
return_dict=True,
)
last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H]
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
if self.config.enable_world_model:
@@ -211,20 +254,16 @@ class VLAJEPAModel(nn.Module):
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
d_emb = input_states.shape[-1]
input_states_4d = input_states.view(b, t_enc_ctx, tokens_per_frame, d_emb)
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
act_4d = action_tokens[:, :expected_actions].view(
b, t_enc_ctx, self.config.num_action_tokens_per_timestep, -1
)
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
predicted_states = pred_4d.reshape(b, -1, d_emb)
predicted_states = self.video_predictor(
input_states.float(),
action_tokens[:, :expected_actions].float(),
)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
@@ -242,15 +281,14 @@ class VLAJEPAModel(nn.Module):
state_tensor = None
if state is not None:
state_tensor = torch.tensor(
np.array(state), device=last_hidden.device, dtype=torch.float32
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
) # [B, 1, state_dim]
# repeated_diffusion_steps: draw R independent noise samples per batch item (CogACT-style).
# Effectively multiplies data efficiency of the action head by R with no extra Qwen/JEPA cost.
num_repeated = self.config.repeated_diffusion_steps
embodied_rep = embodied_action_tokens.float().repeat(num_repeated, 1, 1)
actions_rep = actions_target.repeat(num_repeated, 1, 1)
state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None
repeated_diffusion_steps = self.config.repeated_diffusion_steps
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
if state_tensor is not None:
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
action_is_pad_rep = None
if action_is_pad is not None:
@@ -263,9 +301,11 @@ class VLAJEPAModel(nn.Module):
]
) # [B, T_full]
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
action_is_pad_rep = pad_tensor.repeat(num_repeated, 1) # [B*R, action_horizon]
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
action_loss = self.action_model(embodied_rep, actions_rep, state_rep, action_is_pad_rep)
action_loss = self.action_model(
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
@@ -289,6 +329,14 @@ class VLAJEPAModel(nn.Module):
Returns:
np.ndarray [B, action_horizon, action_dim] — predicted actions.
"""
if self.config.resize_images_to is not None:
height, width = self.config.resize_images_to
resampling = getattr(Image, "Resampling", Image).BOX
batch_images = [
[image.resize((width, height), resample=resampling) for image in sample_images]
for sample_images in batch_images
]
qwen_inputs = self.qwen.build_inputs(
images=batch_images,
instructions=instructions,
@@ -302,27 +350,19 @@ class VLAJEPAModel(nn.Module):
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
qwen_outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
output_attentions=False,
return_dict=True,
)
last_hidden = qwen_outputs.hidden_states[-1]
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
b, _, h = last_hidden.shape
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
state_tensor = None
if state is not None:
state_tensor = torch.from_numpy(np.array(state)).to(
device=last_hidden.device, dtype=torch.float32
device=last_hidden.device, dtype=last_hidden.dtype
)
with torch.autocast(device_type=device_type, dtype=torch.float32):
# Cast embodied tokens to float32 for action model compatibility
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor
) # [B, action_horizon, action_dim]
pred_actions = self.action_model.predict_action(
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
) # [B, action_horizon, action_dim]
return pred_actions.detach().cpu().numpy()
@@ -546,8 +586,31 @@ class VLAJEPAPolicy(PreTrainedPolicy):
actions_np = self.model.predict_action(batch_images, instructions, state_np)
# Convert back to tensor on the right device
actions_np = self._unnormalize_actions(actions_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
def _unnormalize_actions(self, normalized_actions: np.ndarray) -> np.ndarray:
"""Match starVLA's LIBERO action post-processing exactly."""
stats = self.config.action_unnormalization_stats
if not stats:
return normalized_actions
actions = normalized_actions.astype(np.float32, copy=True)
if self.config.clip_normalized_actions:
actions = np.clip(actions, -1.0, 1.0)
if self.config.binarize_gripper_action and actions.shape[-1] >= 7:
actions[..., 6] = np.where(actions[..., 6] < 0.5, 0.0, 1.0)
action_min = np.asarray(stats["min"], dtype=np.float32)
action_max = np.asarray(stats["max"], dtype=np.float32)
mask = np.asarray(stats.get("mask", np.ones_like(action_min, dtype=bool)), dtype=bool)
scaled = 0.5 * (actions + 1.0) * (action_max - action_min) + action_min
actions = np.where(mask, scaled, actions).astype(np.float32)
if self.config.binarize_gripper_action and actions.shape[-1] >= 7:
actions[..., 6] = 1.0 - 2.0 * (actions[..., 6] > 0.5)
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""

View File

@@ -14,7 +14,6 @@ from lerobot.processor import (
PolicyProcessorPipeline,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@@ -31,7 +30,6 @@ def make_vla_jepa_pre_post_processors(
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
VLAJEPANewLineProcessor(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features,
@@ -40,11 +38,6 @@ def make_vla_jepa_pre_post_processors(
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
@@ -64,20 +57,7 @@ def make_vla_jepa_pre_post_processors(
@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor")
class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep):
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
if isinstance(task, str):
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
return new_complementary_data
return complementary_data
def transform_features(self, features):
return features

View File

@@ -39,7 +39,10 @@ class Qwen3VLInterface(torch.nn.Module):
return torch.bfloat16
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
# is required for Qwen embedding/lm_head checkpoint shapes to match.
max_action_tokens = self.config.chunk_size * 4
tokenizer = self.processor.tokenizer
action_tokens = []
action_token_ids = []

View File

@@ -5,59 +5,298 @@ import torch.nn.functional as F # noqa: N812
from torch import nn
def build_block_causal_attention_mask(num_steps: int, tokens_per_step: int, cond_tokens: int) -> torch.Tensor:
total_tokens = num_steps * (tokens_per_step + cond_tokens)
mask = torch.full((total_tokens, total_tokens), float("-inf"))
for current_step in range(num_steps):
row_start = current_step * (tokens_per_step + cond_tokens)
row_end = row_start + tokens_per_step + cond_tokens
mask[row_start:row_end, :row_end] = 0
def build_action_block_causal_attention_mask(
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
) -> torch.Tensor:
tokens_per_frame = add_tokens + grid_height * grid_width
num_tokens = num_frames * tokens_per_frame
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
local_window_time = num_frames
for current_frame in range(num_frames):
first_context_frame = max(0, current_frame - local_window_time + 1)
for context_frame in range(first_context_frame, current_frame + 1):
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
mask[row, col] = mask_block
return mask
class _Attention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int) -> None:
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
_, _, _, dim = x.size()
if dim % 2 != 0:
raise ValueError("Embedding dimension must be even for rotary position encoding.")
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
omega /= dim / 2.0
omega = 1.0 / 10000**omega
freqs = torch.einsum("..., f -> ... f", pos, omega)
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
y = x.unflatten(-1, (-1, 2))
y1, y2 = y.unbind(dim=-1)
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
return x * emb_cos + y * emb_sin
class DropPath(nn.Module):
def __init__(self, drop_prob: float = 0.0) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
self.proj = nn.Linear(embed_dim, embed_dim, bias=True)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
b, n, c = x.shape
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
return self.proj(x.transpose(1, 2).reshape(b, n, c))
class _MLP(nn.Module):
def __init__(self, embed_dim: int, mlp_ratio: float) -> None:
super().__init__()
hidden = int(embed_dim * mlp_ratio)
self.fc1 = nn.Linear(embed_dim, hidden)
self.fc2 = nn.Linear(hidden, embed_dim)
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(F.gelu(self.fc1(x)))
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class _PredictorBlock(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float) -> None:
class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
act_layer: type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = _Attention(embed_dim, num_heads)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = _MLP(embed_dim, mlp_ratio)
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
return x + self.mlp(self.norm2(x))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ACRoPEAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float | None = None,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop_prob = proj_drop
self.proj_drop = nn.Dropout(proj_drop)
self.use_sdpa = use_sdpa
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
self.grid_size = grid_size
self.is_causal = is_causal
@staticmethod
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
return ids // int(height * width)
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
frame_ids = self._get_frame_pos(ids, height, width)
ids = ids - int(height * width) * frame_ids
return ids // width
def separate_positions(
self, ids: torch.Tensor, height: int, width: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
frame_ids = self._get_frame_pos(ids, height, width)
height_ids = self._get_height_pos(ids, height, width)
width_ids = ids - int(height * width) * frame_ids - width * height_ids
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
batch_size, num_tokens, channels = x.size()
if num_frames is None or grid_height is None or grid_width is None:
raise ValueError("num_frames, grid_height and grid_width are required.")
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
else:
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
h_mask *= self.grid_size / grid_height
w_mask *= self.grid_size / grid_width
if action_tokens > 0:
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
action_q, action_k, action_v = [], [], []
for idx in range(action_tokens):
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
qd = rotate_queries_or_keys(q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device))
kd = rotate_queries_or_keys(k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device))
qr = q[..., self.d_dim :]
kr = k[..., self.d_dim :]
action_q.append(torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1))
action_k.append(torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1))
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
x = x[:, :, action_tokens:, :].flatten(1, 2)
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
offset = 0
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
offset += self.d_dim
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
offset += self.h_dim
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
offset += self.w_dim
if offset < self.head_dim:
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
else:
q = torch.cat([qd, qh, qw], dim=-1)
k = torch.cat([kd, kh, kw], dim=-1)
if action_tokens > 0:
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
frame_tokens = frame_tokens.view(
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
)
action_token_values = action_token_values.view(
batch_size, self.num_heads, num_frames, action_tokens, -1
)
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
q = merge(q, action_q)
k = merge(k, action_k)
v = merge(v, action_v)
if attn_mask is not None or self.use_sdpa:
x = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
x = self.proj(x)
return self.proj_drop(x)
class ACBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_scale: float | None = None,
drop: float = 0.0,
attn_drop: float = 0.0,
drop_path: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
use_sdpa: bool = True,
is_causal: bool = False,
grid_size: int = 16,
use_rope: bool = True,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
if not use_rope:
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
self.attn = ACRoPEAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
use_sdpa=use_sdpa,
is_causal=is_causal,
grid_size=grid_size,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = MLP(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=nn.GELU,
drop=drop,
)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
num_frames: int | None = None,
grid_height: int | None = None,
grid_width: int | None = None,
action_tokens: int = 0,
) -> torch.Tensor:
y = self.norm1(x)
y = self.attn(
y,
mask=None,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=grid_height,
grid_width=grid_width,
action_tokens=action_tokens,
)
x = x + self.drop_path(y)
y = self.norm2(x)
return x + self.drop_path(self.mlp(y))
class ActionConditionedVideoPredictor(nn.Module):
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
def __init__(
self,
num_frames: int,
img_size: tuple[int, int],
patch_size: int,
tubelet_size: int,
embed_dim: int,
action_embed_dim: int,
predictor_embed_dim: int,
@@ -65,40 +304,93 @@ class ActionConditionedVideoPredictor(nn.Module):
num_heads: int,
mlp_ratio: float,
num_action_tokens_per_step: int,
use_extrinsics: bool = False,
) -> None:
super().__init__()
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim)
self.is_frame_causal = True
self.use_extrinsics = use_extrinsics
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
self.img_height, self.img_width = img_size
self.patch_size = patch_size
self.num_frames = num_frames
self.tubelet_size = tubelet_size
self.grid_height = self.img_height // self.patch_size
self.grid_width = self.img_width // self.patch_size
self.predictor_blocks = nn.ModuleList(
[_PredictorBlock(predictor_embed_dim, num_heads, mlp_ratio) for _ in range(depth)]
[
ACBlock(
dim=predictor_embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
grid_size=self.grid_height,
use_rope=True,
)
for _ in range(depth)
]
)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim)
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
self.num_action_tokens_per_step = num_action_tokens_per_step
def forward(self, frame_tokens: torch.Tensor, action_tokens: torch.Tensor) -> torch.Tensor:
batch_size, num_steps, tokens_per_frame, _ = frame_tokens.shape
_, action_steps, _, _ = action_tokens.shape
if action_steps != num_steps:
raise ValueError(f"Expected {num_steps} action steps, got {action_steps}.")
@property
def norm(self) -> nn.LayerNorm:
return self.predictor_norm
frame_tokens = self.predictor_embed(frame_tokens)
action_tokens = self.action_encoder(action_tokens)
fused_steps = [
torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1) for step in range(num_steps)
]
fused = torch.cat(fused_steps, dim=1)
@property
def proj(self) -> nn.Linear:
return self.predictor_proj
attn_mask = build_block_causal_attention_mask(
num_steps=num_steps,
tokens_per_step=tokens_per_frame,
cond_tokens=self.num_action_tokens_per_step,
).to(device=fused.device, dtype=fused.dtype)
def forward(
self,
frame_tokens: torch.Tensor,
action_tokens: torch.Tensor,
extrinsics: torch.Tensor | None = None,
) -> torch.Tensor:
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
x = self.predictor_embed(frame_tokens)
batch_size, num_context_tokens, hidden_dim = x.size()
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
actions = self.action_encoder(action_tokens)
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
cond_tokens = actions.shape[2]
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
if self.use_extrinsics:
if extrinsics is None:
raise ValueError("extrinsics are required when use_extrinsics=True.")
cond_tokens += 1
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
else:
x = torch.cat([actions, x], dim=2).flatten(1, 2)
attn_mask = build_action_block_causal_attention_mask(
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
)
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
for block in self.predictor_blocks:
fused = block(fused, attn_mask=attn_mask)
x = block(
x,
attn_mask=attn_mask,
num_frames=num_frames,
grid_height=self.grid_height,
grid_width=self.grid_width,
action_tokens=cond_tokens,
)
fused = self.predictor_norm(fused)
fused = fused.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1)
predicted_frame_tokens = fused[:, :, self.num_action_tokens_per_step :, :]
return self.predictor_proj(predicted_frame_tokens)
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
x = x[:, :, cond_tokens:, :].flatten(1, 2)
x = self.predictor_norm(x)
return self.predictor_proj(x)