mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
refactor(processors): enhance transform_features method across multiple processors (#1849)
* refactor(processors): enhance transform_features method across multiple processors - Updated the transform_features method in various processors to utilize a copy of the features dictionary, ensuring immutability of the original features. - Added handling for new feature keys and removed obsolete ones in the MapTensorToDeltaActionDict, JointVelocityProcessor, and others. - Improved readability and maintainability by following consistent patterns in feature transformation. * refactor(processors): standardize action and observation keys in delta_action_processor and joint_observations_processor - Updated action and observation keys to use constants for improved readability and maintainability. - Refactored the transform_features method in multiple processors to ensure consistent handling of feature keys. - Enhanced error handling by raising exceptions for missing required components in action and observation processing. - Removed obsolete code and improved overall structure for better clarity. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(processors): remove unused import in joint_observations_processor * refactor(processors): simplify transform_features method in delta_action_processor * refactor(processors): streamline transform_features method in ImageCropResizeProcessor * refactor(processors): improve error handling and streamline transform_features method in phone_processor - Raised a ValueError for missing position and rotation in action to enhance error handling. * refactor(processors): enhance error handling in JointVelocityProcessor - Added a ValueError raise for missing current joint positions in the observation method to improve error handling and ensure the integrity of the transform_features method. * refactor(processors): simplify transform_features method in robot kinematic processors * refactor(processors): standardize action keys in phone_processor * fix(processor): RKP feature obs -> act --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
|
||||
from .pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
|
||||
@@ -30,23 +31,28 @@ class MapTensorToDeltaActionDict(ActionProcessor):
|
||||
Map a tensor to a delta action dictionary.
|
||||
"""
|
||||
|
||||
use_gripper: bool = True
|
||||
|
||||
def action(self, action: Tensor) -> dict:
|
||||
if isinstance(action, dict):
|
||||
return action
|
||||
if action.dim() > 1:
|
||||
action = action.squeeze(0)
|
||||
|
||||
# TODO (maractingi): add rotation
|
||||
delta_action = {
|
||||
"action.delta_x": action[0],
|
||||
"action.delta_y": action[1],
|
||||
"action.delta_z": action[2],
|
||||
f"{ACTION}.delta_x": action[0],
|
||||
f"{ACTION}.delta_y": action[1],
|
||||
f"{ACTION}.delta_z": action[2],
|
||||
}
|
||||
if action.shape[0] > 3:
|
||||
delta_action["action.gripper"] = action[3]
|
||||
if self.use_gripper:
|
||||
delta_action[f"{ACTION}.gripper"] = action[3]
|
||||
return delta_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{ACTION}.delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
if self.use_gripper:
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@@ -86,10 +92,10 @@ class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
def action(self, action: dict) -> dict:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
delta_x = action.pop("action.delta_x", 0.0)
|
||||
delta_y = action.pop("action.delta_y", 0.0)
|
||||
delta_z = action.pop("action.delta_z", 0.0)
|
||||
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
|
||||
delta_x = action.pop(f"{ACTION}.delta_x", 0.0)
|
||||
delta_y = action.pop(f"{ACTION}.delta_y", 0.0)
|
||||
delta_z = action.pop(f"{ACTION}.delta_z", 0.0)
|
||||
gripper = action.pop(f"{ACTION}.gripper", 1.0) # Default to "stay" (1.0)
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
@@ -109,31 +115,31 @@ class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
"action.enabled": enabled,
|
||||
"action.target_x": scaled_delta_x,
|
||||
"action.target_y": scaled_delta_y,
|
||||
"action.target_z": scaled_delta_z,
|
||||
"action.target_wx": target_wx,
|
||||
"action.target_wy": target_wy,
|
||||
"action.target_wz": target_wz,
|
||||
"action.gripper": float(gripper),
|
||||
f"{ACTION}.enabled": enabled,
|
||||
f"{ACTION}.target_x": scaled_delta_x,
|
||||
f"{ACTION}.target_y": scaled_delta_y,
|
||||
f"{ACTION}.target_z": scaled_delta_z,
|
||||
f"{ACTION}.target_wx": target_wx,
|
||||
f"{ACTION}.target_wy": target_wy,
|
||||
f"{ACTION}.target_wz": target_wz,
|
||||
f"{ACTION}.gripper": float(gripper),
|
||||
}
|
||||
|
||||
return action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transform features to match output format."""
|
||||
# Update features to reflect the new action format
|
||||
features.update(
|
||||
{
|
||||
"action.enabled": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_x": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_y": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_z": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wx": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wy": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wz": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.gripper": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
}
|
||||
)
|
||||
features.pop(f"{ACTION}.delta_x", None)
|
||||
features.pop(f"{ACTION}.delta_y", None)
|
||||
features.pop(f"{ACTION}.delta_z", None)
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
|
||||
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user