mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
iterate on review:
This commit is contained in:
@@ -243,22 +243,21 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
|
||||
# Extract components
|
||||
eef_pos = robot_state["eef"]["pos"] # (3,)
|
||||
eef_quat = robot_state["eef"]["quat"] # (4,)
|
||||
gripper_qpos = robot_state["gripper"]["qpos"] # (2,)
|
||||
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
|
||||
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
|
||||
|
||||
# Convert quaternion to axis-angle
|
||||
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
|
||||
|
||||
# Concatenate into a single state vector
|
||||
state = np.concatenate((eef_pos, eef_axisangle, gripper_qpos), axis=-1)
|
||||
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
|
||||
|
||||
# Convert to tensor
|
||||
state_tensor = torch.from_numpy(state).float()
|
||||
if state_tensor.dim() == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
# ensure float32
|
||||
state = state.float()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
processed_obs[OBS_STATE] = state_tensor
|
||||
processed_obs[OBS_STATE] = state
|
||||
return processed_obs
|
||||
|
||||
def transform_features(
|
||||
@@ -267,7 +266,26 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Transforms feature keys from the LIBERO format to the LeRobot standard.
|
||||
"""
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features}
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
|
||||
|
||||
# copy over non-STATE features
|
||||
for ft, feats in features.items():
|
||||
if ft != PipelineFeatureType.STATE:
|
||||
new_features[ft] = feats.copy()
|
||||
|
||||
# rebuild STATE features
|
||||
state_feats = {}
|
||||
|
||||
# add our new flattened state
|
||||
state_feats["observation.state"] = PolicyFeature(
|
||||
key="observation.state",
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
dtype="float32",
|
||||
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
|
||||
)
|
||||
|
||||
new_features[PipelineFeatureType.STATE] = state_feats
|
||||
|
||||
return new_features
|
||||
|
||||
def observation(self, observation):
|
||||
|
||||
Reference in New Issue
Block a user