refactor(constants, processor): standardize action and observation keys across multiple files (#1808)

- Added new constants for truncated and done states in constants.py.
- Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability.
This commit is contained in:
Adil Zouitine
2025-08-31 22:53:13 +02:00
committed by GitHub
parent 574a708950
commit 08fb310eaa
6 changed files with 123 additions and 114 deletions

View File

@@ -8,6 +8,7 @@ import torch
import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PolicyFeature
from lerobot.constants import ACTION
from lerobot.processor.pipeline import (
ComplementaryDataProcessor,
EnvTransition,
@@ -22,6 +23,8 @@ from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
GRIPPER_KEY = "gripper"
DISCRETE_PENALTY_KEY = "discrete_penalty"
TELEOP_ACTION_KEY = "teleop_action"
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@@ -33,7 +36,7 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
def complementary_data(self, complementary_data: dict) -> dict:
new_complementary_data = dict(complementary_data)
new_complementary_data["teleop_action"] = self.teleop_device.get_action()
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
return new_complementary_data
@@ -141,7 +144,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
if current_gripper_pos is None:
return complementary_data
gripper_action = action[f"action.{GRIPPER_KEY}.pos"]
gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
gripper_action_normalized = gripper_action / self.max_gripper_pos
# Normalize gripper state and action
@@ -156,7 +159,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
# Create new complementary data with penalty info
new_complementary_data = dict(complementary_data)
new_complementary_data["discrete_penalty"] = gripper_penalty
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
return new_complementary_data
@@ -187,7 +190,7 @@ class InterventionActionProcessor(ProcessorStep):
# Get intervention signals from complementary data
info = transition.get(TransitionKey.INFO, {})
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
teleop_action = complementary_data.get("teleop_action", {})
teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
success = info.get(TeleopEvents.SUCCESS, False)
@@ -200,12 +203,12 @@ class InterventionActionProcessor(ProcessorStep):
if isinstance(teleop_action, dict):
# Convert teleop_action dict to tensor format
action_list = [
teleop_action.get("action.delta_x", 0.0),
teleop_action.get("action.delta_y", 0.0),
teleop_action.get("action.delta_z", 0.0),
teleop_action.get(f"{ACTION}.delta_x", 0.0),
teleop_action.get(f"{ACTION}.delta_y", 0.0),
teleop_action.get(f"{ACTION}.delta_z", 0.0),
]
if self.use_gripper:
action_list.append(teleop_action.get("gripper", 1.0))
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
elif isinstance(teleop_action, np.ndarray):
action_list = teleop_action.tolist()
else:
@@ -229,7 +232,7 @@ class InterventionActionProcessor(ProcessorStep):
# Update complementary data with teleop action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION)
complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition