Files
lerobot-clone/src/lerobot/processor/relative_action_processor.py
Pepijn 15934d8d08 feat(policies): add relative action support for pi0, pi0.5, and pi0_fast (#2970)
* Add option for pi family models to train with relative actions (relative to state)

* formatting

* add recomputation of stats and option to compute delta stats

* normalzie after delta conversion

* only recompute state for stats

* calulate chunk based stats

* sample 100k

* load from parquet

* sample 1m

* stats per chunck

* fix

* use quantiles

* stats for entire dataset

* fix

* max 1m frames

* compute before dist

* fix multi gpu processor bug

* Fix RTC with delta actions and OpenArms motor_type wiring

* feat: align pi0_fast delta actions with pi0/pi05 and add RTC integration tests

- Add delta_exclude_joints and action_feature_names to PI0FastConfig
- Move to_absolute_actions from modeling to processor pipeline for pi0_fast
- Add delta action detection and logging to eval_with_real_robot.py
- Add delta actions documentation to pi0 and pi05 READMEs
- Fix ruff lint issues in test_delta_actions.py
- Add test_rtc_delta_actions.py (24 tests) covering:
  - ActionQueue with delta vs absolute actions
  - RTC denoise step with delta leftovers
  - Full pipeline roundtrip (delta → RTC → absolute)
  - State rebasing approximation bounds
  - Non-delta policy compatibility
  - Multi-chunk consistency

* chore: clean up test comments, add OpenPI attribution, remove debug logging

- Replace decorative comment separators in test files with plain section headers
- Add attribution comments for 1e-6 epsilon in normalize_processor.py (from OpenPI)
- Remove debug logging blocks from lerobot_train.py

* refactor: extract compute_delta_action_stats into compute_stats.py

Move the ~70-line inline delta action stats block from lerobot_train.py
into a dedicated function in compute_stats.py, where all other stats
computation already lives. The training script now calls it in 6 lines.

* refactor: remove unused get_processed_left_over from ActionQueue

This method was never called outside of tests. Leftover actions for RTC
guidance are always retrieved via get_left_over() (delta/original space).

* revert: remove logging-only changes from eval_with_real_robot.py

The delta actions detection helper and log message added no functional
value — the script already handles delta policies correctly via the
processor pipeline.

* refactor: use ACTION/OBS_STATE constants instead of hardcoded strings

Replace hardcoded "action" and "observation.state" with ACTION and
OBS_STATE from utils.constants in compute_stats.py, dataset_tools.py,
and lerobot_train.py.

* style: remove stray blank lines in training loop

* refactor: move delta action stats to preprocessing step, remove on-the-fly computation

- Remove on-the-fly compute_delta_action_stats from lerobot_train.py
- Rewrite recompute_stats to delegate action stats to compute_delta_action_stats
  (chunk-based sampling matching what the model sees during training)
- Add chunk_size parameter to recompute_stats for delta action computation
- Add delta actions documentation to pi0.mdx and pi05.mdx

* feat: add recompute_stats CLI operation to lerobot-edit-dataset

* fix(tests): relax quantile normalization test tolerance for 1e-6 epsilon

* chore: remove agents_memory/pr_details.md from repo

* refactor: rename delta actions to relative actions throughout

What OpenPI calls "DeltaActions" is actually UMI's "relative trajectory"
representation: each action in the chunk is an offset from the current
state, not from the previous action. This avoids error accumulation.

Renamed across all source, tests, docs, and CLI:
- DeltaActionsProcessorStep → RelativeActionsProcessorStep
- to_delta_actions → to_relative_actions
- use_delta_actions → use_relative_actions
- delta_exclude_joints → relative_exclude_joints
- compute_delta_action_stats → compute_relative_action_stats
- delta_action_processor.py → relative_action_processor.py
- test_delta_actions.py → test_relative_actions.py

Kept as-is: AbsoluteActionsProcessorStep (converts TO absolute),
registry ID "delta_actions_processor" (backward compat), and unrelated
delta references (IK pipeline, Robosuite, RA-BC metrics, gym envs).

* docs: add Action Representations guide

Dedicated page explaining absolute, relative, and delta actions with
numerical examples, joint vs EE space, and how to use kinematics
pipelines and the relative action processor. References UMI paper
(Chi et al., 2024) for the terminology.

* docs: remove redundant OpenPI naming note from action representations

* docs: remove opinionated OpenPI reference from delta actions section

* docs: replace ASCII diagram with UMI paper figure

* docs: remove OpenPI reference from action representations

* docs: use HF-hosted image instead of local asset

* docs: clarify figure attribution

* revert: restore original normalization epsilon behavior

The 1e-6 unconditional epsilon change perturbed all normalized values,
breaking backward compatibility tests. The original approach (1e-8 eps
for MEAN_STD, conditional torch.where for QUANTILES) already handles
division by zero correctly without affecting non-degenerate cases.

* fix: restore delta_action_processor.py used by phone/RL teleop

The rename commit incorrectly deleted delta_action_processor.py and
duplicated its classes into relative_action_processor.py. Restore the
original file and import from it instead.

* fix(processor): address PR #2970 review comments

- Remove shebang from relative_action_processor.py (library module, not script)
- Add device alignment in to_relative_actions/to_absolute_actions so _last_state
  on CPU doesn't cause cross-device errors when actions are on CUDA
- Rename delta_step → relative_step in AbsoluteActionsProcessorStep for naming
  consistency; update factory.py, all processor files, and tests
- Expand _reconnect_relative_absolute_steps docstring to explain why post-hoc
  rewiring is needed after deserialization
- Fix off-by-one in compute_stats.py: sample_upper_bound = total_frames - chunk_size + 1
  so last valid start index is included and total_frames == chunk_size is not rejected
- Remove redundant NOTE comment in processor_pi05.py (duplicated two lines below)
- Fix pi0_fast processor ordering: move relative_step before NormalizerProcessorStep
  so normalizer sees delta actions (matching pi0/pi05); flip postprocessor to
  unnormalize → absolute accordingly. Relative stats are now required for all pi models
- Revert use_relative_joint_actions_aloha → use_delta_joint_actions_aloha in
  configuration_smolvla.py (preserve existing public API)
- Update action_representations.mdx: add missing joint to 6-DOF example, fix
  'based on a figure', clarify pi family ordering, add RTC compatibility section

* update rtc link

* feat: compute relative action stats over full dataset with optional parallelism

Remove the 100k sample cap from compute_relative_action_stats and process
all valid chunks. Vectorize with numpy (pre-load actions/states, fancy
indexing + broadcasting) for a large speedup over the per-index HF dataset
loop. Add num_workers param for thread-based parallelism (numpy releases
the GIL). Update docs to show --push_to_hub for recompute_stats.

* style: apply ruff formatting to compute_stats.py

* testing on real robot

* style: fix ruff format and remove redundant .keys() calls
2026-04-01 12:59:12 +02:00

209 lines
8.2 KiB
Python

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_STATE
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .pipeline import ProcessorStep, ProcessorStepRegistry
# Re-export for backward compatibility
__all__ = [
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"RelativeActionsProcessorStep",
"AbsoluteActionsProcessorStep",
"to_relative_actions",
"to_absolute_actions",
]
def to_relative_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert absolute actions to relative: relative = action - state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
# Align state to the same device/dtype as actions. _last_state is cached before
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
if state.device != actions.device or state.dtype != actions.dtype:
state = state.to(device=actions.device, dtype=actions.dtype)
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] -= state_offset
return actions
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert relative actions back to absolute: absolute = relative + state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
# Align state to the same device/dtype as actions. _last_state is cached before
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
if state.device != actions.device or state.dtype != actions.dtype:
state = state.to(device=actions.device, dtype=actions.dtype)
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] += state_offset
return actions
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class RelativeActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
trains on relative offsets instead of absolute positions.
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
the conversion during postprocessing.
Attributes:
enabled: Whether to apply the relative conversion.
exclude_joints: Joint names to keep absolute (not converted to relative).
action_names: Action dimension names from dataset metadata, used to build
the mask from exclude_joints. If None, all dims are converted.
"""
enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
action_names: list[str] | None = None
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, action_dim: int) -> list[bool]:
if not self.exclude_joints or self.action_names is None:
return [True] * action_dim
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
if not exclude_tokens:
return [True] * action_dim
mask = []
for name in self.action_names[:action_dim]:
action_name = str(name).lower()
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < action_dim:
mask.extend([True] * (action_dim - len(mask)))
return mask
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
# Always cache state for the paired AbsoluteActionsProcessorStep
if state is not None:
self._last_state = state
if not self.enabled:
return transition
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or state is None:
return new_transition
mask = self._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask)
return new_transition
def get_config(self) -> dict[str, Any]:
return {
"enabled": self.enabled,
"exclude_joints": self.exclude_joints,
"action_names": self.action_names,
}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("absolute_actions_processor")
@dataclass
class AbsoluteActionsProcessorStep(ProcessorStep):
"""Converts relative actions back to absolute actions (action += state) for all dimensions.
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
predicted relative offsets are converted back to absolute positions for execution.
Reads the cached state from its paired RelativeActionsProcessorStep.
Attributes:
enabled: Whether to apply the absolute conversion.
relative_step: Reference to the paired RelativeActionsProcessorStep that caches state.
"""
enabled: bool = False
relative_step: RelativeActionsProcessorStep | None = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
if self.relative_step is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires a paired RelativeActionsProcessorStep "
"but relative_step is None. Ensure relative_step is set when constructing the postprocessor."
)
if self.relative_step._last_state is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep "
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None:
return new_transition
mask = self.relative_step._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_absolute_actions(
action, self.relative_step._last_state, mask
)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features