mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
refactor(processors): enhance transform_features method across multiple processors (#1849)
* refactor(processors): enhance transform_features method across multiple processors - Updated the transform_features method in various processors to utilize a copy of the features dictionary, ensuring immutability of the original features. - Added handling for new feature keys and removed obsolete ones in the MapTensorToDeltaActionDict, JointVelocityProcessor, and others. - Improved readability and maintainability by following consistent patterns in feature transformation. * refactor(processors): standardize action and observation keys in delta_action_processor and joint_observations_processor - Updated action and observation keys to use constants for improved readability and maintainability. - Refactored the transform_features method in multiple processors to ensure consistent handling of feature keys. - Enhanced error handling by raising exceptions for missing required components in action and observation processing. - Removed obsolete code and improved overall structure for better clarity. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(processors): remove unused import in joint_observations_processor * refactor(processors): simplify transform_features method in delta_action_processor * refactor(processors): streamline transform_features method in ImageCropResizeProcessor * refactor(processors): improve error handling and streamline transform_features method in phone_processor - Raised a ValueError for missing position and rotation in action to enhance error handling. * refactor(processors): enhance error handling in JointVelocityProcessor - Added a ValueError raise for missing current joint positions in the observation method to improve error handling and ensure the integrity of the transform_features method. * refactor(processors): simplify transform_features method in robot kinematic processors * refactor(processors): standardize action keys in phone_processor * fix(processor): RKP feature obs -> act --------- Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -153,4 +153,5 @@ class ToBatchProcessor(ProcessorStep):
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# NOTE: We ignore the batch dimension when transforming features
|
||||
return features
|
||||
|
||||
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
|
||||
from .pipeline import ActionProcessor, ProcessorStepRegistry
|
||||
|
||||
@@ -30,23 +31,28 @@ class MapTensorToDeltaActionDict(ActionProcessor):
|
||||
Map a tensor to a delta action dictionary.
|
||||
"""
|
||||
|
||||
use_gripper: bool = True
|
||||
|
||||
def action(self, action: Tensor) -> dict:
|
||||
if isinstance(action, dict):
|
||||
return action
|
||||
if action.dim() > 1:
|
||||
action = action.squeeze(0)
|
||||
|
||||
# TODO (maractingi): add rotation
|
||||
delta_action = {
|
||||
"action.delta_x": action[0],
|
||||
"action.delta_y": action[1],
|
||||
"action.delta_z": action[2],
|
||||
f"{ACTION}.delta_x": action[0],
|
||||
f"{ACTION}.delta_y": action[1],
|
||||
f"{ACTION}.delta_z": action[2],
|
||||
}
|
||||
if action.shape[0] > 3:
|
||||
delta_action["action.gripper"] = action[3]
|
||||
if self.use_gripper:
|
||||
delta_action[f"{ACTION}.gripper"] = action[3]
|
||||
return delta_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{ACTION}.delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
if self.use_gripper:
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@@ -86,10 +92,10 @@ class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
def action(self, action: dict) -> dict:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
delta_x = action.pop("action.delta_x", 0.0)
|
||||
delta_y = action.pop("action.delta_y", 0.0)
|
||||
delta_z = action.pop("action.delta_z", 0.0)
|
||||
gripper = action.pop("action.gripper", 1.0) # Default to "stay" (1.0)
|
||||
delta_x = action.pop(f"{ACTION}.delta_x", 0.0)
|
||||
delta_y = action.pop(f"{ACTION}.delta_y", 0.0)
|
||||
delta_z = action.pop(f"{ACTION}.delta_z", 0.0)
|
||||
gripper = action.pop(f"{ACTION}.gripper", 1.0) # Default to "stay" (1.0)
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
@@ -109,31 +115,31 @@ class MapDeltaActionToRobotAction(ActionProcessor):
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
"action.enabled": enabled,
|
||||
"action.target_x": scaled_delta_x,
|
||||
"action.target_y": scaled_delta_y,
|
||||
"action.target_z": scaled_delta_z,
|
||||
"action.target_wx": target_wx,
|
||||
"action.target_wy": target_wy,
|
||||
"action.target_wz": target_wz,
|
||||
"action.gripper": float(gripper),
|
||||
f"{ACTION}.enabled": enabled,
|
||||
f"{ACTION}.target_x": scaled_delta_x,
|
||||
f"{ACTION}.target_y": scaled_delta_y,
|
||||
f"{ACTION}.target_z": scaled_delta_z,
|
||||
f"{ACTION}.target_wx": target_wx,
|
||||
f"{ACTION}.target_wy": target_wy,
|
||||
f"{ACTION}.target_wz": target_wz,
|
||||
f"{ACTION}.gripper": float(gripper),
|
||||
}
|
||||
|
||||
return action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transform features to match output format."""
|
||||
# Update features to reflect the new action format
|
||||
features.update(
|
||||
{
|
||||
"action.enabled": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_x": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_y": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_z": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wx": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wy": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.target_wz": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
"action.gripper": PolicyFeature(type=FeatureType.ACTION, shape=(1,)),
|
||||
}
|
||||
)
|
||||
features.pop(f"{ACTION}.delta_x", None)
|
||||
features.pop(f"{ACTION}.delta_y", None)
|
||||
features.pop(f"{ACTION}.delta_z", None)
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
|
||||
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
@@ -4,10 +4,13 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
|
||||
from .pipeline import ObservationProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
@@ -20,10 +23,10 @@ class JointVelocityProcessor(ObservationProcessor):
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get("observation.state")
|
||||
current_positions = observation.get(OBS_STATE)
|
||||
if current_positions is None:
|
||||
# TODO(steven): if we get here, then the transform_features method will not hold
|
||||
return observation
|
||||
raise ValueError(f"{OBS_STATE} is not in observation")
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
@@ -40,7 +43,7 @@ class JointVelocityProcessor(ObservationProcessor):
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
@@ -53,12 +56,12 @@ class JointVelocityProcessor(ObservationProcessor):
|
||||
self.last_joint_positions = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if "observation.state" in features:
|
||||
original_feature = features["observation.state"]
|
||||
if OBS_STATE in features:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Double the shape to account for positions + velocities
|
||||
new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
|
||||
|
||||
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
|
||||
|
||||
@@ -72,14 +75,15 @@ class MotorCurrentProcessor(ObservationProcessor):
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Get current values from robot state
|
||||
if self.robot is None:
|
||||
return observation
|
||||
raise ValueError("Robot is not set")
|
||||
|
||||
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
|
||||
motor_currents = torch.tensor(
|
||||
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0)
|
||||
|
||||
current_state = observation.get("observation.state")
|
||||
current_state = observation.get(OBS_STATE)
|
||||
if current_state is None:
|
||||
return observation
|
||||
|
||||
@@ -87,15 +91,13 @@ class MotorCurrentProcessor(ObservationProcessor):
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation["observation.state"] = extended_state
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if "observation.state" in features and self.robot is not None:
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
original_feature = features["observation.state"]
|
||||
if OBS_STATE in features and self.robot is not None:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Add motor current dimensions to the original state shape
|
||||
num_motors = 0
|
||||
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
|
||||
@@ -103,5 +105,5 @@ class MotorCurrentProcessor(ObservationProcessor):
|
||||
|
||||
if num_motors > 0:
|
||||
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
|
||||
features["observation.state"] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
|
||||
@@ -148,12 +148,12 @@ class EEReferenceAndDelta(ActionProcessor):
|
||||
features.pop(f"{ACTION}.target_wy", None)
|
||||
features.pop(f"{ACTION}.target_wz", None)
|
||||
|
||||
features[f"{ACTION}.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@@ -189,7 +189,9 @@ class EEBoundsAndSafety(ActionProcessor):
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return act
|
||||
raise ValueError(
|
||||
"Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action"
|
||||
)
|
||||
|
||||
pos = np.array([x, y, z], dtype=float)
|
||||
twist = np.array([wx, wy, wz], dtype=float)
|
||||
@@ -221,6 +223,8 @@ class EEBoundsAndSafety(ActionProcessor):
|
||||
self._last_twist = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# check if features as f"{ACTION}.ee.{x,y,z,wx,wy,wz}"
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@@ -290,7 +294,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
new_act = dict(act)
|
||||
for i, name in enumerate(self.motor_names):
|
||||
if name == "gripper":
|
||||
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
|
||||
# TODO(pepijn): Investigate if this is correct
|
||||
# Do we want an observation key in the action field?
|
||||
new_act[f"{ACTION}.gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
@@ -299,10 +305,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for name in self.motor_names:
|
||||
features[f"{ACTION}.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.{name}.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
|
||||
return features
|
||||
|
||||
@@ -340,13 +345,12 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if f"{ACTION}.gripper" not in act:
|
||||
return new_transition
|
||||
raise ValueError(f"Required action key '{ACTION}.gripper' not found in transition")
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
new_act = dict(act)
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
return new_transition
|
||||
raise ValueError(
|
||||
f"Required motor name 'gripper' not found in self.motor_names={self.motor_names}"
|
||||
)
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
@@ -377,7 +381,9 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{OBS_STATE}.gripper.pos"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@@ -403,7 +409,7 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
|
||||
|
||||
def observation(self, obs: dict) -> dict:
|
||||
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
|
||||
return obs
|
||||
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
|
||||
|
||||
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
@@ -421,7 +427,7 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||
features[f"{OBS_STATE}.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{OBS_STATE}.ee.{k}"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
|
||||
@@ -459,7 +459,9 @@ def make_processors(
|
||||
if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None:
|
||||
# Add EE bounds and safety processor
|
||||
inverse_kinematics_steps = [
|
||||
MapTensorToDeltaActionDict(),
|
||||
MapTensorToDeltaActionDict(
|
||||
use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False
|
||||
),
|
||||
MapDeltaActionToRobotAction(),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.processor import ActionProcessor, ProcessorStepRegistry
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneOS
|
||||
|
||||
@@ -48,13 +49,13 @@ class MapPhoneActionToRobotAction(ActionProcessor):
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
# Pop them from the action
|
||||
enabled = bool(act.pop("action.phone.enabled", 0))
|
||||
pos = act.pop("action.phone.pos", None)
|
||||
rot = act.pop("action.phone.rot", None)
|
||||
inputs = act.pop("action.phone.raw_inputs", {})
|
||||
enabled = bool(act.pop(f"{ACTION}.phone.enabled", 0))
|
||||
pos = act.pop(f"{ACTION}.phone.pos", None)
|
||||
rot = act.pop(f"{ACTION}.phone.rot", None)
|
||||
inputs = act.pop(f"{ACTION}.phone.raw_inputs", {})
|
||||
|
||||
if pos is None or rot is None:
|
||||
return act
|
||||
raise ValueError("pos and rot must be present in action")
|
||||
|
||||
rotvec = rot.as_rotvec() # Absolute orientation as rotvec
|
||||
|
||||
@@ -69,28 +70,28 @@ class MapPhoneActionToRobotAction(ActionProcessor):
|
||||
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
|
||||
|
||||
# For some actions we need to invert the axis
|
||||
act["action.enabled"] = enabled
|
||||
act["action.target_x"] = -pos[1] if enabled else 0.0
|
||||
act["action.target_y"] = pos[0] if enabled else 0.0
|
||||
act["action.target_z"] = pos[2] if enabled else 0.0
|
||||
act["action.target_wx"] = rotvec[1] if enabled else 0.0
|
||||
act["action.target_wy"] = rotvec[0] if enabled else 0.0
|
||||
act["action.target_wz"] = -rotvec[2] if enabled else 0.0
|
||||
act["action.gripper"] = gripper # Still send gripper action when disabled
|
||||
act[f"{ACTION}.enabled"] = enabled
|
||||
act[f"{ACTION}.target_x"] = -pos[1] if enabled else 0.0
|
||||
act[f"{ACTION}.target_y"] = pos[0] if enabled else 0.0
|
||||
act[f"{ACTION}.target_z"] = pos[2] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wx"] = rotvec[1] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wy"] = rotvec[0] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wz"] = -rotvec[2] if enabled else 0.0
|
||||
act[f"{ACTION}.gripper"] = gripper # Still send gripper action when disabled
|
||||
return act
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop("action.phone.enabled", None)
|
||||
features.pop("action.phone.pos", None)
|
||||
features.pop("action.phone.rot", None)
|
||||
features.pop("action.phone.raw_inputs", None)
|
||||
features.pop(f"{ACTION}.phone.enabled", None)
|
||||
features.pop(f"{ACTION}.phone.pos", None)
|
||||
features.pop(f"{ACTION}.phone.rot", None)
|
||||
features.pop(f"{ACTION}.phone.raw_inputs", None)
|
||||
|
||||
features["action.enabled"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.target_wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features["action.gripper"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user