refactoring into using pre and post processor

This commit is contained in:
Maxime Ellerbach
2026-05-21 11:15:52 +00:00
committed by Maximellerbach
parent 83ef59e020
commit 01ce5d7af1
7 changed files with 268 additions and 114 deletions

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@@ -19,8 +18,8 @@ class VLAJEPAConfig(PreTrainedConfig):
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.MIN_MAX,
}
)
@@ -66,7 +65,6 @@ class VLAJEPAConfig(PreTrainedConfig):
repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
action_unnormalization_stats: dict[str, Any] | None = None
binarize_gripper_action: bool = True
clip_normalized_actions: bool = True
torch_dtype: str = "bfloat16"

View File

@@ -75,25 +75,26 @@ def _map_checkpoint_key(raw_key: str) -> str | None:
return None
def _fetch_action_stats(api: "HfApi", source_repo_id: str, subfolder: str) -> dict | None:
"""Try to download dataset_statistics.json and return the action stats dict."""
def _fetch_action_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None:
"""Download dataset_statistics.json and return the action stats dict."""
import json
stats_file = f"{subfolder}/dataset_statistics.json"
try:
local = api.hf_hub_download(source_repo_id, stats_file)
data = json.loads(Path(local).read_text())
# The original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}}
# Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}}
for robot_key in data:
if isinstance(data[robot_key], dict) and "action" in data[robot_key]:
log.info(" Loaded action stats from %s (robot key: %s)", stats_file, robot_key)
return data[robot_key]["action"]
log.warning(" %s found but no 'action' key under any robot — skipping action stats.", stats_file)
log.warning(" %s found but no 'action' key under any robot key — skipping action stats.", stats_file)
except Exception as exc: # noqa: BLE001
log.warning(" Could not fetch %s: %saction_unnormalization_stats will be None.", stats_file, exc)
log.warning(" Could not fetch %s: %spostprocessor will have no unnorm stats.", stats_file, exc)
return None
# ---------------------------------------------------------------------------
# Architecture — identical across all 4 variants (from config.json)
# ---------------------------------------------------------------------------
@@ -152,7 +153,6 @@ def _build_config(
camera_keys: list[str],
with_state: bool,
enable_world_model: bool = True,
action_stats: dict | None = None,
):
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
@@ -167,7 +167,6 @@ def _build_config(
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
},
enable_world_model=enable_world_model,
action_unnormalization_stats=action_stats,
binarize_gripper_action=True,
clip_normalized_actions=True,
**_ARCH,
@@ -277,13 +276,15 @@ def main() -> None:
log.info(" Skipped sample: %s", skipped_keys[:5])
log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5])
# Fetch action unnormalization stats from the source repo
action_stats = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder)
# 3. Fetch action stats (min/max per dim) needed by the postprocessor unnormalizer
action_stats_raw = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder)
# Wrap as {"action": {...}} for UnnormalizerProcessorStep
dataset_stats = {"action": action_stats_raw} if action_stats_raw is not None else None
# 3. Build config (no policy instantiation — avoids loading backbone from Hub)
config = _build_config(camera_keys, with_state, enable_world_model, action_stats)
# 4. Build config (no policy instantiation — avoids loading backbone from Hub)
config = _build_config(camera_keys, with_state, enable_world_model)
# 4. Save everything to a temp dir and upload in one shot
# 5. Save everything to a temp dir and upload in one shot
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)
with tempfile.TemporaryDirectory() as tmp:
save_dir = Path(tmp)
@@ -293,7 +294,7 @@ def main() -> None:
config._save_pretrained(save_dir) # writes config.json via draccus
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config)
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats)
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json

View File

@@ -405,7 +405,7 @@ class VLAJEPAPolicy(PreTrainedPolicy):
# ---- Format Conversion: LeRobot → Native ----
def _lerobot_to_native(self, batch: dict[str, Tensor]) -> list[dict]:
def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]:
"""
Convert LeRobot batch format to native VLA-JEPA examples format.
@@ -524,45 +524,19 @@ class VLAJEPAPolicy(PreTrainedPolicy):
return examples
# ---- Format Conversion: Native → LeRobot ----
def _native_to_lerobot(self, native_output: dict[str, Tensor]) -> tuple[Tensor, dict[str, float]]:
"""
Convert native VLA-JEPA output dict to LeRobot (loss, logs) format.
Native output:
{"action_loss": Tensor, "wm_loss": Tensor}
or {"wm_loss": Tensor} (video-only mode)
LeRobot output:
(total_loss: scalar Tensor, {"action_loss": float, "wm_loss": float, "loss": float})
"""
logs: dict[str, float] = {}
total_loss = torch.tensor(0.0, device=self.config.device)
if "action_loss" in native_output:
total_loss = total_loss + native_output["action_loss"]
logs["action_loss"] = native_output["action_loss"].detach().item()
if "wm_loss" in native_output:
total_loss = total_loss + native_output["wm_loss"]
logs["wm_loss"] = native_output["wm_loss"].detach().item()
logs["loss"] = (
total_loss.detach().item()
if total_loss.item() != 0
else (logs.get("wm_loss", 0.0) + logs.get("action_loss", 0.0))
)
return total_loss, logs
# ---- LeRobot Policy Interface ----
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""LeRobot train forward: convert → native forward → convert back."""
examples = self._lerobot_to_native(batch)
"""LeRobot train forward: convert → native forward → aggregate losses."""
examples = self._prepare_model_inputs(batch)
native_output = self.model.forward(examples)
return self._native_to_lerobot(native_output)
total_loss = native_output.get("action_loss", torch.tensor(0.0)) + native_output.get(
"wm_loss", torch.tensor(0.0)
)
logs = {k: v.detach().item() for k, v in native_output.items()}
logs["loss"] = total_loss.detach().item()
return total_loss, logs
def get_optim_params(self) -> dict:
return self.model.parameters()
@@ -573,8 +547,7 @@ class VLAJEPAPolicy(PreTrainedPolicy):
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Convert to native format
examples = self._lerobot_to_native(batch)
examples = self._prepare_model_inputs(batch)
batch_images = [ex["image"] for ex in examples]
instructions = [ex["lang"] for ex in examples]
@@ -582,35 +555,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
if "state" in examples[0] and examples[0]["state"] is not None:
state_np = np.stack([ex["state"] for ex in examples])
# Call native predict
actions_np = self.model.predict_action(batch_images, instructions, state_np)
# Convert back to tensor on the right device
actions_np = self._unnormalize_actions(actions_np)
return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32)
def _unnormalize_actions(self, normalized_actions: np.ndarray) -> np.ndarray:
"""Match starVLA's LIBERO action post-processing exactly."""
stats = self.config.action_unnormalization_stats
if not stats:
return normalized_actions
actions = normalized_actions.astype(np.float32, copy=True)
if self.config.clip_normalized_actions:
actions = np.clip(actions, -1.0, 1.0)
if self.config.binarize_gripper_action and actions.shape[-1] >= 7:
actions[..., 6] = np.where(actions[..., 6] < 0.5, 0.0, 1.0)
action_min = np.asarray(stats["min"], dtype=np.float32)
action_max = np.asarray(stats["max"], dtype=np.float32)
mask = np.asarray(stats.get("mask", np.ones_like(action_min, dtype=bool)), dtype=bool)
scaled = 0.5 * (actions + 1.0) * (action_max - action_min) + action_min
actions = np.where(mask, scaled, actions).astype(np.float32)
if self.config.binarize_gripper_action and actions.shape[-1] >= 7:
actions[..., 6] = 1.0 - 2.0 * (actions[..., 6] > 0.5)
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""LeRobot select_action with action queue caching."""

View File

@@ -9,16 +9,56 @@ from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="vla_jepa_clip_actions")
class ClipActionsProcessorStep(ProcessorStep):
"""Clips action tensor to [-1, 1] before unnormalization."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition = dict(transition)
transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0)
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes gripper dim (index 6) after unnormalization.
Maps continuous value to {-1, 1}: > 0.5 → -1, <= 0.5 → 1 (matches starVLA convention).
Only applied when action has >= 7 dimensions.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] >= 7:
transition = dict(transition)
a = action.clone()
a[..., 6] = 1.0 - 2.0 * (a[..., 6] > 0.5).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
def make_vla_jepa_pre_post_processors(
config: VLAJEPAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
@@ -37,9 +77,19 @@ def make_vla_jepa_pre_post_processors(
stats=dataset_stats,
),
]
output_steps = [
DeviceProcessorStep(device="cpu"),
]
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
output_steps.append(
UnnormalizerProcessorStep(
features=features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
)
)
if config.binarize_gripper_action:
output_steps.append(BinarizeGripperProcessorStep())
output_steps.append(DeviceProcessorStep(device="cpu"))
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,

View File

@@ -111,6 +111,41 @@ def make_inference_batch(
# ---------------------------------------------------------------------------
class _FakeLanguageLayer(nn.Module):
"""Leaf module whose forward hook is captured by _qwen_last_decoder_hidden."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]:
return (hidden,)
class _FakeLanguageModel(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
self._hidden_size = hidden_size
self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)])
def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace:
batch_size, seq_len = input_ids.shape
hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device)
self.layers[-1](hidden)
return SimpleNamespace()
class _FakeQwenInnerModel(nn.Module):
"""Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.language_model = _FakeLanguageModel(hidden_size)
def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace:
return self.language_model(input_ids)
class _FakeQwenBackbone(nn.Module):
def __init__(self, hidden_size: int) -> None:
super().__init__()
@@ -119,6 +154,7 @@ class _FakeQwenBackbone(nn.Module):
hidden_size=hidden_size,
text_config=SimpleNamespace(hidden_size=hidden_size),
)
self.model = _FakeQwenInnerModel(hidden_size)
@property
def device(self) -> torch.device:
@@ -189,7 +225,9 @@ class _FakeVideoEncoder(nn.Module):
def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(1))
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size)
# image_size must be >= patch_size (16) so the predictor grid is non-zero.
# Setting image_size=16 gives a 1x1 grid (1 patch per frame).
self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16)
@property
def device(self) -> torch.device:

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import os
from copy import deepcopy
import numpy as np
import pytest
import torch
from torch import Tensor
@@ -206,12 +207,11 @@ def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None
# ---------------------------------------------------------------------------
def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None) -> None:
import numpy as np
def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None:
from PIL import Image
policy = VLAJEPAPolicy(make_config())
examples = policy._lerobot_to_native(make_train_batch())
examples = policy._prepare_model_inputs(make_train_batch())
assert len(examples) == BATCH_SIZE
for ex in examples:
@@ -222,44 +222,35 @@ def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None)
assert ex["state"].shape == (1, STATE_DIM)
def test_lerobot_to_native_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
for ex in policy._lerobot_to_native(make_inference_batch()):
for ex in policy._prepare_model_inputs(make_inference_batch()):
assert "action" not in ex
assert "image" in ex and "video" in ex and "lang" in ex
def test_lerobot_to_native_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch["task"]
examples = policy._lerobot_to_native(batch)
examples = policy._prepare_model_inputs(batch)
assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples)
def test_lerobot_to_native_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
batch["task"] = "open the drawer"
assert all(ex["lang"] == "open the drawer" for ex in policy._lerobot_to_native(batch))
assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch))
def test_lerobot_to_native_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None:
from lerobot.utils.constants import OBS_STATE
policy = VLAJEPAPolicy(make_config())
batch = make_inference_batch()
del batch[OBS_STATE]
assert all("state" not in ex for ex in policy._lerobot_to_native(batch))
def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) -> None:
policy = VLAJEPAPolicy(make_config())
loss, logs = policy._native_to_lerobot({"action_loss": torch.tensor(0.5), "wm_loss": torch.tensor(0.1)})
assert torch.isfinite(loss)
assert set(logs) == {"action_loss", "wm_loss", "loss"}
assert logs["action_loss"] == pytest.approx(0.5, abs=1e-5)
assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5)
assert all("state" not in ex for ex in policy._prepare_model_inputs(batch))
# ---------------------------------------------------------------------------
@@ -355,3 +346,127 @@ def test_hub_libero_inference_shape() -> None:
batch = _make_hub_inference_batch(policy)
action = policy.select_action(batch)
assert action.shape[-1] == policy.config.action_dim
# ---------------------------------------------------------------------------
# Postprocessor unnormalization tests
#
# These tests verify that the postprocessor pipeline (clip → unnorm → binarize)
# correctly applies MIN_MAX unnormalization after predict_action_chunk.
# ---------------------------------------------------------------------------
def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict:
"""Returns sample dataset_stats with a simple [i, i+10] range per action dim."""
from lerobot.utils.constants import ACTION
return {
ACTION: {
"min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32),
"max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32),
}
}
@torch.no_grad()
def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None:
"""UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
rng = np.random.default_rng(7)
actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32)
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
actions_tensor = torch.from_numpy(actions_np)
transition = policy_action_to_transition(actions_tensor)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None:
"""ClipActionsProcessorStep clamps to [-1, 1] before unnormalization."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor import UnnormalizerProcessorStep
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep
from lerobot.utils.constants import ACTION
dataset_stats = _make_dataset_stats()
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
# Deliberately out-of-range inputs
actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32)
clipped = np.clip(actions_np, -1.0, 1.0)
expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
clip_step = ClipActionsProcessorStep()
unnorm_step = UnnormalizerProcessorStep(
features=features,
norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX},
stats=dataset_stats,
)
transition = policy_action_to_transition(torch.from_numpy(actions_np))
transition = clip_step(transition)
result = transition_to_policy_action(unnorm_step(transition)).numpy()
np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6)
@torch.no_grad()
def test_postprocessor_applied_after_predict_action_chunk(
patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch
) -> None:
"""predict_action_chunk returns raw actions; the postprocessor applies unnormalization.
Verifies the split: predict_action_chunk returns normalized actions, and calling the
postprocessor on them produces the correctly unnormalized result.
"""
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32)
cfg = make_config()
cfg.clip_normalized_actions = False
cfg.binarize_gripper_action = False
policy = VLAJEPAPolicy(cfg)
policy.eval()
monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy())
dataset_stats = _make_dataset_stats()
_, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats)
batch = make_inference_batch()
chunk = policy.predict_action_chunk(batch)
# predict_action_chunk returns raw (normalized) actions
assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), (
"predict_action_chunk should return raw actions without unnormalization applied."
)
# Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i
unnormed = postprocessor(chunk)
from lerobot.utils.constants import ACTION
a_min = dataset_stats[ACTION]["min"].numpy()
a_max = dataset_stats[ACTION]["max"].numpy()
expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0]
assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5)

View File

@@ -15,10 +15,15 @@ _ACTION_EMBED_DIM = 8
def _make_predictor(
embed_dim: int = 8,
action_embed_dim: int = _ACTION_EMBED_DIM,
predictor_embed_dim: int = 16,
predictor_embed_dim: int = 24,
num_action_tokens: int = 2,
tokens_per_frame: int = 1,
) -> ActionConditionedVideoPredictor:
return ActionConditionedVideoPredictor(
num_frames=1,
img_size=(1, tokens_per_frame),
patch_size=1,
tubelet_size=1,
embed_dim=embed_dim,
action_embed_dim=action_embed_dim,
predictor_embed_dim=predictor_embed_dim,
@@ -38,16 +43,16 @@ def _make_predictor(
],
)
def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None:
predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM)
frame_tokens = torch.randn(batch, num_steps, tokens_per_frame, embed_dim)
action_tokens = torch.randn(batch, num_steps, 2, _ACTION_EMBED_DIM)
predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame)
frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim)
action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM)
out = predictor(frame_tokens, action_tokens)
assert tuple(out.shape) == (batch, num_steps, tokens_per_frame, embed_dim)
assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim)
assert torch.isfinite(out).all()
def test_predictor_step_mismatch_raises() -> None:
predictor = _make_predictor()
frame_tokens = torch.randn(2, 3, 4, 8)
with pytest.raises(ValueError, match="Expected 3 action steps"):
predictor(frame_tokens, torch.randn(2, 2, 2, 8))
predictor = _make_predictor(tokens_per_frame=4)
frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each
with pytest.raises(RuntimeError):
predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch