mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
refactor(constants, processor): standardize action and observation keys across multiple files (#1808)
- Added new constants for truncated and done states in constants.py. - Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability.
This commit is contained in:
@@ -8,6 +8,7 @@ import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.processor.pipeline import (
|
||||
ComplementaryDataProcessor,
|
||||
EnvTransition,
|
||||
@@ -22,6 +23,8 @@ from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
GRIPPER_KEY = "gripper"
|
||||
DISCRETE_PENALTY_KEY = "discrete_penalty"
|
||||
TELEOP_ACTION_KEY = "teleop_action"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@@ -33,7 +36,7 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data["teleop_action"] = self.teleop_device.get_action()
|
||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||
return new_complementary_data
|
||||
|
||||
|
||||
@@ -141,7 +144,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
||||
if current_gripper_pos is None:
|
||||
return complementary_data
|
||||
|
||||
gripper_action = action[f"action.{GRIPPER_KEY}.pos"]
|
||||
gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
# Normalize gripper state and action
|
||||
@@ -156,7 +159,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data["discrete_penalty"] = gripper_penalty
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
@@ -187,7 +190,7 @@ class InterventionActionProcessor(ProcessorStep):
|
||||
# Get intervention signals from complementary data
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
teleop_action = complementary_data.get("teleop_action", {})
|
||||
teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
|
||||
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
|
||||
success = info.get(TeleopEvents.SUCCESS, False)
|
||||
@@ -200,12 +203,12 @@ class InterventionActionProcessor(ProcessorStep):
|
||||
if isinstance(teleop_action, dict):
|
||||
# Convert teleop_action dict to tensor format
|
||||
action_list = [
|
||||
teleop_action.get("action.delta_x", 0.0),
|
||||
teleop_action.get("action.delta_y", 0.0),
|
||||
teleop_action.get("action.delta_z", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_x", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_y", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_z", 0.0),
|
||||
]
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get("gripper", 1.0))
|
||||
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
|
||||
elif isinstance(teleop_action, np.ndarray):
|
||||
action_list = teleop_action.tolist()
|
||||
else:
|
||||
@@ -229,7 +232,7 @@ class InterventionActionProcessor(ProcessorStep):
|
||||
|
||||
# Update complementary data with teleop action
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION)
|
||||
complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
Reference in New Issue
Block a user