From 602b8e66a6640b3b38f7ad74682e1eb972b608c4 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 22 Feb 2026 16:11:52 +0100 Subject: [PATCH] fix multi gpu processor bug --- .../processor/delta_action_processor.py | 17 +- src/lerobot/scripts/lerobot_train.py | 88 +++- tests/policies/test_delta_actions.py | 376 ++++++++++++++---- 3 files changed, 392 insertions(+), 89 deletions(-) diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index 9e43ad0c1..5a05ace98 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -210,8 +210,21 @@ class DeltaActionsProcessorStep(ProcessorStep): def _build_mask(self, action_dim: int) -> list[bool]: if not self.exclude_joints or self.action_names is None: return [True] * action_dim - exclude = set(self.exclude_joints) - return [n not in exclude for n in self.action_names] + + exclude_tokens = [str(name).lower() for name in self.exclude_joints if name] + if not exclude_tokens: + return [True] * action_dim + + mask = [] + for name in self.action_names[:action_dim]: + action_name = str(name).lower() + is_excluded = any(token == action_name or token in action_name for token in exclude_tokens) + mask.append(not is_excluded) + + if len(mask) < action_dim: + mask.extend([True] * (action_dim - len(mask))) + + return mask def __call__(self, transition: EnvTransition) -> EnvTransition: observation = transition.get(TransitionKey.OBSERVATION, {}) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 890225b37..353b3bb63 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -211,6 +211,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): torch.backends.cuda.matmul.allow_tf32 = True # Dataset loading synchronization: main process downloads first to avoid race conditions + delta_action_stats = None if is_main_process: logging.info("Creating dataset") dataset = make_dataset(cfg) @@ -220,13 +221,27 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): import numpy as np from lerobot.datasets.compute_stats import get_feature_stats - from lerobot.processor.delta_action_processor import to_delta_actions + from lerobot.processor.delta_action_processor import DeltaActionsProcessorStep, to_delta_actions chunk_size = cfg.policy.chunk_size hf = dataset.hf_dataset total_frames = len(hf) - max_samples = min(500_000, total_frames - chunk_size) - indices = np.random.choice(total_frames - chunk_size, max_samples, replace=False) + sample_upper_bound = total_frames - chunk_size + if sample_upper_bound <= 0: + raise ValueError( + f"Cannot compute delta action stats: total_frames={total_frames}, chunk_size={chunk_size}" + ) + + max_samples = min(100_000, sample_upper_bound) + indices = np.random.choice(sample_upper_bound, max_samples, replace=False) + + action_names = dataset.meta.features.get("action", {}).get("names") + delta_mask_step = DeltaActionsProcessorStep( + enabled=True, + exclude_joints=getattr(cfg.policy, "delta_exclude_joints", []), + action_names=action_names, + ) + delta_mask = delta_mask_step._build_mask(dataset.meta.features["action"]["shape"][0]) logging.info( f"use_delta_actions is enabled — computing delta action stats " f"from {max_samples} chunk samples (chunk_size={chunk_size})" @@ -245,13 +260,16 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float() state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float() - mask = [True] * actions.shape[-1] - delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) + delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), delta_mask).squeeze(0) all_delta_actions.append(delta.numpy()) + if not all_delta_actions: + raise RuntimeError("Failed to compute delta action stats: no valid chunks found.") + all_delta = np.concatenate(all_delta_actions, axis=0) delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1) - dataset.meta.stats["action"] = delta_stats + delta_action_stats = delta_stats + dataset.meta.stats["action"] = delta_action_stats norm_type = "UNKNOWN" if hasattr(cfg.policy, "normalization_mapping"): @@ -259,8 +277,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): action_norm = cfg.policy.normalization_mapping.get("ACTION", None) norm_type = action_norm.value if action_norm else "UNKNOWN" + excluded_dims = len(delta_mask) - sum(delta_mask) logging.info( f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): " + f"delta_dims={sum(delta_mask)}/{len(delta_mask)} (excluded={excluded_dims}), " f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, " f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}" ) @@ -274,6 +294,15 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if not is_main_process: dataset = make_dataset(cfg) + # Ensure all ranks use the exact same delta action stats. + if getattr(cfg.policy, "use_delta_actions", False): + if accelerator.num_processes > 1 and torch.distributed.is_initialized(): + stats_list = [delta_action_stats] + torch.distributed.broadcast_object_list(stats_list, src=0) + delta_action_stats = stats_list[0] + if delta_action_stats is not None: + dataset.meta.stats["action"] = delta_action_stats + # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. @@ -299,10 +328,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone() + processor_pretrained_path = cfg.policy.pretrained_path + if ( + getattr(cfg.policy, "use_delta_actions", False) + and processor_pretrained_path is not None + and not cfg.resume + ): + logging.warning( + "use_delta_actions=true with pretrained processors can skip delta transforms if " + "the checkpoint processors do not define them. Building processors from current policy config." + ) + processor_pretrained_path = None + # Create processors - only provide dataset_stats if not resuming from saved processors processor_kwargs = {} postprocessor_kwargs = {} - if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: + if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path: # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats @@ -310,7 +351,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if cfg.policy.type == "sarm": processor_kwargs["dataset_meta"] = dataset.meta - if cfg.policy.pretrained_path is not None: + if processor_pretrained_path is not None: processor_kwargs["preprocessor_overrides"] = { "device_processor": {"device": device.type}, "normalizer_processor": { @@ -332,7 +373,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, - pretrained_path=cfg.policy.pretrained_path, + pretrained_path=processor_pretrained_path, **processor_kwargs, **postprocessor_kwargs, ) @@ -450,7 +491,36 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) + + # Debug logging for first few steps and periodically + if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)): + action = batch.get("action") + state = batch.get("observation.state") + if action is not None and state is not None: + logging.info( + f"[DEBUG step={step}] PRE-PROCESSOR — " + f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, " + f"min={action.min():.4f}, max={action.max():.4f} | " + f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}" + ) + batch = preprocessor(batch) + + if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)): + action = batch.get("action") + state = batch.get("observation.state") + if action is not None: + logging.info( + f"[DEBUG step={step}] POST-PROCESSOR — " + f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, " + f"min={action.min():.4f}, max={action.max():.4f}" + ) + if state is not None: + logging.info( + f"[DEBUG step={step}] POST-PROCESSOR — " + f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}, std={state.std():.4f}" + ) + train_tracker.dataloading_s = time.perf_counter() - start_time train_tracker, output_dict = update_policy( diff --git a/tests/policies/test_delta_actions.py b/tests/policies/test_delta_actions.py index f37bd6461..b2ec3f488 100644 --- a/tests/policies/test_delta_actions.py +++ b/tests/policies/test_delta_actions.py @@ -1,124 +1,344 @@ -"""Tests for delta action transforms using a local dummy dataset.""" +"""Tests for delta action transforms — full pipeline validation. + +Tests the complete flow matching OpenPI: + raw actions → DeltaActions → Normalize(delta_stats) → model → Unnormalize → AbsoluteActions + +Uses real dataset: lerobot-data-collection/dagger_final_1_21 +""" import numpy as np import pytest import torch +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.datasets.compute_stats import get_feature_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.processor import TransitionKey, batch_to_transition from lerobot.processor.delta_action_processor import ( + AbsoluteActionsProcessorStep, DeltaActionsProcessorStep, to_absolute_actions, to_delta_actions, ) +from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep from lerobot.utils.constants import ACTION, OBS_STATE -ACTION_DIM = 14 -STATE_DIM = 14 +CHUNK_SIZE = 10 +REPO_ID = "lerobot-data-collection/dagger_final_1_21" -@pytest.fixture -def dataset(tmp_path, empty_lerobot_dataset_factory): - features = { - "action": {"dtype": "float32", "shape": (ACTION_DIM,), "names": None}, - "observation.state": {"dtype": "float32", "shape": (STATE_DIM,), "names": None}, - } - ds = empty_lerobot_dataset_factory(root=tmp_path / "delta_test", features=features) - for ep in range(2): - for _ in range(5): - ds.add_frame( - { - "action": np.random.randn(ACTION_DIM).astype(np.float32), - "observation.state": np.random.randn(STATE_DIM).astype(np.float32), - "task": f"task_{ep}", - } - ) - ds.save_episode() - ds.finalize() - return ds +@pytest.fixture(scope="module") +def dataset(): + return LeRobotDataset(REPO_ID, episodes=[0]) -def _collate(dataset, indices): - items = [dataset[i] for i in indices] - batch = {} - for key in items[0]: - vals = [item[key] for item in items] - if isinstance(vals[0], torch.Tensor): - batch[key] = torch.stack(vals) - else: - batch[key] = vals - return batch +@pytest.fixture(scope="module") +def action_dim(dataset): + return dataset.meta.features["action"]["shape"][0] -def test_roundtrip_3d(dataset): - """Delta then absolute on real data should recover original actions.""" - batch = _collate(dataset, range(4)) - actions = batch[ACTION].unsqueeze(1).expand(-1, 10, -1).clone() - state = batch[OBS_STATE] - mask = [True] * actions.shape[-1] +def _build_action_chunks(dataset, chunk_size, max_chunks=50): + """Build action chunks from hf_dataset, like the training script does.""" + hf = dataset.hf_dataset + total = len(hf) + all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)]) + chunks, states = [], [] + for i in range(total - chunk_size + 1): + if all_ep[i] != all_ep[i + chunk_size - 1]: + continue + chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float() + state = hf[i]["observation.state"].float() + chunks.append(chunk_actions) + states.append(state) + if len(chunks) >= max_chunks: + break + assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}" + return torch.stack(chunks), torch.stack(states) - delta = to_delta_actions(actions, state, mask) - recovered = to_absolute_actions(delta, state, mask) + +def _compute_delta_chunk_stats(action_chunks, states, mask): + all_deltas = [] + for actions, state in zip(action_chunks, states): + delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) + all_deltas.append(delta.numpy()) + all_delta = np.concatenate(all_deltas, axis=0) + return get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1) + + +# --- Basic roundtrip tests --- + +def test_roundtrip_3d(action_dim): + actions = torch.randn(4, CHUNK_SIZE, action_dim) + state = torch.randn(4, action_dim) + mask = [True] * action_dim + recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask) torch.testing.assert_close(recovered, actions) -def test_roundtrip_2d(dataset): - """Works with (B, action_dim) shaped actions too.""" - batch = _collate(dataset, range(4)) - actions = batch[ACTION] - state = batch[OBS_STATE] - mask = [True] * actions.shape[-1] - - delta = to_delta_actions(actions, state, mask) - recovered = to_absolute_actions(delta, state, mask) +def test_roundtrip_2d(action_dim): + actions = torch.randn(4, action_dim) + state = torch.randn(4, action_dim) + mask = [True] * action_dim + recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask) torch.testing.assert_close(recovered, actions) -def test_delta_changes_all_dims(dataset): - """All dims should change when mask is all True.""" - batch = _collate(dataset, range(4)) - actions = batch[ACTION].unsqueeze(1) - state = batch[OBS_STATE] - mask = [True] * actions.shape[-1] - - delta = to_delta_actions(actions, state, mask) - assert (delta - actions).abs().sum() > 0 - - -def test_no_mutation(dataset): - """Original tensors should not be modified.""" - batch = _collate(dataset, range(2)) - actions = batch[ACTION].unsqueeze(1) +def test_no_mutation(action_dim): + actions = torch.randn(2, CHUNK_SIZE, action_dim) original = actions.clone() - state = batch[OBS_STATE] - mask = [True] * actions.shape[-1] - - to_delta_actions(actions, state, mask) + state = torch.randn(2, action_dim) + to_delta_actions(actions, state, [True] * action_dim) torch.testing.assert_close(actions, original) -def test_processor_step_roundtrip(dataset): +def test_exclude_joints_supports_partial_name_matching(): + names = [ + "right_joint_1.pos", + "right_gripper.pos", + "left_joint_1.pos", + "left_gripper.pos", + ] + step = DeltaActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names) + assert step._build_mask(len(names)) == [True, False, True, False] + + +# --- Chunk-level delta stats test --- + +def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim): + """Chunk-level delta stats should have larger std than per-frame delta stats.""" + action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE) + mask = [True] * action_dim + + chunk_stats = _compute_delta_chunk_stats(action_chunks, states, mask) + + # Per-frame stats + hf = dataset.hf_dataset + n = min(500, len(hf)) + frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float() + frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float() + frame_deltas = to_delta_actions(frame_actions, frame_states, mask).numpy() + frame_stats = get_feature_stats(frame_deltas, axis=0, keepdims=frame_deltas.ndim == 1) + + assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), ( + f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= " + f"frame std ({frame_stats['std'].mean():.4f})" + ) + + +# --- Full pipeline roundtrip: delta → normalize → unnormalize → absolute --- + +def test_full_pipeline_roundtrip(dataset, action_dim): + """Test the complete OpenPI pipeline: delta → normalize → unnormalize → absolute.""" + action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE) + mask = [True] * action_dim + + delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask) + stats = {ACTION: {k: v for k, v in delta_stats.items()}} + + features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))} + norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} + + delta_step = DeltaActionsProcessorStep(enabled=True) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step) + + original_actions = action_chunks[0].unsqueeze(0) + state = states[0].unsqueeze(0) + + batch = {ACTION: original_actions, OBS_STATE: state} + transition = batch_to_transition(batch) + + # Forward: delta → normalize + t1 = delta_step(transition) + t2 = normalizer(t1) + + normalized_action = t2[TransitionKey.ACTION] + assert normalized_action.abs().mean() < 10, ( + f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}" + ) + + # Reverse: unnormalize → absolute + t3 = unnormalizer(t2) + t4 = absolute_step(t3) + + recovered_actions = t4[TransitionKey.ACTION] + torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4) + + +def test_normalized_delta_values_are_reasonable(dataset, action_dim): + """With correct chunk stats, normalized delta actions should be in a reasonable range.""" + action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE) + mask = [True] * action_dim + + delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask) + mean = torch.tensor(delta_stats["mean"]).float() + std = torch.tensor(delta_stats["std"]).float() + + all_normalized = [] + for actions, state in zip(action_chunks, states): + delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) + normalized = (delta - mean) / (std + 1e-6) + all_normalized.append(normalized) + + all_normalized = torch.cat(all_normalized, dim=0) + + pct_in_range = (all_normalized.abs() < 5).float().mean() + assert pct_in_range > 0.9, ( + f"Only {pct_in_range*100:.1f}% of normalized values in [-5, 5], expected >90%" + ) + + assert all_normalized.mean().abs() < 1.0, ( + f"Mean of normalized deltas is {all_normalized.mean():.2f}, expected near 0" + ) + + +def test_processor_step_roundtrip(dataset, action_dim): """DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original.""" - batch = _collate(dataset, range(4)) + hf = dataset.hf_dataset + batch = { + ACTION: torch.stack([hf[i]["action"] for i in range(4)]), + OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]), + } original_actions = batch[ACTION].clone() transition = batch_to_transition(batch) step = DeltaActionsProcessorStep(enabled=True) delta_transition = step(transition) - - delta_actions = delta_transition[TransitionKey.ACTION] - assert not torch.allclose(delta_actions, original_actions) + assert not torch.allclose(delta_transition[TransitionKey.ACTION], original_actions) state = transition[TransitionKey.OBSERVATION][OBS_STATE] - mask = [True] * original_actions.shape[-1] - recovered = to_absolute_actions(delta_actions, state, mask) + mask = [True] * action_dim + recovered = to_absolute_actions(delta_transition[TransitionKey.ACTION], state, mask) torch.testing.assert_close(recovered, original_actions) -def test_processor_step_disabled_is_noop(dataset): +def test_processor_step_disabled_is_noop(dataset, action_dim): """enabled=False should be a no-op.""" - batch = _collate(dataset, range(2)) + hf = dataset.hf_dataset + batch = { + ACTION: torch.stack([hf[i]["action"] for i in range(2)]), + OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]), + } original = batch[ACTION].clone() transition = batch_to_transition(batch) - result = DeltaActionsProcessorStep(enabled=False)(transition) torch.testing.assert_close(result[TransitionKey.ACTION], original) + + +# --- Training batch shape validation --- + +def test_delta_with_action_chunks(dataset, action_dim): + """Verify delta works correctly with (B, chunk_size, action_dim) shaped actions.""" + action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE) + + # Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim) + batch_actions = action_chunks[:4] # (4, chunk_size, action_dim) + batch_states = states[:4] # (4, state_dim) + + mask = [True] * action_dim + delta = to_delta_actions(batch_actions, batch_states, mask) + + # First action in each chunk should be close to zero (action[t] - state[t] ≈ small) + first_deltas = delta[:, 0, :] # (B, action_dim) + assert first_deltas.abs().mean() < delta.abs().mean(), ( + f"First action in chunk should have smaller delta than average. " + f"First: {first_deltas.abs().mean():.4f}, Average: {delta.abs().mean():.4f}" + ) + + # Later actions should have larger deltas + last_deltas = delta[:, -1, :] # (B, action_dim) + assert last_deltas.abs().mean() >= first_deltas.abs().mean(), ( + f"Last action in chunk should have >= delta than first. " + f"Last: {last_deltas.abs().mean():.4f}, First: {first_deltas.abs().mean():.4f}" + ) + + # Roundtrip + recovered = to_absolute_actions(delta, batch_states, mask) + torch.testing.assert_close(recovered, batch_actions) + + +def test_delta_stats_match_actual_data_distribution(dataset, action_dim): + """Verify computed stats match the actual delta distribution.""" + action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE) + mask = [True] * action_dim + + # Compute stats like the training script does + delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask) + + # Also compute directly + all_deltas = [] + for actions, state in zip(action_chunks, states): + delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) + all_deltas.append(delta) + all_deltas_tensor = torch.cat(all_deltas, dim=0) + + # Compare mean + actual_mean = all_deltas_tensor.mean(dim=0).numpy() + np.testing.assert_allclose(delta_stats["mean"], actual_mean, atol=0.01) + + # Compare std + actual_std = all_deltas_tensor.std(dim=0).numpy() + np.testing.assert_allclose(delta_stats["std"], actual_std, atol=0.1) + + # Verify q01 < mean < q99 + assert (delta_stats["q01"] < delta_stats["mean"]).all(), "q01 should be < mean" + assert (delta_stats["mean"] < delta_stats["q99"]).all(), "mean should be < q99" + + +def test_quantile_normalization_roundtrip(dataset, action_dim): + """Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05).""" + action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE) + mask = [True] * action_dim + + delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask) + stats = {ACTION: {k: v for k, v in delta_stats.items()}} + + features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))} + norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES} + + delta_step = DeltaActionsProcessorStep(enabled=True) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step) + + original_actions = action_chunks[0].unsqueeze(0) + state = states[0].unsqueeze(0) + + batch = {ACTION: original_actions, OBS_STATE: state} + transition = batch_to_transition(batch) + + # Forward: delta → quantile normalize + t1 = delta_step(transition) + t2 = normalizer(t1) + + normalized = t2[TransitionKey.ACTION] + # Most values should be in [-1, 1] with quantile normalization + pct_in_range = (normalized.abs() < 2).float().mean() + assert pct_in_range > 0.5, ( + f"Only {pct_in_range*100:.1f}% in [-2, 2] after quantile norm, expected >50%" + ) + + # Reverse: unnormalize → absolute + t3 = unnormalizer(t2) + t4 = absolute_step(t3) + + recovered = t4[TransitionKey.ACTION] + torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3) + + +def test_state_not_modified_by_delta(dataset, action_dim): + """State should never be modified by the delta processor.""" + hf = dataset.hf_dataset + batch = { + ACTION: torch.stack([hf[i]["action"] for i in range(4)]), + OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]), + } + original_state = batch[OBS_STATE].clone() + transition = batch_to_transition(batch) + + step = DeltaActionsProcessorStep(enabled=True) + result = step(transition) + + result_state = result[TransitionKey.OBSERVATION][OBS_STATE] + torch.testing.assert_close(result_state, original_state)