diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 6ce34a683..2613d291b 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -243,22 +243,21 @@ class LiberoProcessorStep(ObservationProcessorStep): robot_state = processed_obs.pop("observation.robot_state") # Extract components - eef_pos = robot_state["eef"]["pos"] # (3,) - eef_quat = robot_state["eef"]["quat"] # (4,) - gripper_qpos = robot_state["gripper"]["qpos"] # (2,) + eef_pos = robot_state["eef"]["pos"] # (B, 3,) + eef_quat = robot_state["eef"]["quat"] # (B, 4,) + gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,) # Convert quaternion to axis-angle eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3) - # Concatenate into a single state vector - state = np.concatenate((eef_pos, eef_axisangle, gripper_qpos), axis=-1) + state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1) - # Convert to tensor - state_tensor = torch.from_numpy(state).float() - if state_tensor.dim() == 1: - state_tensor = state_tensor.unsqueeze(0) + # ensure float32 + state = state.float() + if state.dim() == 1: + state = state.unsqueeze(0) - processed_obs[OBS_STATE] = state_tensor + processed_obs[OBS_STATE] = state return processed_obs def transform_features( @@ -267,7 +266,26 @@ class LiberoProcessorStep(ObservationProcessorStep): """ Transforms feature keys from the LIBERO format to the LeRobot standard. """ - new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features} + new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {} + + # copy over non-STATE features + for ft, feats in features.items(): + if ft != PipelineFeatureType.STATE: + new_features[ft] = feats.copy() + + # rebuild STATE features + state_feats = {} + + # add our new flattened state + state_feats["observation.state"] = PolicyFeature( + key="observation.state", + shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] + dtype="float32", + description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."), + ) + + new_features[PipelineFeatureType.STATE] = state_feats + return new_features def observation(self, observation):