iterate on review:

This commit is contained in:
Jade Choghari
2025-11-18 16:12:09 +01:00
parent a068618faf
commit 0ed2f87fba

View File

@@ -243,22 +243,21 @@ class LiberoProcessorStep(ObservationProcessorStep):
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (3,)
eef_quat = robot_state["eef"]["quat"] # (4,)
gripper_qpos = robot_state["gripper"]["qpos"] # (2,)
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into a single state vector
state = np.concatenate((eef_pos, eef_axisangle, gripper_qpos), axis=-1)
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
# Convert to tensor
state_tensor = torch.from_numpy(state).float()
if state_tensor.dim() == 1:
state_tensor = state_tensor.unsqueeze(0)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state_tensor
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
@@ -267,7 +266,26 @@ class LiberoProcessorStep(ObservationProcessorStep):
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features}
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def observation(self, observation):