mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
lots of changes to make existing weights work, need to massively refactor the pre and post processing
This commit is contained in:
committed by
Maximellerbach
parent
c6bf11b2d5
commit
999cc625d6
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -33,17 +34,6 @@ def swish(x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class _MLP2(nn.Module):
|
||||
"""Two-layer GELU MLP with layer1/layer2 attribute names matching the original checkpoint."""
|
||||
|
||||
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.layer1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.layer2 = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.layer2(F.gelu(self.layer1(x)))
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim: int):
|
||||
@@ -108,10 +98,11 @@ class BasicTransformerBlock(nn.Module):
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout: float,
|
||||
cross_attention_dim: int | None,
|
||||
cross_attention_dim: int,
|
||||
is_cross_attention: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.is_cross_attention = is_cross_attention
|
||||
self.norm1 = AdaLayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
@@ -132,7 +123,8 @@ class BasicTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
attn_input = self.norm1(hidden_states, temb)
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=encoder_hidden_states)
|
||||
attention_context = encoder_hidden_states if self.is_cross_attention else None
|
||||
hidden_states = hidden_states + self.attn1(attn_input, encoder_hidden_states=attention_context)
|
||||
hidden_states = hidden_states + self.ff(self.norm2(hidden_states))
|
||||
return hidden_states
|
||||
|
||||
@@ -160,10 +152,10 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
# Even blocks attend to context (cross-attention), odd blocks are self-attention.
|
||||
cross_attention_dim=cross_attention_dim if i % 2 == 0 else None,
|
||||
cross_attention_dim=cross_attention_dim if layer_idx % 2 == 0 else self.inner_dim,
|
||||
is_cross_attention=layer_idx % 2 == 0,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
for layer_idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
@@ -179,8 +171,7 @@ class DiT(ModelMixin, ConfigMixin):
|
||||
temb = self.timestep_encoder(timestep)
|
||||
x = hidden_states
|
||||
for block in self.transformer_blocks:
|
||||
es = encoder_hidden_states if block.is_cross_attention else None
|
||||
x = block(x, encoder_hidden_states=es, temb=temb)
|
||||
x = block(x, encoder_hidden_states=encoder_hidden_states, temb=temb)
|
||||
shift, scale = self.proj_out_1(F.silu(temb)).chunk(2, dim=-1)
|
||||
x = self.norm_out(x) * (1 + scale[:, None]) + shift[:, None]
|
||||
return self.proj_out_2(x)
|
||||
@@ -205,34 +196,49 @@ class VLAJEPAActionHead(nn.Module):
|
||||
super().__init__()
|
||||
preset = DIT_PRESETS[config.action_model_type]
|
||||
self.config = config
|
||||
num_heads = preset.num_attention_heads
|
||||
head_dim = preset.attention_head_dim
|
||||
num_heads = config.action_num_heads or preset.num_attention_heads
|
||||
head_dim = config.action_attention_head_dim or preset.attention_head_dim
|
||||
inner_dim = num_heads * head_dim # e.g. DiT-B: 12 × 64 = 768
|
||||
|
||||
self.input_embedding_dim = inner_dim
|
||||
self.action_horizon = config.chunk_size
|
||||
self.num_inference_timesteps = config.num_inference_timesteps
|
||||
|
||||
hidden_size = config.action_hidden_size
|
||||
self.model = DiT(
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=head_dim,
|
||||
output_dim=config.action_hidden_size,
|
||||
output_dim=hidden_size,
|
||||
num_layers=config.action_num_layers,
|
||||
dropout=config.action_dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
# action_encoder/decoder and state_encoder use action_hidden_size (DiT output dim).
|
||||
# action_encoder and state_encoder produce inner_dim-sized tokens (DiT input width).
|
||||
# action_decoder takes DiT output (action_hidden_size) and produces action_dim predictions.
|
||||
self.action_encoder = ActionEncoder(config.action_dim, inner_dim)
|
||||
self.action_decoder = _MLP2(config.action_hidden_size, config.action_hidden_size, config.action_dim)
|
||||
self.state_encoder = (
|
||||
_MLP2(config.state_dim, config.action_hidden_size, inner_dim) if config.state_dim > 0 else None
|
||||
self.action_decoder = nn.Sequential(
|
||||
OrderedDict([
|
||||
("layer1", nn.Linear(hidden_size, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, config.action_dim)),
|
||||
])
|
||||
)
|
||||
self.state_encoder = (
|
||||
nn.Sequential(
|
||||
OrderedDict([
|
||||
("layer1", nn.Linear(config.state_dim, hidden_size)),
|
||||
("relu", nn.ReLU()),
|
||||
("layer2", nn.Linear(hidden_size, inner_dim)),
|
||||
])
|
||||
)
|
||||
if config.state_dim > 0
|
||||
else None
|
||||
)
|
||||
self.future_tokens = nn.Embedding(
|
||||
config.num_embodied_action_tokens_per_instruction, inner_dim
|
||||
)
|
||||
self.position_embedding = nn.Embedding(
|
||||
max(1024, config.chunk_size + config.num_action_tokens_per_timestep + 4),
|
||||
inner_dim,
|
||||
)
|
||||
# future_tokens and position_embedding operate at inner_dim (DiT input width),
|
||||
# not at action_hidden_size (DiT output width).
|
||||
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, inner_dim)
|
||||
self.position_embedding = nn.Embedding(config.action_max_seq_len, inner_dim)
|
||||
self.beta_dist = Beta(config.action_noise_beta_alpha, config.action_noise_beta_beta)
|
||||
|
||||
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
@@ -44,6 +45,8 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
action_hidden_size: int = 1024
|
||||
action_model_type: str = "DiT-B"
|
||||
action_num_layers: int = 16
|
||||
action_num_heads: int | None = None
|
||||
action_attention_head_dim: int | None = None
|
||||
action_dropout: float = 0.2
|
||||
action_num_timestep_buckets: int = 1000
|
||||
action_noise_beta_alpha: float = 1.5
|
||||
@@ -63,6 +66,9 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
action_unnormalization_stats: dict[str, Any] | None = None
|
||||
binarize_gripper_action: bool = True
|
||||
clip_normalized_actions: bool = True
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
|
||||
@@ -25,36 +25,75 @@ Image keys SimplerEnv: OXE Bridge/RT1 are single-camera ✓ confirmed
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file as save_safetensors
|
||||
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level settings
|
||||
# ---------------------------------------------------------------------------
|
||||
SOURCE_REPO_ID = "ginwind/VLA-JEPA"
|
||||
TARGET_ORG = "lerobot"
|
||||
TARGET_ORG = "maximellerbach"
|
||||
COLLECTION_TITLE = "VLA-JEPA"
|
||||
COLLECTION_DESCRIPTION = (
|
||||
"VLA-JEPA model checkpoints (LIBERO, Pretrain, SimplerEnv) converted from .pt to safetensors via LeRobot."
|
||||
)
|
||||
|
||||
# Remap state-dict key prefixes before loading into the LeRobot policy.
|
||||
# E.g. {"": "model."} prepends "model." to every key.
|
||||
# Leave empty if keys already match — the first run's log will tell you.
|
||||
KEY_PREFIX_REMAP: dict[str, str] = {
|
||||
# Specific rules must come before the "" catch-all (dict order is preserved).
|
||||
"qwen_vl_interface.": "model.qwen.",
|
||||
"vj_encoder.": "model.video_encoder.",
|
||||
"vj_predictor.": "model.video_predictor.",
|
||||
# Everything else (action_model.*) just needs the "model." wrapper.
|
||||
"": "model.",
|
||||
}
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Key mapping — mirrors todo_converter.py map_key() so both converters
|
||||
# produce identical safetensors layouts that match the LeRobot action_head code.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_source_key(key: str) -> str:
|
||||
return key[len("module."):] if key.startswith("module.") else key
|
||||
|
||||
|
||||
def _map_checkpoint_key(raw_key: str) -> str | None:
|
||||
"""Map original VLA-JEPA state-dict keys to LeRobot vla_jepa layout."""
|
||||
key = _normalize_source_key(raw_key)
|
||||
|
||||
if key.startswith("qwen_vl_interface."):
|
||||
return "model.qwen." + key[len("qwen_vl_interface."):]
|
||||
if key.startswith("vj_encoder."):
|
||||
return "model.video_encoder." + key[len("vj_encoder."):]
|
||||
if key.startswith("vj_predictor."):
|
||||
return "model.video_predictor." + key[len("vj_predictor."):]
|
||||
if key.startswith("action_model."):
|
||||
# LeRobot code uses the same sub-key names as the source checkpoint,
|
||||
# so only the top-level "model." prefix needs to be added.
|
||||
return "model." + key
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_action_stats(api: "HfApi", source_repo_id: str, subfolder: str) -> dict | None:
|
||||
"""Try to download dataset_statistics.json and return the action stats dict."""
|
||||
import json
|
||||
|
||||
stats_file = f"{subfolder}/dataset_statistics.json"
|
||||
try:
|
||||
local = api.hf_hub_download(source_repo_id, stats_file)
|
||||
data = json.loads(Path(local).read_text())
|
||||
# The original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}}
|
||||
for robot_key in data:
|
||||
if isinstance(data[robot_key], dict) and "action" in data[robot_key]:
|
||||
log.info(" Loaded action stats from %s (robot key: %s)", stats_file, robot_key)
|
||||
return data[robot_key]["action"]
|
||||
log.warning(" %s found but no 'action' key under any robot — skipping action stats.", stats_file)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(" Could not fetch %s: %s — action_unnormalization_stats will be None.", stats_file, exc)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Architecture — identical across all 4 variants (from config.json)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -69,6 +108,9 @@ _ARCH = {
|
||||
"num_inference_timesteps": 4,
|
||||
"action_hidden_size": 1024,
|
||||
"action_model_type": "DiT-B",
|
||||
# Explicit dims matching DiT-B preset and ginwind checkpoint shape
|
||||
"action_num_heads": 12,
|
||||
"action_attention_head_dim": 64,
|
||||
"action_num_layers": 16,
|
||||
"action_dropout": 0.2,
|
||||
"repeated_diffusion_steps": 8,
|
||||
@@ -76,9 +118,6 @@ _ARCH = {
|
||||
"action_noise_beta_beta": 1.0,
|
||||
"action_noise_s": 0.999,
|
||||
"action_num_timestep_buckets": 1000,
|
||||
# Action head embedding params (from original config.json)
|
||||
"num_target_vision_tokens": 32,
|
||||
"action_max_seq_len": 1024,
|
||||
# World model predictor (12 blocks, confirmed from checkpoint)
|
||||
"predictor_depth": 12,
|
||||
}
|
||||
@@ -109,7 +148,12 @@ _OXE_CAMS = [
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_config(camera_keys: list[str], with_state: bool, enable_world_model: bool = True):
|
||||
def _build_config(
|
||||
camera_keys: list[str],
|
||||
with_state: bool,
|
||||
enable_world_model: bool = True,
|
||||
action_stats: dict | None = None,
|
||||
):
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
@@ -123,6 +167,9 @@ def _build_config(camera_keys: list[str], with_state: bool, enable_world_model:
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
},
|
||||
enable_world_model=enable_world_model,
|
||||
action_unnormalization_stats=action_stats,
|
||||
binarize_gripper_action=True,
|
||||
clip_normalized_actions=True,
|
||||
**_ARCH,
|
||||
)
|
||||
cfg.validate_features()
|
||||
@@ -153,20 +200,6 @@ def extract_state_dict(ckpt: object) -> dict[str, torch.Tensor]:
|
||||
return {k: v for k, v in sd.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
|
||||
def remap_keys(sd: dict[str, torch.Tensor], remap: dict[str, str]) -> dict[str, torch.Tensor]:
|
||||
if not remap:
|
||||
return sd
|
||||
out = {}
|
||||
for k, v in sd.items():
|
||||
new_k = k
|
||||
for old, new in remap.items():
|
||||
if k.startswith(old):
|
||||
new_k = new + k[len(old) :]
|
||||
break
|
||||
out[new_k] = v
|
||||
return out
|
||||
|
||||
|
||||
def subfolder_of(pt_path: str) -> str | None:
|
||||
for part in Path(pt_path).parts:
|
||||
if part in VARIANTS:
|
||||
@@ -229,37 +262,48 @@ def main() -> None:
|
||||
ckpt = torch.load(local_pt, map_location="cpu") # nosec B614
|
||||
|
||||
sd = extract_state_dict(ckpt)
|
||||
sd = remap_keys(sd, KEY_PREFIX_REMAP)
|
||||
log.info(" %d tensors extracted", len(sd))
|
||||
log.info(" First 5 keys: %s", list(sd)[:5])
|
||||
|
||||
# 3. Build policy
|
||||
from lerobot.policies.vla_jepa.modeling_vla_jepa import VLAJEPAPolicy
|
||||
# Map source key names → LeRobot layout (handles layer1→w1, transformer_blocks→blocks, etc.)
|
||||
mapped_sd: dict[str, torch.Tensor] = {}
|
||||
skipped_keys: list[str] = []
|
||||
for raw_key, value in sd.items():
|
||||
target_key = _map_checkpoint_key(raw_key)
|
||||
if target_key is None:
|
||||
skipped_keys.append(raw_key)
|
||||
else:
|
||||
mapped_sd[target_key] = value
|
||||
log.info(" %d tensors mapped, %d skipped", len(mapped_sd), len(skipped_keys))
|
||||
if skipped_keys:
|
||||
log.info(" Skipped sample: %s", skipped_keys[:5])
|
||||
log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5])
|
||||
|
||||
config = _build_config(camera_keys, with_state, enable_world_model)
|
||||
policy = VLAJEPAPolicy(config)
|
||||
# Fetch action unnormalization stats from the source repo
|
||||
action_stats = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder)
|
||||
|
||||
# 4. Load weights
|
||||
missing, unexpected = policy.load_state_dict(sd, strict=False)
|
||||
# 3. Build config (no policy instantiation — avoids loading backbone from Hub)
|
||||
config = _build_config(camera_keys, with_state, enable_world_model, action_stats)
|
||||
|
||||
def _prefix_summary(keys: list[str]) -> dict[str, int]:
|
||||
from collections import Counter
|
||||
|
||||
return dict(Counter(".".join(k.split(".")[:3]) for k in keys).most_common())
|
||||
|
||||
if missing:
|
||||
log.warning(" Missing (%d) by prefix: %s", len(missing), _prefix_summary(missing))
|
||||
if unexpected:
|
||||
log.warning(" Unexpected (%d) by prefix: %s", len(unexpected), _prefix_summary(unexpected))
|
||||
if not missing and not unexpected:
|
||||
log.info(" State dict loaded cleanly.")
|
||||
|
||||
# 5. Push to hub (writes model.safetensors + config.json)
|
||||
# 4. Save everything to a temp dir and upload in one shot
|
||||
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)
|
||||
commit_url = policy.push_to_hub(
|
||||
repo_id=target_repo_id,
|
||||
commit_message=f"Convert {Path(pt_filename).name} to safetensors",
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
save_dir = Path(tmp)
|
||||
|
||||
log.info(" Saving model.safetensors …")
|
||||
save_safetensors(mapped_sd, save_dir / "model.safetensors")
|
||||
|
||||
config._save_pretrained(save_dir) # writes config.json via draccus
|
||||
|
||||
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config)
|
||||
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
|
||||
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json
|
||||
|
||||
log.info(" Uploading …")
|
||||
commit_url = api.upload_folder(
|
||||
folder_path=save_dir,
|
||||
repo_id=target_repo_id,
|
||||
repo_type="model",
|
||||
commit_message=f"Convert {Path(pt_filename).name} to safetensors",
|
||||
)
|
||||
log.info(" Uploaded → %s", commit_url)
|
||||
|
||||
# 6. Add to collection
|
||||
|
||||
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file as load_safetensors_file
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
@@ -72,9 +73,18 @@ class VLAJEPAModel(nn.Module):
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
num_views = max(len(config.image_features), 1)
|
||||
num_views = max(1, len(config.image_features))
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
image_size = getattr(self.video_encoder.config, "image_size", None)
|
||||
if image_size is None:
|
||||
first_image_shape = next(iter(config.image_features.values())).shape
|
||||
image_size = first_image_shape[-1]
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
embed_dim=num_views * self.video_encoder.config.hidden_size,
|
||||
num_frames=config.num_video_frames // tubelet_size,
|
||||
img_size=(image_size, image_size),
|
||||
patch_size=16,
|
||||
tubelet_size=1,
|
||||
embed_dim=self.video_encoder.config.hidden_size * num_views,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
@@ -91,17 +101,56 @@ class VLAJEPAModel(nn.Module):
|
||||
self.qwen.requires_grad_(False)
|
||||
|
||||
# Build prompt placeholders.
|
||||
# Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor.
|
||||
# This matches the number of context temporal positions after tubelet compression.
|
||||
n_wm_action_groups = max(1, self.config.num_video_frames // self.config.jepa_tubelet_size - 1)
|
||||
# Use the encoder's actual tubelet_size when available (world model enabled),
|
||||
# otherwise fall back to config.
|
||||
_tubelet_size = (
|
||||
self.video_encoder.config.tubelet_size
|
||||
if config.enable_world_model
|
||||
else self.config.jepa_tubelet_size
|
||||
)
|
||||
num_action_prompt_steps = self.config.num_video_frames // _tubelet_size - 1
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[:n_wm_action_groups]
|
||||
for token in self.action_tokens[:num_action_prompt_steps]
|
||||
)
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return Qwen's final decoder-layer output before the final RMSNorm.
|
||||
|
||||
starVLA trained its downstream heads on the legacy transformers-4.57
|
||||
`hidden_states[-1]` value, which is the last decoder layer output before
|
||||
Qwen's final RMSNorm. Newer transformers versions expose `hidden_states[-1]`
|
||||
as the post-norm last hidden state, so capture the layer output directly.
|
||||
"""
|
||||
captured: dict[str, torch.Tensor] = {}
|
||||
language_model = self.qwen.model.model.language_model
|
||||
|
||||
def capture_last_layer_output(
|
||||
_module: nn.Module,
|
||||
_inputs: tuple[torch.Tensor, ...],
|
||||
output: torch.Tensor | tuple[torch.Tensor, ...],
|
||||
) -> None:
|
||||
captured["last_hidden"] = output[0] if isinstance(output, tuple) else output
|
||||
return None
|
||||
|
||||
handle = language_model.layers[-1].register_forward_hook(capture_last_layer_output)
|
||||
try:
|
||||
self.qwen.model.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
if "last_hidden" not in captured:
|
||||
raise RuntimeError("Failed to capture Qwen last decoder hidden states.")
|
||||
return captured["last_hidden"]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
def forward(self, examples: list[dict]) -> dict[str, Tensor]:
|
||||
@@ -160,13 +209,7 @@ class VLAJEPAModel(nn.Module):
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
qwen_outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H]
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
|
||||
if self.config.enable_world_model:
|
||||
@@ -211,20 +254,16 @@ class VLAJEPAModel(nn.Module):
|
||||
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
d_emb = input_states.shape[-1]
|
||||
|
||||
input_states_4d = input_states.view(b, t_enc_ctx, tokens_per_frame, d_emb)
|
||||
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
act_4d = action_tokens[:, :expected_actions].view(
|
||||
b, t_enc_ctx, self.config.num_action_tokens_per_timestep, -1
|
||||
)
|
||||
|
||||
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
|
||||
predicted_states = pred_4d.reshape(b, -1, d_emb)
|
||||
predicted_states = self.video_predictor(
|
||||
input_states.float(),
|
||||
action_tokens[:, :expected_actions].float(),
|
||||
)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
@@ -242,15 +281,14 @@ class VLAJEPAModel(nn.Module):
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.tensor(
|
||||
np.array(state), device=last_hidden.device, dtype=torch.float32
|
||||
np.array(state), device=last_hidden.device, dtype=last_hidden.dtype
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
# repeated_diffusion_steps: draw R independent noise samples per batch item (CogACT-style).
|
||||
# Effectively multiplies data efficiency of the action head by R with no extra Qwen/JEPA cost.
|
||||
num_repeated = self.config.repeated_diffusion_steps
|
||||
embodied_rep = embodied_action_tokens.float().repeat(num_repeated, 1, 1)
|
||||
actions_rep = actions_target.repeat(num_repeated, 1, 1)
|
||||
state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None
|
||||
repeated_diffusion_steps = self.config.repeated_diffusion_steps
|
||||
actions_target = actions_target.repeat(repeated_diffusion_steps, 1, 1)
|
||||
embodied_action_tokens = embodied_action_tokens.repeat(repeated_diffusion_steps, 1, 1)
|
||||
if state_tensor is not None:
|
||||
state_tensor = state_tensor.repeat(repeated_diffusion_steps, 1, 1)
|
||||
|
||||
action_is_pad_rep = None
|
||||
if action_is_pad is not None:
|
||||
@@ -263,9 +301,11 @@ class VLAJEPAModel(nn.Module):
|
||||
]
|
||||
) # [B, T_full]
|
||||
pad_tensor = pad_tensor[:, -action_horizon:] # [B, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(num_repeated, 1) # [B*R, action_horizon]
|
||||
action_is_pad_rep = pad_tensor.repeat(repeated_diffusion_steps, 1) # [B*R, action_horizon]
|
||||
|
||||
action_loss = self.action_model(embodied_rep, actions_rep, state_rep, action_is_pad_rep)
|
||||
action_loss = self.action_model(
|
||||
embodied_action_tokens, actions_target, state_tensor, action_is_pad_rep
|
||||
)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
@@ -289,6 +329,14 @@ class VLAJEPAModel(nn.Module):
|
||||
Returns:
|
||||
np.ndarray [B, action_horizon, action_dim] — predicted actions.
|
||||
"""
|
||||
if self.config.resize_images_to is not None:
|
||||
height, width = self.config.resize_images_to
|
||||
resampling = getattr(Image, "Resampling", Image).BOX
|
||||
batch_images = [
|
||||
[image.resize((width, height), resample=resampling) for image in sample_images]
|
||||
for sample_images in batch_images
|
||||
]
|
||||
|
||||
qwen_inputs = self.qwen.build_inputs(
|
||||
images=batch_images,
|
||||
instructions=instructions,
|
||||
@@ -302,27 +350,19 @@ class VLAJEPAModel(nn.Module):
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
qwen_outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
last_hidden = qwen_outputs.hidden_states[-1]
|
||||
last_hidden = self._qwen_last_decoder_hidden(qwen_inputs) # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
state_tensor = None
|
||||
if state is not None:
|
||||
state_tensor = torch.from_numpy(np.array(state)).to(
|
||||
device=last_hidden.device, dtype=torch.float32
|
||||
device=last_hidden.device, dtype=last_hidden.dtype
|
||||
)
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
# Cast embodied tokens to float32 for action model compatibility
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor
|
||||
) # [B, action_horizon, action_dim]
|
||||
pred_actions = self.action_model.predict_action(
|
||||
embodied_action_tokens.float(), state_tensor.float() if state_tensor is not None else None
|
||||
) # [B, action_horizon, action_dim]
|
||||
|
||||
return pred_actions.detach().cpu().numpy()
|
||||
|
||||
@@ -546,8 +586,31 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
actions_np = self.model.predict_action(batch_images, instructions, state_np)
|
||||
|
||||
# Convert back to tensor on the right device
|
||||
actions_np = self._unnormalize_actions(actions_np)
|
||||
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
|
||||
|
||||
def _unnormalize_actions(self, normalized_actions: np.ndarray) -> np.ndarray:
|
||||
"""Match starVLA's LIBERO action post-processing exactly."""
|
||||
stats = self.config.action_unnormalization_stats
|
||||
if not stats:
|
||||
return normalized_actions
|
||||
|
||||
actions = normalized_actions.astype(np.float32, copy=True)
|
||||
if self.config.clip_normalized_actions:
|
||||
actions = np.clip(actions, -1.0, 1.0)
|
||||
|
||||
if self.config.binarize_gripper_action and actions.shape[-1] >= 7:
|
||||
actions[..., 6] = np.where(actions[..., 6] < 0.5, 0.0, 1.0)
|
||||
|
||||
action_min = np.asarray(stats["min"], dtype=np.float32)
|
||||
action_max = np.asarray(stats["max"], dtype=np.float32)
|
||||
mask = np.asarray(stats.get("mask", np.ones_like(action_min, dtype=bool)), dtype=bool)
|
||||
scaled = 0.5 * (actions + 1.0) * (action_max - action_min) + action_min
|
||||
actions = np.where(mask, scaled, actions).astype(np.float32)
|
||||
if self.config.binarize_gripper_action and actions.shape[-1] >= 7:
|
||||
actions[..., 6] = 1.0 - 2.0 * (actions[..., 6] > 0.5)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""LeRobot select_action with action queue caching."""
|
||||
|
||||
@@ -14,7 +14,6 @@ from lerobot.processor import (
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
@@ -31,7 +30,6 @@ def make_vla_jepa_pre_post_processors(
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
VLAJEPANewLineProcessor(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
@@ -40,11 +38,6 @@ def make_vla_jepa_pre_post_processors(
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
@@ -64,20 +57,7 @@ def make_vla_jepa_pre_post_processors(
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_new_line_processor")
|
||||
class VLAJEPANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
if isinstance(task, str):
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
return new_complementary_data
|
||||
return complementary_data
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
@@ -39,7 +39,10 @@ class Qwen3VLInterface(torch.nn.Module):
|
||||
return torch.bfloat16
|
||||
|
||||
def expand_tokenizer(self) -> tuple[list[str], list[int], int]:
|
||||
max_action_tokens = self.config.chunk_size * self.config.num_action_tokens_per_timestep
|
||||
# starVLA/JEVLA checkpoints expand action tokens as action_horizon * 4,
|
||||
# independent of vj2 num_action_tokens_per_timestep. Keeping this count
|
||||
# is required for Qwen embedding/lm_head checkpoint shapes to match.
|
||||
max_action_tokens = self.config.chunk_size * 4
|
||||
tokenizer = self.processor.tokenizer
|
||||
action_tokens = []
|
||||
action_token_ids = []
|
||||
|
||||
@@ -5,59 +5,298 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import nn
|
||||
|
||||
|
||||
def build_block_causal_attention_mask(num_steps: int, tokens_per_step: int, cond_tokens: int) -> torch.Tensor:
|
||||
total_tokens = num_steps * (tokens_per_step + cond_tokens)
|
||||
mask = torch.full((total_tokens, total_tokens), float("-inf"))
|
||||
for current_step in range(num_steps):
|
||||
row_start = current_step * (tokens_per_step + cond_tokens)
|
||||
row_end = row_start + tokens_per_step + cond_tokens
|
||||
mask[row_start:row_end, :row_end] = 0
|
||||
def build_action_block_causal_attention_mask(
|
||||
num_frames: int, grid_height: int, grid_width: int, add_tokens: int = 1
|
||||
) -> torch.Tensor:
|
||||
tokens_per_frame = add_tokens + grid_height * grid_width
|
||||
num_tokens = num_frames * tokens_per_frame
|
||||
mask = torch.zeros(num_tokens, num_tokens, dtype=torch.bool)
|
||||
mask_block = torch.ones(tokens_per_frame, tokens_per_frame, dtype=torch.bool)
|
||||
local_window_time = num_frames
|
||||
|
||||
for current_frame in range(num_frames):
|
||||
first_context_frame = max(0, current_frame - local_window_time + 1)
|
||||
for context_frame in range(first_context_frame, current_frame + 1):
|
||||
row = slice(current_frame * tokens_per_frame, (current_frame + 1) * tokens_per_frame)
|
||||
col = slice(context_frame * tokens_per_frame, (context_frame + 1) * tokens_per_frame)
|
||||
mask[row, col] = mask_block
|
||||
return mask
|
||||
|
||||
|
||||
class _Attention(nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int) -> None:
|
||||
def rotate_queries_or_keys(x: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, dim = x.size()
|
||||
if dim % 2 != 0:
|
||||
raise ValueError("Embedding dimension must be even for rotary position encoding.")
|
||||
|
||||
omega = torch.arange(dim // 2, dtype=x.dtype, device=x.device)
|
||||
omega /= dim / 2.0
|
||||
omega = 1.0 / 10000**omega
|
||||
freqs = torch.einsum("..., f -> ... f", pos, omega)
|
||||
emb_sin = freqs.sin().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
emb_cos = freqs.cos().squeeze(-1).repeat(1, 1, 1, 2)
|
||||
|
||||
y = x.unflatten(-1, (-1, 2))
|
||||
y1, y2 = y.unbind(dim=-1)
|
||||
y = torch.stack((-y2, y1), dim=-1).flatten(-2)
|
||||
return x * emb_cos + y * emb_sin
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob: float = 0.0) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
|
||||
self.proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
|
||||
b, n, c = x.shape
|
||||
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
return self.proj(x.transpose(1, 2).reshape(b, n, c))
|
||||
|
||||
|
||||
class _MLP(nn.Module):
|
||||
def __init__(self, embed_dim: int, mlp_ratio: float) -> None:
|
||||
super().__init__()
|
||||
hidden = int(embed_dim * mlp_ratio)
|
||||
self.fc1 = nn.Linear(embed_dim, hidden)
|
||||
self.fc2 = nn.Linear(hidden, embed_dim)
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.fc2(F.gelu(self.fc1(x)))
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_()
|
||||
return x.div(keep_prob) * random_tensor
|
||||
|
||||
|
||||
class _PredictorBlock(nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float) -> None:
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.attn = _Attention(embed_dim, num_heads)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.mlp = _MLP(embed_dim, mlp_ratio)
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
|
||||
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
|
||||
return x + self.mlp(self.norm2(x))
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ACRoPEAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float | None = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop_prob = proj_drop
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.d_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.h_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.w_dim = int(2 * ((self.head_dim // 3) // 2))
|
||||
self.grid_size = grid_size
|
||||
self.is_causal = is_causal
|
||||
|
||||
@staticmethod
|
||||
def _get_frame_pos(ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
return ids // int(height * width)
|
||||
|
||||
def _get_height_pos(self, ids: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
ids = ids - int(height * width) * frame_ids
|
||||
return ids // width
|
||||
|
||||
def separate_positions(
|
||||
self, ids: torch.Tensor, height: int, width: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
frame_ids = self._get_frame_pos(ids, height, width)
|
||||
height_ids = self._get_height_pos(ids, height, width)
|
||||
width_ids = ids - int(height * width) * frame_ids - width * height_ids
|
||||
return 1.0 * frame_ids, 1.0 * height_ids, 1.0 * width_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor | None = None,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_tokens, channels = x.size()
|
||||
if num_frames is None or grid_height is None or grid_width is None:
|
||||
raise ValueError("num_frames, grid_height and grid_width are required.")
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
else:
|
||||
mask = torch.arange(int(num_frames * grid_height * grid_width), device=x.device)
|
||||
d_mask, h_mask, w_mask = self.separate_positions(mask, grid_height, grid_width)
|
||||
|
||||
h_mask *= self.grid_size / grid_height
|
||||
w_mask *= self.grid_size / grid_width
|
||||
|
||||
if action_tokens > 0:
|
||||
x = x.view(batch_size, -1, action_tokens + grid_height * grid_width, channels)
|
||||
action_q, action_k, action_v = [], [], []
|
||||
for idx in range(action_tokens):
|
||||
action_token = x[:, :, idx : idx + 1, :].flatten(1, 2)
|
||||
qkv = self.qkv(action_token).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qd = rotate_queries_or_keys(q[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device))
|
||||
kd = rotate_queries_or_keys(k[..., : self.d_dim], pos=torch.arange(num_frames, device=x.device))
|
||||
qr = q[..., self.d_dim :]
|
||||
kr = k[..., self.d_dim :]
|
||||
action_q.append(torch.cat([qd, qr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
action_k.append(torch.cat([kd, kr], dim=-1).view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
action_v.append(v.view(batch_size, self.num_heads, num_frames, 1, -1))
|
||||
|
||||
action_q = torch.cat(action_q, dim=3).flatten(2, 3)
|
||||
action_k = torch.cat(action_k, dim=3).flatten(2, 3)
|
||||
action_v = torch.cat(action_v, dim=3).flatten(2, 3)
|
||||
x = x[:, :, action_tokens:, :].flatten(1, 2)
|
||||
|
||||
qkv = self.qkv(x).unflatten(-1, (3, self.num_heads, -1)).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
offset = 0
|
||||
qd = rotate_queries_or_keys(q[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
kd = rotate_queries_or_keys(k[..., offset : offset + self.d_dim], pos=d_mask)
|
||||
offset += self.d_dim
|
||||
qh = rotate_queries_or_keys(q[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
kh = rotate_queries_or_keys(k[..., offset : offset + self.h_dim], pos=h_mask)
|
||||
offset += self.h_dim
|
||||
qw = rotate_queries_or_keys(q[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
kw = rotate_queries_or_keys(k[..., offset : offset + self.w_dim], pos=w_mask)
|
||||
offset += self.w_dim
|
||||
|
||||
if offset < self.head_dim:
|
||||
q = torch.cat([qd, qh, qw, q[..., offset:]], dim=-1)
|
||||
k = torch.cat([kd, kh, kw, k[..., offset:]], dim=-1)
|
||||
else:
|
||||
q = torch.cat([qd, qh, qw], dim=-1)
|
||||
k = torch.cat([kd, kh, kw], dim=-1)
|
||||
|
||||
if action_tokens > 0:
|
||||
|
||||
def merge(frame_tokens: torch.Tensor, action_token_values: torch.Tensor) -> torch.Tensor:
|
||||
frame_tokens = frame_tokens.view(
|
||||
batch_size, self.num_heads, num_frames, grid_height * grid_width, -1
|
||||
)
|
||||
action_token_values = action_token_values.view(
|
||||
batch_size, self.num_heads, num_frames, action_tokens, -1
|
||||
)
|
||||
return torch.cat([action_token_values, frame_tokens], dim=3).flatten(2, 3)
|
||||
|
||||
q = merge(q, action_q)
|
||||
k = merge(k, action_k)
|
||||
v = merge(v, action_v)
|
||||
|
||||
if attn_mask is not None or self.use_sdpa:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=self.proj_drop_prob, is_causal=self.is_causal, attn_mask=attn_mask
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(batch_size, num_tokens, channels)
|
||||
x = self.proj(x)
|
||||
return self.proj_drop(x)
|
||||
|
||||
|
||||
class ACBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: float | None = None,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
use_sdpa: bool = True,
|
||||
is_causal: bool = False,
|
||||
grid_size: int = 16,
|
||||
use_rope: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if not use_rope:
|
||||
raise ValueError("JEVLA1 world predictor uses AC RoPE attention.")
|
||||
self.attn = ACRoPEAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
use_sdpa=use_sdpa,
|
||||
is_causal=is_causal,
|
||||
grid_size=grid_size,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLP(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=nn.GELU,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
grid_height: int | None = None,
|
||||
grid_width: int | None = None,
|
||||
action_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
y = self.norm1(x)
|
||||
y = self.attn(
|
||||
y,
|
||||
mask=None,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=grid_height,
|
||||
grid_width=grid_width,
|
||||
action_tokens=action_tokens,
|
||||
)
|
||||
x = x + self.drop_path(y)
|
||||
y = self.norm2(x)
|
||||
return x + self.drop_path(self.mlp(y))
|
||||
|
||||
|
||||
class ActionConditionedVideoPredictor(nn.Module):
|
||||
"""JEVLA1-compatible action-conditioned V-JEPA predictor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_frames: int,
|
||||
img_size: tuple[int, int],
|
||||
patch_size: int,
|
||||
tubelet_size: int,
|
||||
embed_dim: int,
|
||||
action_embed_dim: int,
|
||||
predictor_embed_dim: int,
|
||||
@@ -65,40 +304,93 @@ class ActionConditionedVideoPredictor(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
num_action_tokens_per_step: int,
|
||||
use_extrinsics: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim)
|
||||
self.is_frame_causal = True
|
||||
self.use_extrinsics = use_extrinsics
|
||||
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
||||
self.action_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.state_encoder = nn.Linear(action_embed_dim, predictor_embed_dim, bias=True)
|
||||
self.extrinsics_encoder = nn.Linear(action_embed_dim - 1, predictor_embed_dim, bias=True)
|
||||
|
||||
self.img_height, self.img_width = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.grid_height = self.img_height // self.patch_size
|
||||
self.grid_width = self.img_width // self.patch_size
|
||||
|
||||
self.predictor_blocks = nn.ModuleList(
|
||||
[_PredictorBlock(predictor_embed_dim, num_heads, mlp_ratio) for _ in range(depth)]
|
||||
[
|
||||
ACBlock(
|
||||
dim=predictor_embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=lambda dim: nn.LayerNorm(dim, eps=1e-6),
|
||||
grid_size=self.grid_height,
|
||||
use_rope=True,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim)
|
||||
self.predictor_norm = nn.LayerNorm(predictor_embed_dim, eps=1e-6)
|
||||
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
||||
self.num_action_tokens_per_step = num_action_tokens_per_step
|
||||
|
||||
def forward(self, frame_tokens: torch.Tensor, action_tokens: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_steps, tokens_per_frame, _ = frame_tokens.shape
|
||||
_, action_steps, _, _ = action_tokens.shape
|
||||
if action_steps != num_steps:
|
||||
raise ValueError(f"Expected {num_steps} action steps, got {action_steps}.")
|
||||
@property
|
||||
def norm(self) -> nn.LayerNorm:
|
||||
return self.predictor_norm
|
||||
|
||||
frame_tokens = self.predictor_embed(frame_tokens)
|
||||
action_tokens = self.action_encoder(action_tokens)
|
||||
fused_steps = [
|
||||
torch.cat([action_tokens[:, step], frame_tokens[:, step]], dim=1) for step in range(num_steps)
|
||||
]
|
||||
fused = torch.cat(fused_steps, dim=1)
|
||||
@property
|
||||
def proj(self) -> nn.Linear:
|
||||
return self.predictor_proj
|
||||
|
||||
attn_mask = build_block_causal_attention_mask(
|
||||
num_steps=num_steps,
|
||||
tokens_per_step=tokens_per_frame,
|
||||
cond_tokens=self.num_action_tokens_per_step,
|
||||
).to(device=fused.device, dtype=fused.dtype)
|
||||
def forward(
|
||||
self,
|
||||
frame_tokens: torch.Tensor,
|
||||
action_tokens: torch.Tensor,
|
||||
extrinsics: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# starVLA input convention: frame_tokens [B, T*H*W, D], actions [B, T*A, D].
|
||||
x = self.predictor_embed(frame_tokens)
|
||||
batch_size, num_context_tokens, hidden_dim = x.size()
|
||||
num_frames = num_context_tokens // (self.grid_height * self.grid_width)
|
||||
|
||||
actions = self.action_encoder(action_tokens)
|
||||
actions = actions.view(batch_size, num_frames, -1, hidden_dim)
|
||||
cond_tokens = actions.shape[2]
|
||||
|
||||
x = x.view(batch_size, num_frames, self.grid_height * self.grid_width, hidden_dim)
|
||||
if self.use_extrinsics:
|
||||
if extrinsics is None:
|
||||
raise ValueError("extrinsics are required when use_extrinsics=True.")
|
||||
cond_tokens += 1
|
||||
extrinsic_tokens = self.extrinsics_encoder(extrinsics).unsqueeze(2)
|
||||
x = torch.cat([actions, extrinsic_tokens, x], dim=2).flatten(1, 2)
|
||||
else:
|
||||
x = torch.cat([actions, x], dim=2).flatten(1, 2)
|
||||
|
||||
attn_mask = build_action_block_causal_attention_mask(
|
||||
num_frames, self.grid_height, self.grid_width, add_tokens=cond_tokens
|
||||
)
|
||||
attn_mask = attn_mask[: x.size(1), : x.size(1)].to(x.device, non_blocking=True)
|
||||
|
||||
for block in self.predictor_blocks:
|
||||
fused = block(fused, attn_mask=attn_mask)
|
||||
x = block(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
num_frames=num_frames,
|
||||
grid_height=self.grid_height,
|
||||
grid_width=self.grid_width,
|
||||
action_tokens=cond_tokens,
|
||||
)
|
||||
|
||||
fused = self.predictor_norm(fused)
|
||||
fused = fused.view(batch_size, num_steps, self.num_action_tokens_per_step + tokens_per_frame, -1)
|
||||
predicted_frame_tokens = fused[:, :, self.num_action_tokens_per_step :, :]
|
||||
return self.predictor_proj(predicted_frame_tokens)
|
||||
x = x.view(batch_size, num_frames, cond_tokens + self.grid_height * self.grid_width, hidden_dim)
|
||||
x = x[:, :, cond_tokens:, :].flatten(1, 2)
|
||||
x = self.predictor_norm(x)
|
||||
return self.predictor_proj(x)
|
||||
|
||||
Reference in New Issue
Block a user