refactor(processors): enhance transform_features method across multiple processors (#1849)

* refactor(processors): enhance transform_features method across multiple processors

- Updated the transform_features method in various processors to utilize a copy of the features dictionary, ensuring immutability of the original features.
- Added handling for new feature keys and removed obsolete ones in the MapTensorToDeltaActionDict, JointVelocityProcessor, and others.
- Improved readability and maintainability by following consistent patterns in feature transformation.

* refactor(processors): standardize action and observation keys in delta_action_processor and joint_observations_processor

- Updated action and observation keys to use constants for improved readability and maintainability.
- Refactored the transform_features method in multiple processors to ensure consistent handling of feature keys.
- Enhanced error handling by raising exceptions for missing required components in action and observation processing.
- Removed obsolete code and improved overall structure for better clarity.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor(processors): remove unused import in joint_observations_processor

* refactor(processors): simplify transform_features method in delta_action_processor

* refactor(processors): streamline transform_features method in ImageCropResizeProcessor

* refactor(processors): improve error handling and streamline transform_features method in phone_processor

- Raised a ValueError for missing position and rotation in action to enhance error handling.

* refactor(processors): enhance error handling in JointVelocityProcessor

- Added a ValueError raise for missing current joint positions in the observation method to improve error handling and ensure the integrity of the transform_features method.

* refactor(processors): simplify transform_features method in robot kinematic processors

* refactor(processors): standardize action keys in phone_processor

* fix(processor): RKP feature obs -> act

---------

Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Adil Zouitine
2025-09-03 16:54:41 +02:00
committed by GitHub
parent 2fcc358e98
commit 4ebe482a7e
6 changed files with 111 additions and 93 deletions

View File

@@ -148,12 +148,12 @@ class EEReferenceAndDelta(ActionProcessor):
features.pop(f"{ACTION}.target_wy", None)
features.pop(f"{ACTION}.target_wz", None)
features[f"{ACTION}.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{ACTION}.ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
@@ -189,7 +189,9 @@ class EEBoundsAndSafety(ActionProcessor):
wz = act.get(f"{ACTION}.ee.wz", None)
if None in (x, y, z, wx, wy, wz):
return act
raise ValueError(
"Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action"
)
pos = np.array([x, y, z], dtype=float)
twist = np.array([wx, wy, wz], dtype=float)
@@ -221,6 +223,8 @@ class EEBoundsAndSafety(ActionProcessor):
self._last_twist = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# check if features as f"{ACTION}.ee.{x,y,z,wx,wy,wz}"
return features
@@ -290,7 +294,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
new_act = dict(act)
for i, name in enumerate(self.motor_names):
if name == "gripper":
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
# TODO(pepijn): Investigate if this is correct
# Do we want an observation key in the action field?
new_act[f"{ACTION}.gripper.pos"] = float(raw["gripper"])
else:
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
new_transition[TransitionKey.ACTION] = new_act
@@ -299,10 +305,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
return new_transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
for name in self.motor_names:
features[f"{ACTION}.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.{name}.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
return features
@@ -340,13 +345,12 @@ class GripperVelocityToJoint(ProcessorStep):
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if f"{ACTION}.gripper" not in act:
return new_transition
raise ValueError(f"Required action key '{ACTION}.gripper' not found in transition")
if "gripper" not in self.motor_names:
new_act = dict(act)
new_act.pop(f"{ACTION}.gripper", None)
new_transition[TransitionKey.ACTION] = new_act
return new_transition
raise ValueError(
f"Required motor name 'gripper' not found in self.motor_names={self.motor_names}"
)
if self.discrete_gripper:
# Discrete gripper actions are in [0, 1, 2]
@@ -377,7 +381,9 @@ class GripperVelocityToJoint(ProcessorStep):
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop(f"{ACTION}.gripper", None)
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
features[f"{OBS_STATE}.gripper.pos"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
return features
@@ -403,7 +409,7 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
def observation(self, obs: dict) -> dict:
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
return obs
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
t = self.kinematics.forward_kinematics(q)
@@ -421,7 +427,7 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset
for k in ["x", "y", "z", "wx", "wy", "wz"]:
features[f"{OBS_STATE}.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features[f"{OBS_STATE}.ee.{k}"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
return features