From 01ce5d7af17fedf11dbf94651c5eedba17c577ed Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Thu, 21 May 2026 11:15:52 +0000 Subject: [PATCH] refactoring into using pre and post processor --- .../vla_jepa/configuration_vla_jepa.py | 6 +- .../vla_jepa/convert_vla_jepa_checkpoints.py | 27 +-- .../policies/vla_jepa/modeling_vla_jepa.py | 75 ++------- .../policies/vla_jepa/processor_vla_jepa.py | 56 ++++++- tests/policies/vla_jepa/conftest.py | 40 ++++- tests/policies/vla_jepa/test_vla_jepa.py | 155 +++++++++++++++--- tests/policies/vla_jepa/test_world_model.py | 23 ++- 7 files changed, 268 insertions(+), 114 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index 53dced2cd..9e1dd0ffe 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode @@ -19,8 +18,8 @@ class VLAJEPAConfig(PreTrainedConfig): normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.MEAN_STD, - "ACTION": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.IDENTITY, + "ACTION": NormalizationMode.MIN_MAX, } ) @@ -66,7 +65,6 @@ class VLAJEPAConfig(PreTrainedConfig): repeated_diffusion_steps: int = 8 # independent noise draws per batch item (CogACT-style) resize_images_to: tuple[int, int] | None = None - action_unnormalization_stats: dict[str, Any] | None = None binarize_gripper_action: bool = True clip_normalized_actions: bool = True torch_dtype: str = "bfloat16" diff --git a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py index a586dc625..753291bb5 100644 --- a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py +++ b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py @@ -75,25 +75,26 @@ def _map_checkpoint_key(raw_key: str) -> str | None: return None -def _fetch_action_stats(api: "HfApi", source_repo_id: str, subfolder: str) -> dict | None: - """Try to download dataset_statistics.json and return the action stats dict.""" +def _fetch_action_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None: + """Download dataset_statistics.json and return the action stats dict.""" import json stats_file = f"{subfolder}/dataset_statistics.json" try: local = api.hf_hub_download(source_repo_id, stats_file) data = json.loads(Path(local).read_text()) - # The original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}} + # Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}} for robot_key in data: if isinstance(data[robot_key], dict) and "action" in data[robot_key]: log.info(" Loaded action stats from %s (robot key: %s)", stats_file, robot_key) return data[robot_key]["action"] - log.warning(" %s found but no 'action' key under any robot — skipping action stats.", stats_file) + log.warning(" %s found but no 'action' key under any robot key — skipping action stats.", stats_file) except Exception as exc: # noqa: BLE001 - log.warning(" Could not fetch %s: %s — action_unnormalization_stats will be None.", stats_file, exc) + log.warning(" Could not fetch %s: %s — postprocessor will have no unnorm stats.", stats_file, exc) return None + # --------------------------------------------------------------------------- # Architecture — identical across all 4 variants (from config.json) # --------------------------------------------------------------------------- @@ -152,7 +153,6 @@ def _build_config( camera_keys: list[str], with_state: bool, enable_world_model: bool = True, - action_stats: dict | None = None, ): from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig @@ -167,7 +167,6 @@ def _build_config( "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), }, enable_world_model=enable_world_model, - action_unnormalization_stats=action_stats, binarize_gripper_action=True, clip_normalized_actions=True, **_ARCH, @@ -277,13 +276,15 @@ def main() -> None: log.info(" Skipped sample: %s", skipped_keys[:5]) log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5]) - # Fetch action unnormalization stats from the source repo - action_stats = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder) + # 3. Fetch action stats (min/max per dim) needed by the postprocessor unnormalizer + action_stats_raw = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder) + # Wrap as {"action": {...}} for UnnormalizerProcessorStep + dataset_stats = {"action": action_stats_raw} if action_stats_raw is not None else None - # 3. Build config (no policy instantiation — avoids loading backbone from Hub) - config = _build_config(camera_keys, with_state, enable_world_model, action_stats) + # 4. Build config (no policy instantiation — avoids loading backbone from Hub) + config = _build_config(camera_keys, with_state, enable_world_model) - # 4. Save everything to a temp dir and upload in one shot + # 5. Save everything to a temp dir and upload in one shot api.create_repo(target_repo_id, repo_type="model", exist_ok=True) with tempfile.TemporaryDirectory() as tmp: save_dir = Path(tmp) @@ -293,7 +294,7 @@ def main() -> None: config._save_pretrained(save_dir) # writes config.json via draccus - preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config) + preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats) preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 5d05e03eb..2f0ebbb4b 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -405,7 +405,7 @@ class VLAJEPAPolicy(PreTrainedPolicy): # ---- Format Conversion: LeRobot → Native ---- - def _lerobot_to_native(self, batch: dict[str, Tensor]) -> list[dict]: + def _prepare_model_inputs(self, batch: dict[str, Tensor]) -> list[dict]: """ Convert LeRobot batch format to native VLA-JEPA examples format. @@ -524,45 +524,19 @@ class VLAJEPAPolicy(PreTrainedPolicy): return examples - # ---- Format Conversion: Native → LeRobot ---- - - def _native_to_lerobot(self, native_output: dict[str, Tensor]) -> tuple[Tensor, dict[str, float]]: - """ - Convert native VLA-JEPA output dict to LeRobot (loss, logs) format. - - Native output: - {"action_loss": Tensor, "wm_loss": Tensor} - or {"wm_loss": Tensor} (video-only mode) - - LeRobot output: - (total_loss: scalar Tensor, {"action_loss": float, "wm_loss": float, "loss": float}) - """ - logs: dict[str, float] = {} - total_loss = torch.tensor(0.0, device=self.config.device) - - if "action_loss" in native_output: - total_loss = total_loss + native_output["action_loss"] - logs["action_loss"] = native_output["action_loss"].detach().item() - - if "wm_loss" in native_output: - total_loss = total_loss + native_output["wm_loss"] - logs["wm_loss"] = native_output["wm_loss"].detach().item() - - logs["loss"] = ( - total_loss.detach().item() - if total_loss.item() != 0 - else (logs.get("wm_loss", 0.0) + logs.get("action_loss", 0.0)) - ) - - return total_loss, logs - # ---- LeRobot Policy Interface ---- def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: - """LeRobot train forward: convert → native forward → convert back.""" - examples = self._lerobot_to_native(batch) + """LeRobot train forward: convert → native forward → aggregate losses.""" + examples = self._prepare_model_inputs(batch) native_output = self.model.forward(examples) - return self._native_to_lerobot(native_output) + + total_loss = native_output.get("action_loss", torch.tensor(0.0)) + native_output.get( + "wm_loss", torch.tensor(0.0) + ) + logs = {k: v.detach().item() for k, v in native_output.items()} + logs["loss"] = total_loss.detach().item() + return total_loss, logs def get_optim_params(self) -> dict: return self.model.parameters() @@ -573,8 +547,7 @@ class VLAJEPAPolicy(PreTrainedPolicy): self.eval() self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) - # Convert to native format - examples = self._lerobot_to_native(batch) + examples = self._prepare_model_inputs(batch) batch_images = [ex["image"] for ex in examples] instructions = [ex["lang"] for ex in examples] @@ -582,35 +555,9 @@ class VLAJEPAPolicy(PreTrainedPolicy): if "state" in examples[0] and examples[0]["state"] is not None: state_np = np.stack([ex["state"] for ex in examples]) - # Call native predict actions_np = self.model.predict_action(batch_images, instructions, state_np) - - # Convert back to tensor on the right device - actions_np = self._unnormalize_actions(actions_np) return torch.from_numpy(actions_np).to(device=self.config.device, dtype=torch.float32) - def _unnormalize_actions(self, normalized_actions: np.ndarray) -> np.ndarray: - """Match starVLA's LIBERO action post-processing exactly.""" - stats = self.config.action_unnormalization_stats - if not stats: - return normalized_actions - - actions = normalized_actions.astype(np.float32, copy=True) - if self.config.clip_normalized_actions: - actions = np.clip(actions, -1.0, 1.0) - - if self.config.binarize_gripper_action and actions.shape[-1] >= 7: - actions[..., 6] = np.where(actions[..., 6] < 0.5, 0.0, 1.0) - - action_min = np.asarray(stats["min"], dtype=np.float32) - action_max = np.asarray(stats["max"], dtype=np.float32) - mask = np.asarray(stats.get("mask", np.ones_like(action_min, dtype=bool)), dtype=bool) - scaled = 0.5 * (actions + 1.0) * (action_max - action_min) + action_min - actions = np.where(mask, scaled, actions).astype(np.float32) - if self.config.binarize_gripper_action and actions.shape[-1] >= 7: - actions[..., 6] = 1.0 - 2.0 * (actions[..., 6] > 0.5) - return actions - @torch.no_grad() def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """LeRobot select_action with action queue caching.""" diff --git a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py index 5aab01c18..a455737e6 100644 --- a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py @@ -9,16 +9,56 @@ from lerobot.processor import ( AddBatchDimensionProcessorStep, ComplementaryDataProcessorStep, DeviceProcessorStep, + EnvTransition, NormalizerProcessorStep, PolicyAction, PolicyProcessorPipeline, + ProcessorStep, ProcessorStepRegistry, RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +@ProcessorStepRegistry.register(name="vla_jepa_clip_actions") +class ClipActionsProcessorStep(ProcessorStep): + """Clips action tensor to [-1, 1] before unnormalization.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is not None: + transition = dict(transition) + transition[TransitionKey.ACTION] = action.clamp(-1.0, 1.0) + return transition + + def transform_features(self, features): + return features + + +@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper") +class BinarizeGripperProcessorStep(ProcessorStep): + """Binarizes gripper dim (index 6) after unnormalization. + + Maps continuous value to {-1, 1}: > 0.5 → -1, <= 0.5 → 1 (matches starVLA convention). + Only applied when action has >= 7 dimensions. + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is not None and action.shape[-1] >= 7: + transition = dict(transition) + a = action.clone() + a[..., 6] = 1.0 - 2.0 * (a[..., 6] > 0.5).float() + transition[TransitionKey.ACTION] = a + return transition + + def transform_features(self, features): + return features + + def make_vla_jepa_pre_post_processors( config: VLAJEPAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, @@ -37,9 +77,19 @@ def make_vla_jepa_pre_post_processors( stats=dataset_stats, ), ] - output_steps = [ - DeviceProcessorStep(device="cpu"), - ] + output_steps: list[ProcessorStep] = [] + if config.clip_normalized_actions: + output_steps.append(ClipActionsProcessorStep()) + output_steps.append( + UnnormalizerProcessorStep( + features=features, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ) + ) + if config.binarize_gripper_action: + output_steps.append(BinarizeGripperProcessorStep()) + output_steps.append(DeviceProcessorStep(device="cpu")) return ( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( steps=input_steps, diff --git a/tests/policies/vla_jepa/conftest.py b/tests/policies/vla_jepa/conftest.py index f5703fd82..da7c38cca 100644 --- a/tests/policies/vla_jepa/conftest.py +++ b/tests/policies/vla_jepa/conftest.py @@ -111,6 +111,41 @@ def make_inference_batch( # --------------------------------------------------------------------------- +class _FakeLanguageLayer(nn.Module): + """Leaf module whose forward hook is captured by _qwen_last_decoder_hidden.""" + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self._hidden_size = hidden_size + + def forward(self, hidden: Tensor, **_: object) -> tuple[Tensor, ...]: + return (hidden,) + + +class _FakeLanguageModel(nn.Module): + def __init__(self, hidden_size: int) -> None: + super().__init__() + self._hidden_size = hidden_size + self.layers = nn.ModuleList([_FakeLanguageLayer(hidden_size)]) + + def forward(self, input_ids: Tensor, **_: object) -> SimpleNamespace: + batch_size, seq_len = input_ids.shape + hidden = torch.zeros(batch_size, seq_len, self._hidden_size, device=input_ids.device) + self.layers[-1](hidden) + return SimpleNamespace() + + +class _FakeQwenInnerModel(nn.Module): + """Mimics the `.model.model` level that _qwen_last_decoder_hidden walks into.""" + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.language_model = _FakeLanguageModel(hidden_size) + + def forward(self, input_ids: Tensor, **kwargs: object) -> SimpleNamespace: + return self.language_model(input_ids) + + class _FakeQwenBackbone(nn.Module): def __init__(self, hidden_size: int) -> None: super().__init__() @@ -119,6 +154,7 @@ class _FakeQwenBackbone(nn.Module): hidden_size=hidden_size, text_config=SimpleNamespace(hidden_size=hidden_size), ) + self.model = _FakeQwenInnerModel(hidden_size) @property def device(self) -> torch.device: @@ -189,7 +225,9 @@ class _FakeVideoEncoder(nn.Module): def __init__(self, hidden_size: int = 8, tubelet_size: int = 1) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(1)) - self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size) + # image_size must be >= patch_size (16) so the predictor grid is non-zero. + # Setting image_size=16 gives a 1x1 grid (1 patch per frame). + self.config = SimpleNamespace(hidden_size=hidden_size, tubelet_size=tubelet_size, image_size=16) @property def device(self) -> torch.device: diff --git a/tests/policies/vla_jepa/test_vla_jepa.py b/tests/policies/vla_jepa/test_vla_jepa.py index 3b6e4a1a6..37ea46da5 100644 --- a/tests/policies/vla_jepa/test_vla_jepa.py +++ b/tests/policies/vla_jepa/test_vla_jepa.py @@ -5,6 +5,7 @@ from __future__ import annotations import os from copy import deepcopy +import numpy as np import pytest import torch from torch import Tensor @@ -206,12 +207,11 @@ def test_reset_clears_action_queue(patch_vla_jepa_external_models: None) -> None # --------------------------------------------------------------------------- -def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None) -> None: - import numpy as np +def test_prepare_model_inputs_training_format(patch_vla_jepa_external_models: None) -> None: from PIL import Image policy = VLAJEPAPolicy(make_config()) - examples = policy._lerobot_to_native(make_train_batch()) + examples = policy._prepare_model_inputs(make_train_batch()) assert len(examples) == BATCH_SIZE for ex in examples: @@ -222,44 +222,35 @@ def test_lerobot_to_native_training_format(patch_vla_jepa_external_models: None) assert ex["state"].shape == (1, STATE_DIM) -def test_lerobot_to_native_inference_omits_action(patch_vla_jepa_external_models: None) -> None: +def test_prepare_model_inputs_inference_omits_action(patch_vla_jepa_external_models: None) -> None: policy = VLAJEPAPolicy(make_config()) - for ex in policy._lerobot_to_native(make_inference_batch()): + for ex in policy._prepare_model_inputs(make_inference_batch()): assert "action" not in ex assert "image" in ex and "video" in ex and "lang" in ex -def test_lerobot_to_native_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None: +def test_prepare_model_inputs_missing_task_uses_default(patch_vla_jepa_external_models: None) -> None: policy = VLAJEPAPolicy(make_config()) batch = make_inference_batch() del batch["task"] - examples = policy._lerobot_to_native(batch) + examples = policy._prepare_model_inputs(batch) assert all(isinstance(ex["lang"], str) and len(ex["lang"]) > 0 for ex in examples) -def test_lerobot_to_native_string_task_broadcast(patch_vla_jepa_external_models: None) -> None: +def test_prepare_model_inputs_string_task_broadcast(patch_vla_jepa_external_models: None) -> None: policy = VLAJEPAPolicy(make_config()) batch = make_inference_batch() batch["task"] = "open the drawer" - assert all(ex["lang"] == "open the drawer" for ex in policy._lerobot_to_native(batch)) + assert all(ex["lang"] == "open the drawer" for ex in policy._prepare_model_inputs(batch)) -def test_lerobot_to_native_no_state_omitted(patch_vla_jepa_external_models: None) -> None: +def test_prepare_model_inputs_no_state_omitted(patch_vla_jepa_external_models: None) -> None: from lerobot.utils.constants import OBS_STATE policy = VLAJEPAPolicy(make_config()) batch = make_inference_batch() del batch[OBS_STATE] - assert all("state" not in ex for ex in policy._lerobot_to_native(batch)) - - -def test_native_to_lerobot_both_losses(patch_vla_jepa_external_models: None) -> None: - policy = VLAJEPAPolicy(make_config()) - loss, logs = policy._native_to_lerobot({"action_loss": torch.tensor(0.5), "wm_loss": torch.tensor(0.1)}) - assert torch.isfinite(loss) - assert set(logs) == {"action_loss", "wm_loss", "loss"} - assert logs["action_loss"] == pytest.approx(0.5, abs=1e-5) - assert logs["wm_loss"] == pytest.approx(0.1, abs=1e-5) + assert all("state" not in ex for ex in policy._prepare_model_inputs(batch)) # --------------------------------------------------------------------------- @@ -355,3 +346,127 @@ def test_hub_libero_inference_shape() -> None: batch = _make_hub_inference_batch(policy) action = policy.select_action(batch) assert action.shape[-1] == policy.config.action_dim + + +# --------------------------------------------------------------------------- +# Postprocessor unnormalization tests +# +# These tests verify that the postprocessor pipeline (clip → unnorm → binarize) +# correctly applies MIN_MAX unnormalization after predict_action_chunk. +# --------------------------------------------------------------------------- + + +def _make_dataset_stats(action_dim: int = ACTION_DIM) -> dict: + """Returns sample dataset_stats with a simple [i, i+10] range per action dim.""" + from lerobot.utils.constants import ACTION + + return { + ACTION: { + "min": torch.tensor([float(i) for i in range(action_dim)], dtype=torch.float32), + "max": torch.tensor([float(i) + 10.0 for i in range(action_dim)], dtype=torch.float32), + } + } + + +@torch.no_grad() +def test_postprocessor_unnormalizes_actions(patch_vla_jepa_external_models: None) -> None: + """UnnormalizerProcessorStep with MIN_MAX produces the correct inverse of MIN_MAX normalization.""" + from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + from lerobot.processor import UnnormalizerProcessorStep + from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + from lerobot.utils.constants import ACTION + + dataset_stats = _make_dataset_stats() + + rng = np.random.default_rng(7) + actions_np = rng.uniform(-1.0, 1.0, (2, ACTION_HORIZON, ACTION_DIM)).astype(np.float32) + + a_min = dataset_stats[ACTION]["min"].numpy() + a_max = dataset_stats[ACTION]["max"].numpy() + expected = (actions_np + 1.0) / 2.0 * (a_max - a_min) + a_min + + features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))} + unnorm_step = UnnormalizerProcessorStep( + features=features, + norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX}, + stats=dataset_stats, + ) + + actions_tensor = torch.from_numpy(actions_np) + transition = policy_action_to_transition(actions_tensor) + result = transition_to_policy_action(unnorm_step(transition)).numpy() + + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + +@torch.no_grad() +def test_postprocessor_clip_clamps_before_unnorm(patch_vla_jepa_external_models: None) -> None: + """ClipActionsProcessorStep clamps to [-1, 1] before unnormalization.""" + from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + from lerobot.processor import UnnormalizerProcessorStep + from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + from lerobot.policies.vla_jepa.processor_vla_jepa import ClipActionsProcessorStep + from lerobot.utils.constants import ACTION + + dataset_stats = _make_dataset_stats() + a_min = dataset_stats[ACTION]["min"].numpy() + a_max = dataset_stats[ACTION]["max"].numpy() + + # Deliberately out-of-range inputs + actions_np = np.array([[[2.0] * ACTION_DIM, [-3.0] * ACTION_DIM]], dtype=np.float32) + clipped = np.clip(actions_np, -1.0, 1.0) + expected = (clipped + 1.0) / 2.0 * (a_max - a_min) + a_min + + features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))} + clip_step = ClipActionsProcessorStep() + unnorm_step = UnnormalizerProcessorStep( + features=features, + norm_map={FeatureType.ACTION: NormalizationMode.MIN_MAX}, + stats=dataset_stats, + ) + + transition = policy_action_to_transition(torch.from_numpy(actions_np)) + transition = clip_step(transition) + result = transition_to_policy_action(unnorm_step(transition)).numpy() + + np.testing.assert_allclose(result, expected, rtol=1e-5, atol=1e-6) + + +@torch.no_grad() +def test_postprocessor_applied_after_predict_action_chunk( + patch_vla_jepa_external_models: None, monkeypatch: pytest.MonkeyPatch +) -> None: + """predict_action_chunk returns raw actions; the postprocessor applies unnormalization. + + Verifies the split: predict_action_chunk returns normalized actions, and calling the + postprocessor on them produces the correctly unnormalized result. + """ + from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors + + raw_actions = np.zeros((BATCH_SIZE, ACTION_HORIZON, ACTION_DIM), dtype=np.float32) + + cfg = make_config() + cfg.clip_normalized_actions = False + cfg.binarize_gripper_action = False + policy = VLAJEPAPolicy(cfg) + policy.eval() + monkeypatch.setattr(policy.model, "predict_action", lambda *a, **kw: raw_actions.copy()) + + dataset_stats = _make_dataset_stats() + _, postprocessor = make_vla_jepa_pre_post_processors(cfg, dataset_stats) + + batch = make_inference_batch() + chunk = policy.predict_action_chunk(batch) + + # predict_action_chunk returns raw (normalized) actions + assert torch.allclose(chunk, torch.zeros_like(chunk), atol=1e-6), ( + "predict_action_chunk should return raw actions without unnormalization applied." + ) + + # Postprocessor applies unnormalization: 0 → (0+1)/2 * (max-min) + min = 5 + i + unnormed = postprocessor(chunk) + from lerobot.utils.constants import ACTION + a_min = dataset_stats[ACTION]["min"].numpy() + a_max = dataset_stats[ACTION]["max"].numpy() + expected_first = 0.5 * (0.0 + 1.0) * (a_max[0] - a_min[0]) + a_min[0] + assert unnormed[0, 0, 0].item() == pytest.approx(expected_first, abs=1e-5) diff --git a/tests/policies/vla_jepa/test_world_model.py b/tests/policies/vla_jepa/test_world_model.py index 0077efb3b..0c341b993 100644 --- a/tests/policies/vla_jepa/test_world_model.py +++ b/tests/policies/vla_jepa/test_world_model.py @@ -15,10 +15,15 @@ _ACTION_EMBED_DIM = 8 def _make_predictor( embed_dim: int = 8, action_embed_dim: int = _ACTION_EMBED_DIM, - predictor_embed_dim: int = 16, + predictor_embed_dim: int = 24, num_action_tokens: int = 2, + tokens_per_frame: int = 1, ) -> ActionConditionedVideoPredictor: return ActionConditionedVideoPredictor( + num_frames=1, + img_size=(1, tokens_per_frame), + patch_size=1, + tubelet_size=1, embed_dim=embed_dim, action_embed_dim=action_embed_dim, predictor_embed_dim=predictor_embed_dim, @@ -38,16 +43,16 @@ def _make_predictor( ], ) def test_predictor_output_shape(batch: int, num_steps: int, tokens_per_frame: int, embed_dim: int) -> None: - predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM) - frame_tokens = torch.randn(batch, num_steps, tokens_per_frame, embed_dim) - action_tokens = torch.randn(batch, num_steps, 2, _ACTION_EMBED_DIM) + predictor = _make_predictor(embed_dim=embed_dim, action_embed_dim=_ACTION_EMBED_DIM, tokens_per_frame=tokens_per_frame) + frame_tokens = torch.randn(batch, num_steps * tokens_per_frame, embed_dim) + action_tokens = torch.randn(batch, num_steps * 2, _ACTION_EMBED_DIM) out = predictor(frame_tokens, action_tokens) - assert tuple(out.shape) == (batch, num_steps, tokens_per_frame, embed_dim) + assert tuple(out.shape) == (batch, num_steps * tokens_per_frame, embed_dim) assert torch.isfinite(out).all() def test_predictor_step_mismatch_raises() -> None: - predictor = _make_predictor() - frame_tokens = torch.randn(2, 3, 4, 8) - with pytest.raises(ValueError, match="Expected 3 action steps"): - predictor(frame_tokens, torch.randn(2, 2, 2, 8)) + predictor = _make_predictor(tokens_per_frame=4) + frame_tokens = torch.randn(2, 3 * 4, 8) # 3 steps, 4 tokens each + with pytest.raises(RuntimeError): + predictor(frame_tokens, torch.randn(2, 2 * 2, 8)) # 2 steps → mismatch