mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
131 Commits
fix/images
...
ci/convert
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
332ca4ccc5 | ||
|
|
fc43246942 | ||
|
|
793ad86fc9 | ||
|
|
a6dbb65917 | ||
|
|
6c7169c4af | ||
|
|
f125d5e3bf | ||
|
|
75dcfd4886 | ||
|
|
ff3cbaa872 | ||
|
|
ce793cde64 | ||
|
|
029c4a9a76 | ||
|
|
d893bf1e30 | ||
|
|
8c796b39f5 | ||
|
|
4ebe482a7e | ||
|
|
2fcc358e98 | ||
|
|
b052843f08 | ||
|
|
ebb464c255 | ||
|
|
2914ae2a96 | ||
|
|
645c87e3a9 | ||
|
|
2c802ac134 | ||
|
|
15ffc01fb3 | ||
|
|
a837685bf8 | ||
|
|
d32b76cc66 | ||
|
|
08fb310eaa | ||
|
|
574a708950 | ||
|
|
ce665160ae | ||
|
|
35c5d43255 | ||
|
|
95c1e32aa5 | ||
|
|
e4db65a127 | ||
|
|
0053defa2e | ||
|
|
fd5d8b3d5f | ||
|
|
5bf82f8229 | ||
|
|
5ca3920611 | ||
|
|
8bde9d0ab7 | ||
|
|
abcbc16126 | ||
|
|
e4fd30a8d4 | ||
|
|
5f759b1637 | ||
|
|
6a75b4761a | ||
|
|
e5ade5565d | ||
|
|
0524551f52 | ||
|
|
862bc7ef85 | ||
|
|
d38792d6e5 | ||
|
|
db3cf0158c | ||
|
|
0535f2a59a | ||
|
|
2805ae347c | ||
|
|
28ef6fcd14 | ||
|
|
7fc7ec75bb | ||
|
|
87890cbf38 | ||
|
|
5326ffe77e | ||
|
|
a1734cf575 | ||
|
|
82f300e880 | ||
|
|
3e7c9d7afc | ||
|
|
e9cb779eab | ||
|
|
8ff95be04c | ||
|
|
f02ce69df0 | ||
|
|
1feb7b5d88 | ||
|
|
fbe9009db2 | ||
|
|
c0013b130b | ||
|
|
c4763f61a1 | ||
|
|
b95c219d96 | ||
|
|
9b1138171e | ||
|
|
023b8f3466 | ||
|
|
1cad87ebd2 | ||
|
|
99de7567e6 | ||
|
|
21baa8fa02 | ||
|
|
8b4a5368b3 | ||
|
|
f5c6b03b61 | ||
|
|
e7be2fd113 | ||
|
|
b632490b4b | ||
|
|
9a9c7208d2 | ||
|
|
427b97d198 | ||
|
|
2c2bb1e8bf | ||
|
|
4b24f94225 | ||
|
|
670a278cbc | ||
|
|
fc74001202 | ||
|
|
f14ac5d486 | ||
|
|
7bd0d62ce5 | ||
|
|
7eccefe235 | ||
|
|
b72274066e | ||
|
|
20f2910b63 | ||
|
|
fd4ae3466b | ||
|
|
7beb040e8e | ||
|
|
05bd18f453 | ||
|
|
8077456c00 | ||
|
|
5595887fd0 | ||
|
|
41959389b6 | ||
|
|
2c4e888c7f | ||
|
|
5ced72e6b8 | ||
|
|
907023f9f7 | ||
|
|
4ba23ea029 | ||
|
|
409ac0baca | ||
|
|
699363f9fc | ||
|
|
ae7a54de57 | ||
|
|
fb9139b882 | ||
|
|
9fe3a3fb17 | ||
|
|
26cb9a24c3 | ||
|
|
77106697c3 | ||
|
|
75bc44c166 | ||
|
|
f2b79656eb | ||
|
|
14c2ece004 | ||
|
|
35612c61e1 | ||
|
|
f7bb3e2d90 | ||
|
|
1e0d667a22 | ||
|
|
33969a0337 | ||
|
|
fa26290e8c | ||
|
|
e9f7f5127b | ||
|
|
097842c70f | ||
|
|
3b8a3a32a0 | ||
|
|
1c56779dd9 | ||
|
|
83a4338f8b | ||
|
|
730c7b2f35 | ||
|
|
116059a43e | ||
|
|
b08149a113 | ||
|
|
c227107f60 | ||
|
|
01dc289f3d | ||
|
|
6830ca7645 | ||
|
|
ed42c71fc3 | ||
|
|
e0139065bd | ||
|
|
e509f255af | ||
|
|
e2fcd140b0 | ||
|
|
2a7a0e6129 | ||
|
|
9f33791b19 | ||
|
|
453e0a995f | ||
|
|
8ebf79c494 | ||
|
|
8774aec304 | ||
|
|
ac742c9f0d | ||
|
|
cd13f1ecfd | ||
|
|
9aa632968f | ||
|
|
62caaf07b0 | ||
|
|
3355f04ca6 | ||
|
|
769f531603 | ||
|
|
f6c7287ae7 |
@@ -4,7 +4,13 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient
|
||||
|
||||
HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process.
|
||||
|
||||
It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
It combines three key ingredients:
|
||||
|
||||
1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point.
|
||||
|
||||
2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour.
|
||||
|
||||
3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe.
|
||||
|
||||
Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.
|
||||
|
||||
@@ -56,30 +62,243 @@ pip install -e ".[hilserl]"
|
||||
|
||||
### Understanding Configuration
|
||||
|
||||
The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as:
|
||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
class GymManipulatorConfig:
|
||||
env: HILSerlRobotEnvConfig # Environment configuration (nested)
|
||||
dataset: DatasetConfig # Dataset recording/replay configuration (nested)
|
||||
mode: str | None = None # "record", "replay", or None (for training)
|
||||
device: str = "cpu" # Compute device
|
||||
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`)
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`)
|
||||
wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py`
|
||||
fps: int = 10 # Control frequency
|
||||
teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm
|
||||
processor: HILSerlProcessorConfig # Processing pipeline configuration (nested)
|
||||
name: str = "real_robot" # Environment name
|
||||
mode: str = None # "record", "replay", or None (for training)
|
||||
repo_id: str | None = None # LeRobot dataset repository ID
|
||||
dataset_root: str | None = None # Local dataset root (optional)
|
||||
task: str = "" # Task identifier
|
||||
num_episodes: int = 10 # Number of episodes for recording
|
||||
episode: int = 0 # episode index for replay
|
||||
device: str = "cuda" # Compute device
|
||||
push_to_hub: bool = True # Whether to push the recorded datasets to Hub
|
||||
pretrained_policy_name_or_path: str | None = None # For policy loading
|
||||
reward_classifier_pretrained_path: str | None = None # For reward model
|
||||
number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier
|
||||
task: str | None = None # Task identifier
|
||||
fps: int = 10 # Control frequency
|
||||
|
||||
# Nested processor configuration
|
||||
class HILSerlProcessorConfig:
|
||||
control_mode: str = "gamepad" # Control mode
|
||||
observation: ObservationConfig | None = None # Observation processing settings
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings
|
||||
gripper: GripperConfig | None = None # Gripper control and penalty settings
|
||||
reset: ResetConfig | None = None # Environment reset and timing settings
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings
|
||||
reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings
|
||||
max_gripper_pos: float | None = 100.0 # Maximum gripper position
|
||||
|
||||
# Sub-configuration classes
|
||||
class ObservationConfig:
|
||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||
add_current_to_observation: bool = False # Add motor currents to state
|
||||
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||
display_cameras: bool = False # Display camera feeds during execution
|
||||
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters
|
||||
resize_size: tuple[int, int] | None = None # Target image size
|
||||
|
||||
class GripperConfig:
|
||||
use_gripper: bool = True # Enable gripper control
|
||||
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
|
||||
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
|
||||
|
||||
class ResetConfig:
|
||||
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
|
||||
reset_time_s: float = 5.0 # Time to wait during reset
|
||||
control_time_s: float = 20.0 # Maximum episode duration
|
||||
terminate_on_success: bool = True # Whether to terminate episodes on success detection
|
||||
|
||||
class InverseKinematicsConfig:
|
||||
urdf_path: str | None = None # Path to robot URDF file
|
||||
target_frame_name: str | None = None # End-effector frame name
|
||||
end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds
|
||||
end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis
|
||||
|
||||
class RewardClassifierConfig:
|
||||
pretrained_path: str | None = None # Path to pretrained reward classifier
|
||||
success_threshold: float = 0.5 # Success detection threshold
|
||||
success_reward: float = 1.0 # Reward value for successful episodes
|
||||
|
||||
# Dataset configuration
|
||||
class DatasetConfig:
|
||||
repo_id: str # LeRobot dataset repository ID
|
||||
task: str # Task identifier
|
||||
root: str | None = None # Local dataset root directory
|
||||
num_episodes_to_record: int = 5 # Number of episodes for recording
|
||||
replay_episode: int | None = None # Episode index for replay
|
||||
push_to_hub: bool = False # Whether to push datasets to Hub
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Processor Pipeline Architecture
|
||||
|
||||
HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components:
|
||||
|
||||
#### Environment Processor Pipeline
|
||||
|
||||
The environment processor (`env_processor`) handles incoming observations and environment state:
|
||||
|
||||
1. **VanillaObservationProcessorStep**: Converts raw robot observations into standardized format
|
||||
2. **JointVelocityProcessorStep** (optional): Adds joint velocity information to observations
|
||||
3. **MotorCurrentProcessorStep** (optional): Adds motor current readings to observations
|
||||
4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions
|
||||
5. **ImageCropResizeProcessorStep** (optional): Crops and resizes camera images
|
||||
6. **TimeLimitProcessorStep** (optional): Enforces episode time limits
|
||||
7. **GripperPenaltyProcessorStep** (optional): Applies penalties for inappropriate gripper usage
|
||||
8. **RewardClassifierProcessorStep** (optional): Automated reward detection using vision models
|
||||
9. **AddBatchDimensionProcessorStep**: Converts data to batch format for neural network processing
|
||||
10. **DeviceProcessorStep**: Moves data to the specified compute device (CPU/GPU)
|
||||
|
||||
#### Action Processor Pipeline
|
||||
|
||||
The action processor (`action_processor`) handles outgoing actions and human interventions:
|
||||
|
||||
1. **AddTeleopActionAsComplimentaryDataStep**: Captures teleoperator actions for logging
|
||||
2. **AddTeleopEventsAsInfoStep**: Records intervention events and episode control signals
|
||||
3. **AddRobotObservationAsComplimentaryData**: Stores raw robot state for processing
|
||||
4. **InterventionActionProcessorStep**: Handles human interventions and episode termination
|
||||
5. **Inverse Kinematics Pipeline** (when enabled):
|
||||
- **MapDeltaActionToRobotActionStep**: Converts delta actions to robot action format
|
||||
- **EEReferenceAndDelta**: Computes end-effector reference and delta movements
|
||||
- **EEBoundsAndSafety**: Enforces workspace safety bounds
|
||||
- **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets
|
||||
- **GripperVelocityToJoint**: Handles gripper control commands
|
||||
|
||||
#### Configuration Examples
|
||||
|
||||
**Basic Observation Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Image Processing**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.front": [180, 250, 120, 150],
|
||||
"observation.images.side": [180, 207, 180, 200]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Inverse Kinematics Setup**:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"inverse_kinematics": {
|
||||
"urdf_path": "path/to/robot.urdf",
|
||||
"target_frame_name": "end_effector",
|
||||
"end_effector_bounds": {
|
||||
"min": [0.16, -0.08, 0.03],
|
||||
"max": [0.24, 0.2, 0.1]
|
||||
},
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Observation Processing
|
||||
|
||||
The HIL-SERL framework supports additional observation processing features that can improve policy learning:
|
||||
|
||||
#### Joint Velocity Processing
|
||||
|
||||
Enable joint velocity estimation to provide the policy with motion information:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Estimates joint velocities using finite differences between consecutive joint position readings
|
||||
- Adds velocity information to the observation state vector
|
||||
- Useful for policies that need motion awareness for dynamic tasks
|
||||
|
||||
#### Motor Current Processing
|
||||
|
||||
Monitor motor currents to detect contact forces and load conditions:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_current_to_observation": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This processor:
|
||||
|
||||
- Reads motor current values from the robot's control system
|
||||
- Adds current measurements to the observation state vector
|
||||
- Helps detect contact events, object weights, and mechanical resistance
|
||||
- Useful for contact-rich manipulation tasks
|
||||
|
||||
#### Combined Observation Processing
|
||||
|
||||
You can enable multiple observation processing features simultaneously:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"observation": {
|
||||
"add_joint_velocity_to_observation": true,
|
||||
"add_current_to_observation": true,
|
||||
"add_ee_pose_to_observation": false,
|
||||
"display_cameras": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data.
|
||||
|
||||
### Finding Robot Workspace Bounds
|
||||
|
||||
Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot.
|
||||
@@ -130,22 +349,56 @@ With the bounds defined, you can safely collect demonstrations for training. Tra
|
||||
|
||||
Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)):
|
||||
|
||||
1. Set `mode` to `"record"`
|
||||
2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name")
|
||||
3. Set `num_episodes` to the number of demonstrations you want to collect
|
||||
4. Set `crop_params_dict` to `null` initially (we'll determine crops later)
|
||||
5. Configure `robot`, `cameras`, and other hardware settings
|
||||
1. Set `mode` to `"record"` at the root level
|
||||
2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name")
|
||||
3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect
|
||||
4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later)
|
||||
5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section
|
||||
|
||||
Example configuration section:
|
||||
|
||||
```json
|
||||
"mode": "record",
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"dataset_root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes": 15,
|
||||
"episode": 0,
|
||||
"push_to_hub": true
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"observation": {
|
||||
"display_cameras": false
|
||||
},
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {},
|
||||
"resize_size": [128, 128]
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": 0.0
|
||||
},
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/pick_lift_cube",
|
||||
"root": null,
|
||||
"task": "pick_and_lift",
|
||||
"num_episodes_to_record": 15,
|
||||
"replay_episode": 0,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cpu"
|
||||
}
|
||||
```
|
||||
|
||||
### Using a Teleoperation Device
|
||||
@@ -191,10 +444,20 @@ The gamepad provides a very convenient way to control the robot and the episode
|
||||
To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
"type": "gamepad",
|
||||
"use_gripper": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
@@ -216,11 +479,21 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll
|
||||
To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file.
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"teleop": {
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921", # check your port number
|
||||
"use_degrees": true
|
||||
"type": "so101_leader",
|
||||
"port": "/dev/tty.usbmodem585A0077921",
|
||||
"use_degrees": true
|
||||
},
|
||||
"processor": {
|
||||
"control_mode": "leader",
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure.
|
||||
@@ -251,7 +524,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e
|
||||
|
||||
During recording:
|
||||
|
||||
1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions`
|
||||
1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions`
|
||||
2. Complete the task successfully
|
||||
3. The episode ends with a reward of 1 when you press the "success" button
|
||||
4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0
|
||||
@@ -310,11 +583,19 @@ observation.images.front: [180, 250, 120, 150]
|
||||
Add these crop parameters to your training configuration:
|
||||
|
||||
```json
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"image_preprocessing": {
|
||||
"crop_params_dict": {
|
||||
"observation.images.side": [180, 207, 180, 200],
|
||||
"observation.images.front": [180, 250, 120, 150]
|
||||
},
|
||||
"resize_size": [128, 128]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Recommended image resolution**
|
||||
@@ -343,26 +624,52 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r
|
||||
|
||||
**Key Parameters for Data Collection**
|
||||
|
||||
- **mode**: set it to `"record"` to collect a dataset
|
||||
- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **num_episodes**: Number of episodes to record
|
||||
- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected
|
||||
- **fps**: Number of frames per second to record
|
||||
- **push_to_hub**: Whether to push the dataset to the hub
|
||||
- **mode**: set it to `"record"` to collect a dataset (at root level)
|
||||
- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub
|
||||
- **dataset.num_episodes_to_record**: Number of episodes to record
|
||||
- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`)
|
||||
- **env.fps**: Number of frames per second to record
|
||||
- **dataset.push_to_hub**: Whether to push the dataset to the hub
|
||||
|
||||
The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier.
|
||||
The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection.
|
||||
|
||||
**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully.
|
||||
|
||||
Example configuration section for data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "real_robot",
|
||||
"fps": 10,
|
||||
"processor": {
|
||||
"reset": {
|
||||
"reset_time_s": 5.0,
|
||||
"control_time_s": 20.0,
|
||||
"terminate_on_success": false
|
||||
},
|
||||
"gripper": {
|
||||
"use_gripper": true
|
||||
}
|
||||
},
|
||||
"robot": {
|
||||
// ... robot configuration ...
|
||||
},
|
||||
"teleop": {
|
||||
// ... teleoperator configuration ...
|
||||
}
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"task": "reward_classifier_task",
|
||||
"num_episodes_to_record": 20,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"repo_id": "hf_username/dataset_name",
|
||||
"dataset_root": "data/your_dataset",
|
||||
"num_episodes": 20,
|
||||
"push_to_hub": true,
|
||||
"fps": 10,
|
||||
"number_of_steps_after_success": 15
|
||||
"device": "cpu"
|
||||
}
|
||||
```
|
||||
|
||||
@@ -421,9 +728,17 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
env_config = HILSerlRobotEnvConfig(
|
||||
reward_classifier_pretrained_path="path_to_your_pretrained_trained_model",
|
||||
# Other environment parameters
|
||||
config = GymManipulatorConfig(
|
||||
env=HILSerlRobotEnvConfig(
|
||||
processor=HILSerlProcessorConfig(
|
||||
reward_classifier=RewardClassifierConfig(
|
||||
pretrained_path="path_to_your_pretrained_trained_model"
|
||||
)
|
||||
),
|
||||
# Other environment parameters
|
||||
),
|
||||
dataset=DatasetConfig(...),
|
||||
mode=None # For training
|
||||
)
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
@@ -432,7 +747,18 @@ or set the argument in the json config file.
|
||||
|
||||
```json
|
||||
{
|
||||
"reward_classifier_pretrained_path": "path_to_your_pretrained_model"
|
||||
"env": {
|
||||
"processor": {
|
||||
"reward_classifier": {
|
||||
"pretrained_path": "path_to_your_pretrained_model",
|
||||
"success_threshold": 0.7,
|
||||
"success_reward": 1.0
|
||||
},
|
||||
"reset": {
|
||||
"terminate_on_success": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -32,9 +32,12 @@ To use `gym_hil` with LeRobot, you need to create a configuration file. An examp
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "hil",
|
||||
"name": "franka_sim",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
@@ -45,28 +48,40 @@ Available tasks:
|
||||
- `PandaPickCubeGamepad-v0`: With gamepad control
|
||||
- `PandaPickCubeKeyboard-v0`: With keyboard control
|
||||
|
||||
### Gym Wrappers Configuration
|
||||
### Processor Configuration
|
||||
|
||||
```json
|
||||
"wrapper": {
|
||||
"gripper_penalty": -0.02,
|
||||
"control_time_s": 15.0,
|
||||
"use_gripper": true,
|
||||
"fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785],
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
},
|
||||
"control_mode": "gamepad"
|
||||
{
|
||||
"env": {
|
||||
"processor": {
|
||||
"control_mode": "gamepad",
|
||||
"gripper": {
|
||||
"use_gripper": true,
|
||||
"gripper_penalty": -0.02
|
||||
},
|
||||
"reset": {
|
||||
"control_time_s": 15.0,
|
||||
"fixed_reset_joint_positions": [
|
||||
0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785
|
||||
]
|
||||
},
|
||||
"inverse_kinematics": {
|
||||
"end_effector_step_sizes": {
|
||||
"x": 0.025,
|
||||
"y": 0.025,
|
||||
"z": 0.025
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Important parameters:
|
||||
|
||||
- `gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `use_gripper`: Whether to enable gripper control
|
||||
- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `gripper.gripper_penalty`: Penalty for excessive gripper movement
|
||||
- `gripper.use_gripper`: Whether to enable gripper control
|
||||
- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector
|
||||
- `control_mode`: Set to `"gamepad"` to use a gamepad controller
|
||||
|
||||
## Running with HIL RL of LeRobot
|
||||
@@ -75,39 +90,50 @@ Important parameters:
|
||||
|
||||
To run the environment, set mode to null:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Recording a Dataset
|
||||
|
||||
To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0"
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "username/sim_dataset",
|
||||
"root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 10,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record"
|
||||
}
|
||||
```
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
### Training a Policy
|
||||
|
||||
To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
In a different terminal, run the learner server:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
```bash
|
||||
python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots.
|
||||
|
||||
|
||||
@@ -519,11 +519,14 @@ from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.policies.factory import make_processor
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
@@ -535,7 +538,7 @@ robot_config = SO100FollowerConfig(
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<my_policy_repo_id>")
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -544,7 +547,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/eval_<dataset_repo_id>",
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -559,6 +562,12 @@ _init_rerun(session_name="recording")
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
@@ -568,6 +577,8 @@ for episode_idx in range(NUM_EPISODES):
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -24,11 +24,36 @@ pip install -e ".[hilserl]"
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json).
|
||||
|
||||
To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record".
|
||||
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
|
||||
|
||||
If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS).
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_gym",
|
||||
"root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 30,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`.
|
||||
Key configuration points:
|
||||
|
||||
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
|
||||
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
|
||||
- Ensure `mode` is set to `"record"`
|
||||
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
|
||||
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
|
||||
|
||||
Then we can run this command to start:
|
||||
|
||||
@@ -140,9 +165,32 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
|
||||
|
||||
## Evaluate your policy in Sim
|
||||
|
||||
To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json).
|
||||
|
||||
Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model`
|
||||
Here's an example evaluation configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube"
|
||||
},
|
||||
"pretrained_policy_name_or_path": "your_username/il_sim_model",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to replace:
|
||||
|
||||
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
|
||||
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
|
||||
|
||||
Then you can run this command to visualize your trained policy
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
@@ -11,12 +12,14 @@ NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
policy = ACTPolicy.from_pretrained("<hf_username>/<policy_repo_id>")
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
@@ -25,7 +28,7 @@ dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="<hf_username>/<eval_dataset_repo_id>",
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
@@ -43,6 +46,12 @@ listener, events = init_keyboard_listener()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
@@ -53,6 +62,8 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
|
||||
@@ -38,7 +38,7 @@ while True:
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
log_rerun_data(observation, {**arm_action, **base_action})
|
||||
log_rerun_data(observation=observation, action={**arm_action, **base_action})
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
|
||||
158
examples/phone_to_so100/evaluate.py
Normal file
158
examples/phone_to_so100/evaluate.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot with degrees
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset action and gripper features
|
||||
action_ee_and_gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints_processor,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.ee", "action.gripper.pos", "observation.state.gripper.pos"],
|
||||
) # Get all ee action features + gripper pos action features
|
||||
|
||||
# Build dataset observation features
|
||||
obs_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
) # Get all ee observation features
|
||||
|
||||
dataset_features = combine_feature_dicts(obs_ee, action_ee_and_gripper)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
episode_idx = 0
|
||||
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
215
examples/phone_to_so100/record.py
Normal file
215
examples/phone_to_so100/record.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
action_to_transition,
|
||||
observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.record import record_loop
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEE,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
|
||||
NUM_EPISODES = 10
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to ee pose observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline(
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Build dataset ee action features
|
||||
action_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=phone.action_features,
|
||||
use_videos=True,
|
||||
patterns=["action.ee"],
|
||||
)
|
||||
|
||||
# Get gripper pos action features
|
||||
gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints_processor,
|
||||
initial_features={},
|
||||
use_videos=True,
|
||||
patterns=["action.gripper.pos", "observation.state.gripper.pos"],
|
||||
)
|
||||
|
||||
# Build dataset ee observation features
|
||||
observation_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=robot.observation_features,
|
||||
use_videos=True,
|
||||
patterns=["observation.state.ee"],
|
||||
)
|
||||
|
||||
dataset_features = combine_feature_dicts(action_ee, gripper, observation_ee)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
_init_rerun(session_name="recording_phone")
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
dataset.push_to_hub()
|
||||
81
examples/phone_to_so100/replay.py
Normal file
81
examples/phone_to_so100/replay.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert ee pose action to joint action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
robot_ee_to_joints_processor.reset()
|
||||
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
ee_action = {
|
||||
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
|
||||
}
|
||||
|
||||
joint_action = robot_ee_to_joints_processor(ee_action)
|
||||
action_sent = robot.send_action(joint_action)
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
robot.disconnect()
|
||||
93
examples/phone_to_so100/teleoperate.py
Normal file
93
examples/phone_to_so100/teleoperate.py
Normal file
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specif
|
||||
|
||||
import time
|
||||
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
AddRobotObservationAsComplimentaryData,
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./src/lerobot/teleoperators/sim/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline(
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
AddRobotObservationAsComplimentaryData(robot=robot),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot.")
|
||||
while True:
|
||||
# Get teleop observation
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints_processor(phone_obs)
|
||||
|
||||
if joint_action:
|
||||
robot.send_action(joint_action)
|
||||
|
||||
time.sleep(0.01)
|
||||
@@ -73,6 +73,7 @@ dependencies = [
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
"scipy>=1.15.2",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
@@ -95,7 +96,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1"]
|
||||
placo-dep = ["placo>=0.9.6"]
|
||||
transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency
|
||||
transformers-dep = ["transformers<=4.52.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
|
||||
|
||||
# Motors
|
||||
@@ -111,6 +112,7 @@ intelrealsense = [
|
||||
"pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'",
|
||||
"pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'",
|
||||
]
|
||||
phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
|
||||
# stretch = [
|
||||
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
|
||||
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
|
||||
@@ -152,7 +154,8 @@ all = [
|
||||
"lerobot[video_benchmark]",
|
||||
"lerobot[aloha]",
|
||||
"lerobot[pusht]",
|
||||
"lerobot[xarm]"
|
||||
"lerobot[xarm]",
|
||||
"lerobot[phone]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
@@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@@ -24,6 +24,7 @@ class FeatureType(str, Enum):
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
LANGUAGE = "LANGUAGE"
|
||||
|
||||
|
||||
class NormalizationMode(str, Enum):
|
||||
|
||||
@@ -21,8 +21,14 @@ OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
OBS_LANGUAGE = "observation.language"
|
||||
ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
TRUNCATED = "next.truncated"
|
||||
DONE = "next.done"
|
||||
|
||||
OBS_LANGUAGE_TOKENS = "observation.language.tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask"
|
||||
|
||||
ROBOTS = "robots"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
@@ -39,6 +45,9 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
|
||||
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
|
||||
95
src/lerobot/datasets/pipeline_features.py
Normal file
95
src/lerobot/datasets/pipeline_features.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.processor import DataProcessorPipeline
|
||||
|
||||
|
||||
def aggregate_pipeline_dataset_features(
|
||||
pipeline: DataProcessorPipeline,
|
||||
initial_features: dict[str, Any],
|
||||
*,
|
||||
use_videos: bool = True,
|
||||
patterns: Sequence[str] | None = None,
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Aggregates the pipeline's features and returns a features dict ready for the dataset,
|
||||
filtered to only those keys matching any of the given patterns (for action/state only).
|
||||
|
||||
- `initial_features`: raw camera specs, e.g. {"front": (h,w,c), ...}
|
||||
- `use_videos`: whether to treat image features as video streams
|
||||
- `patterns`: regexes to filter action & state features; images are included
|
||||
whenever use_videos=True, regardless of patterns.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Gather everything the pipeline features specifies, seeded with hardware cams:
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Helper to decide which action/state keys survive the `patterns` filter:
|
||||
def keep(key: str) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
|
||||
# Start with hardware dict, injecting initial cameras if videos are ON:
|
||||
hw: dict[str, dict[str, Any]] = {}
|
||||
if use_videos:
|
||||
cams = {
|
||||
name: shape
|
||||
for name, shape in initial_features.items()
|
||||
if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
if cams:
|
||||
hw["observation"] = dict(cams)
|
||||
|
||||
# Go over every feature from the pipeline and merge:
|
||||
for full_key, ty in all_features.items():
|
||||
if full_key.startswith(f"{ACTION}."):
|
||||
# action.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{ACTION}.") :]
|
||||
hw.setdefault(ACTION, {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_STATE}."):
|
||||
# observation.state.<feat>
|
||||
if not keep(full_key):
|
||||
continue
|
||||
name = full_key[len(f"{OBS_STATE}.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
elif full_key.startswith(f"{OBS_IMAGES}."):
|
||||
# observation.images.<cam>
|
||||
# images obey ONLY the use_videos flag, not patterns
|
||||
if not use_videos:
|
||||
continue
|
||||
name = full_key[len(f"{OBS_IMAGES}.") :]
|
||||
hw.setdefault("observation", {})[name] = ty
|
||||
|
||||
else:
|
||||
# anything else (e.g. policy-only features) is ignored here
|
||||
continue
|
||||
|
||||
out: dict[str, dict] = {}
|
||||
if ACTION in hw:
|
||||
out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
|
||||
if "observation" in hw:
|
||||
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
||||
|
||||
return out
|
||||
@@ -470,6 +470,50 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
return policy_features
|
||||
|
||||
|
||||
def combine_feature_dicts(*dicts: dict) -> dict:
|
||||
"""
|
||||
Merge LeRobot grouped feature dicts.
|
||||
|
||||
- For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
|
||||
- For others (observation.images.*), last one wins (if they are identical).
|
||||
"""
|
||||
out: dict = {}
|
||||
for d in dicts:
|
||||
for key, value in d.items():
|
||||
if not isinstance(value, dict):
|
||||
out[key] = value
|
||||
continue
|
||||
|
||||
dtype = value.get("dtype")
|
||||
shape = value.get("shape")
|
||||
is_vector = (
|
||||
dtype not in ("image", "video", "string")
|
||||
and isinstance(shape, tuple)
|
||||
and len(shape) == 1
|
||||
and "names" in value
|
||||
)
|
||||
|
||||
if is_vector:
|
||||
# Initialize or retrieve the accumulating dict for this feature key
|
||||
target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
|
||||
# Ensure consistent data types across merged entries
|
||||
if "dtype" in target and dtype != target["dtype"]:
|
||||
raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
|
||||
|
||||
# Merge feature names: append only new ones to preserve order without duplicates
|
||||
seen = set(target["names"])
|
||||
for n in value["names"]:
|
||||
if n not in seen:
|
||||
target["names"].append(n)
|
||||
seen.add(n)
|
||||
# Recompute the shape to reflect the updated number of features
|
||||
target["shape"] = (len(target["names"]),)
|
||||
else:
|
||||
# For images/videos and non-1D entries: override with the latest definition
|
||||
out[key] = value
|
||||
return out
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
|
||||
@@ -161,35 +161,73 @@ class XarmEnv(EnvConfig):
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
class ImagePreprocessingConfig:
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvTransformConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
class RewardClassifierConfig:
|
||||
"""Configuration for reward classification."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class InverseKinematicsConfig:
|
||||
"""Configuration for inverse kinematics processing."""
|
||||
|
||||
urdf_path: str | None = None
|
||||
target_frame_name: str | None = None
|
||||
end_effector_bounds: dict[str, list[float]] | None = None
|
||||
end_effector_step_sizes: dict[str, float] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObservationConfig:
|
||||
"""Configuration for observation processing."""
|
||||
|
||||
# ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig)
|
||||
control_mode: str = "gamepad"
|
||||
display_cameras: bool = False
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
reset_time_s: float = 5.0
|
||||
display_cameras: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GripperConfig:
|
||||
"""Configuration for gripper control and penalties."""
|
||||
|
||||
use_gripper: bool = True
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetConfig:
|
||||
"""Configuration for environment reset behavior."""
|
||||
|
||||
fixed_reset_joint_positions: Any | None = None
|
||||
reset_time_s: float = 5.0
|
||||
control_time_s: float = 20.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlProcessorConfig:
|
||||
"""Configuration for environment processing pipeline."""
|
||||
|
||||
control_mode: str = "gamepad"
|
||||
observation: ObservationConfig | None = None
|
||||
image_preprocessing: ImagePreprocessingConfig | None = None
|
||||
gripper: GripperConfig | None = None
|
||||
reset: ResetConfig | None = None
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None
|
||||
reward_classifier: RewardClassifierConfig | None = None
|
||||
max_gripper_pos: float | None = 100.0
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
@@ -197,77 +235,10 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
|
||||
robot: RobotConfig | None = None
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
fps: int = 10
|
||||
processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig)
|
||||
|
||||
name: str = "real_robot"
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
task: str | None = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("hil")
|
||||
@dataclass
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
name: str = "PandaPickCube"
|
||||
task: str | None = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
state_dim: int = 18
|
||||
action_dim: int = 4
|
||||
fps: int = 100
|
||||
episode_length: int = 100
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
################# args from hilserlrobotenv
|
||||
reward_classifier_pretrained_path: str | None = None
|
||||
robot_config: RobotConfig | None = None
|
||||
teleop_config: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: str | None = None
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
############################
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"use_viewer": self.use_viewer,
|
||||
"use_gamepad": self.use_gamepad,
|
||||
"gripper_penalty": self.gripper_penalty,
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
elif env_type == "hil":
|
||||
return HILEnvConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
@@ -15,6 +15,17 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0.processor_pi0 import Pi0NewLineProcessor
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"PI0Config",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
]
|
||||
|
||||
@@ -35,7 +35,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
@@ -51,27 +50,16 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
if config.temporal_ensemble_coeff is not None:
|
||||
@@ -137,23 +125,19 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
actions = self.model(batch)[0]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
|
||||
70
src/lerobot/policies/act/processor_act.py
Normal file
70
src/lerobot/policies/act/processor_act.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_act_pre_post_processors(
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -35,7 +35,6 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
@@ -57,7 +56,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -70,14 +68,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
@@ -106,9 +96,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -137,7 +124,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
@@ -153,11 +139,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
70
src/lerobot/policies/diffusion/processor_diffusion.py
Normal file
70
src/lerobot/policies/diffusion/processor_diffusion.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_diffusion_pre_post_processors(
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -14,9 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
from torch import nn
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import torch
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
@@ -34,9 +38,10 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline, ProcessorKwargs
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
|
||||
if name == "tdmpc":
|
||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
@@ -101,6 +106,159 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for the processor config."""
|
||||
|
||||
preprocessor_config_filename: str | None
|
||||
postprocessor_config_filename: str | None
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
preprocessor_kwargs: ProcessorKwargs | None
|
||||
postprocessor_kwargs: ProcessorKwargs | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
"""Make a processor instance for a given policy type.
|
||||
|
||||
This function creates the appropriate processor configuration based on the policy type.
|
||||
Each policy type has its own processor with specific preprocessing steps.
|
||||
|
||||
Args:
|
||||
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
||||
the processor from this path instead of creating a new one.
|
||||
**kwargs: Additional keyword arguments passed to the processor creation.
|
||||
|
||||
Returns:
|
||||
Tuple of (input_processor, output_processor) for the policy.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# Extract preprocessor and postprocessor kwargs
|
||||
preprocessor_kwargs = kwargs.get("preprocessor_kwargs", {})
|
||||
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=preprocessor_kwargs.get("to_transition"),
|
||||
to_output=preprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=postprocessor_kwargs.get("to_transition"),
|
||||
to_output=postprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new processor based on policy type
|
||||
if isinstance(policy_cfg, TDMPCConfig):
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
|
||||
|
||||
processors = make_tdmpc_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, DiffusionConfig):
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
|
||||
|
||||
processors = make_diffusion_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, ACTConfig):
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
|
||||
processors = make_act_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VQBeTConfig):
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
|
||||
processors = make_vqbet_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0Config):
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
|
||||
|
||||
processors = make_pi0_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0FASTConfig):
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
|
||||
|
||||
processors = make_pi0fast_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
processors = make_smolvla_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
return processors
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
@@ -147,7 +305,6 @@ def make_policy(
|
||||
kwargs = {}
|
||||
if ds_meta is not None:
|
||||
features = dataset_to_policy_features(ds_meta.features)
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
else:
|
||||
if not cfg.pretrained_path:
|
||||
logging.warning(
|
||||
@@ -155,6 +312,8 @@ def make_policy(
|
||||
"rather than a dataset. Normalization modules inside the policy will have infinite values "
|
||||
"by default without stats from a dataset."
|
||||
)
|
||||
if env_cfg is None:
|
||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
@@ -171,7 +330,7 @@ def make_policy(
|
||||
policy = policy_cls(**kwargs)
|
||||
|
||||
policy.to(cfg.device)
|
||||
assert isinstance(policy, nn.Module)
|
||||
assert isinstance(policy, torch.nn.Module)
|
||||
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
|
||||
@@ -1,420 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
assert isinstance(norm_mode, NormalizationMode)
|
||||
|
||||
shape = tuple(ft.shape)
|
||||
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# sanity checks
|
||||
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
buffer = {}
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
"std": nn.Parameter(std, requires_grad=False),
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model."
|
||||
)
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# TODO: Remove this shallow copy
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
|
||||
# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization
|
||||
# and remove the `Normalize` and `Unnormalize` classes.
|
||||
def _initialize_stats_buffers(
|
||||
module: nn.Module,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`,
|
||||
but is factored out so it can be reused by both classes and stay in sync.
|
||||
"""
|
||||
for key, ft in features.items():
|
||||
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
shape: tuple[int, ...] = tuple(ft.shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
# reduce spatial dimensions, keep channel dimension only
|
||||
c, *_ = shape
|
||||
shape = (c, 1, 1)
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
std_data = stats[key]["std"]
|
||||
if isinstance(mean_data, torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_mean", mean)
|
||||
module.register_buffer(f"{prefix}_std", std)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
module.register_buffer(f"{prefix}_min", min_val)
|
||||
module.register_buffer(f"{prefix}_max", max_val)
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
|
||||
class NormalizeBuffer(nn.Module):
|
||||
"""Same as `Normalize` but statistics are stored as registered buffers rather than parameters."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8)
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class UnnormalizeBuffer(nn.Module):
|
||||
"""Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
if norm_mode is NormalizationMode.IDENTITY:
|
||||
continue
|
||||
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = getattr(self, f"{prefix}_mean")
|
||||
std = getattr(self, f"{prefix}_std")
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = getattr(self, f"{prefix}_min")
|
||||
max_val = getattr(self, f"{prefix}_max")
|
||||
assert not torch.isinf(min_val).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_val).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max_val - min_val) + min_val
|
||||
continue
|
||||
|
||||
raise ValueError(norm_mode)
|
||||
|
||||
return batch
|
||||
@@ -56,18 +56,15 @@ from collections import deque
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertConfig,
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
from lerobot.utils.utils import get_safe_dtype, init_logging
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
@@ -223,28 +220,17 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FlowMatching(config)
|
||||
|
||||
self.reset()
|
||||
@@ -253,99 +239,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
||||
"""
|
||||
Transform state dict keys to match expected model structure.
|
||||
|
||||
Transformations:
|
||||
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
||||
model.paligemma_with_expert.paligemma.lm_head
|
||||
- model.paligemma_with_expert.paligemma.language_model.model ->
|
||||
model.paligemma_with_expert.paligemma.model.language_model
|
||||
- model.paligemma_with_expert.paligemma.vision_tower ->
|
||||
model.paligemma_with_expert.paligemma.model.vision_tower
|
||||
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
||||
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
||||
|
||||
Also handles tied weights between lm_head.weight and
|
||||
embed_tokens.weight.
|
||||
"""
|
||||
import re
|
||||
|
||||
transformed_dict = {}
|
||||
|
||||
transformations = [
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
||||
".paligemma_with_expert.paligemma.lm_head",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
||||
".paligemma_with_expert.paligemma.model.language_model",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
||||
".paligemma_with_expert.paligemma.model.vision_tower",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
||||
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
),
|
||||
]
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
for pattern, replacement in transformations:
|
||||
new_key = pattern.sub(replacement, new_key)
|
||||
transformed_dict[new_key] = value
|
||||
|
||||
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
||||
lm_head_key = None
|
||||
embed_tokens_key = None
|
||||
|
||||
for key in transformed_dict:
|
||||
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
||||
lm_head_key = key
|
||||
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
||||
embed_tokens_key = key
|
||||
if lm_head_key and embed_tokens_key:
|
||||
break
|
||||
|
||||
if lm_head_key and not embed_tokens_key:
|
||||
embed_tokens_key = lm_head_key.replace(
|
||||
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
||||
)
|
||||
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
||||
elif embed_tokens_key and not lm_head_key:
|
||||
lm_head_key = embed_tokens_key.replace(
|
||||
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
||||
)
|
||||
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
||||
|
||||
return transformed_dict
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
||||
) -> "PI0Policy":
|
||||
"""Override to apply key transformations before loading."""
|
||||
from safetensors.torch import load_file
|
||||
|
||||
init_logging()
|
||||
# Load the state dict from file safely
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
|
||||
# Apply key transformations
|
||||
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Load the transformed state dict
|
||||
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
||||
|
||||
# Log message
|
||||
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -377,14 +270,13 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
@@ -394,8 +286,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -410,12 +300,10 @@ class PI0Policy(PreTrainedPolicy):
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("action_is_pad")
|
||||
|
||||
@@ -482,26 +370,6 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
@@ -567,7 +435,7 @@ class PI0FlowMatching(nn.Module):
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PI0Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
||||
118
src/lerobot/policies/pi0/processor_pi0.py
Normal file
118
src/lerobot/policies/pi0/processor_pi0.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
|
||||
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one.
|
||||
This is required for the PaliGemma tokenizer.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def make_pi0_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -58,7 +58,6 @@ from transformers.cache_utils import HybridCache, StaticCache
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
@@ -146,14 +145,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")
|
||||
self.model = PI0FAST(config)
|
||||
|
||||
@@ -221,8 +212,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
@@ -235,8 +224,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
] # self.config.max_action_dim # self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -249,8 +236,6 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss_dict = self.model.forward(batch)
|
||||
return loss_dict["loss"], loss_dict
|
||||
|
||||
|
||||
70
src/lerobot/policies/pi0fast/processor_pi0fast.py
Normal file
70
src/lerobot/policies/pi0fast/processor_pi0fast.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_pi0fast_pre_post_processors(
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -28,7 +28,6 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
from lerobot.policies.normalize import NormalizeBuffer
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
@@ -45,7 +44,6 @@ class SACPolicy(
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -53,7 +51,6 @@ class SACPolicy(
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features["action"].shape[0]
|
||||
self._init_normalization(dataset_stats)
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
@@ -88,8 +85,7 @@ class SACPolicy(
|
||||
|
||||
observations_features = None
|
||||
if self.shared_encoder and self.actor.encoder.has_images:
|
||||
# Cache and normalize image features
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True)
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
@@ -391,28 +387,12 @@ class SACPolicy(
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_normalization(self, dataset_stats):
|
||||
"""Initialize input/output normalization modules."""
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = nn.Identity()
|
||||
if self.config.dataset_stats is not None:
|
||||
params = _convert_normalization_params_to_tensor(self.config.dataset_stats)
|
||||
self.normalize_inputs = NormalizeBuffer(
|
||||
self.config.input_features, self.config.normalization_mapping, params
|
||||
)
|
||||
stats = dataset_stats or params
|
||||
self.normalize_targets = NormalizeBuffer(
|
||||
self.config.output_features, self.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic
|
||||
if self.shared_encoder
|
||||
else SACObservationEncoder(self.config, self.normalize_inputs)
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
@@ -424,9 +404,7 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
@@ -434,9 +412,7 @@ class SACPolicy(
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets
|
||||
)
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
@@ -490,10 +466,9 @@ class SACPolicy(
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None:
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self._init_image_layers()
|
||||
self._init_state_layers()
|
||||
self._compute_output_dim()
|
||||
@@ -568,11 +543,10 @@ class SACObservationEncoder(nn.Module):
|
||||
def forward(
|
||||
self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False
|
||||
) -> Tensor:
|
||||
obs = self.input_normalization(obs)
|
||||
parts = []
|
||||
if self.has_images:
|
||||
if cache is None:
|
||||
cache = self.get_cached_image_features(obs, normalize=False)
|
||||
cache = self.get_cached_image_features(obs)
|
||||
parts.append(self._encode_images(cache, detach))
|
||||
if self.has_env:
|
||||
parts.append(self.env_encoder(obs["observation.environment_state"]))
|
||||
@@ -585,7 +559,7 @@ class SACObservationEncoder(nn.Module):
|
||||
"No parts to concatenate, you should have at least one image or environment state or state"
|
||||
)
|
||||
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]:
|
||||
def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Extract and optionally cache image features from observations.
|
||||
|
||||
This function processes image observations through the vision encoder once and returns
|
||||
@@ -597,26 +571,17 @@ class SACObservationEncoder(nn.Module):
|
||||
- The vision encoder forward pass is typically the main computational bottleneck during training and inference
|
||||
- Caching these features can provide 2-4x speedup in training and inference
|
||||
|
||||
Normalization behavior:
|
||||
- When called from inside forward(): set normalize=False since inputs are already normalized
|
||||
- When called from outside forward(): set normalize=True to ensure proper input normalization
|
||||
|
||||
Usage patterns:
|
||||
- Called in select_action() with normalize=True
|
||||
- Called in select_action()
|
||||
- Called in learner.py's get_observation_features() to pre-compute features for all policy components
|
||||
- Called internally by forward() with normalize=False
|
||||
- Called internally by forward()
|
||||
|
||||
Args:
|
||||
obs: Dictionary of observation tensors containing image keys
|
||||
normalize: Whether to normalize observations before encoding
|
||||
Set to True when calling directly from outside the encoder's forward method
|
||||
Set to False when calling from within forward() where inputs are already normalized
|
||||
|
||||
Returns:
|
||||
Dictionary mapping image keys to their corresponding encoded features
|
||||
"""
|
||||
if normalize:
|
||||
obs = self.input_normalization(obs)
|
||||
batched = torch.cat([obs[k] for k in self.image_keys], dim=0)
|
||||
out = self.image_encoder(batched)
|
||||
chunks = torch.chunk(out, len(self.image_keys), dim=0)
|
||||
@@ -747,7 +712,6 @@ class CriticEnsemble(nn.Module):
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
output_normalization (nn.Module): normalization layer for actions.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
@@ -757,13 +721,11 @@ class CriticEnsemble(nn.Module):
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
output_normalization: nn.Module,
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
@@ -775,11 +737,6 @@ class CriticEnsemble(nn.Module):
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# NOTE: We normalize actions it helps for sample efficiency
|
||||
actions: dict[str, torch.tensor] = {"action": actions}
|
||||
# NOTE: Normalization layer took dict in input and outputs a dict that why
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
|
||||
71
src/lerobot/policies/sac/processor_sac.py
Normal file
71
src/lerobot/policies/sac/processor_sac.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -20,7 +20,6 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
@@ -108,22 +107,12 @@ class Classifier(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
@@ -247,10 +236,6 @@ class Classifier(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
NormalizerProcessorStep(
|
||||
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
NormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name="classifier_preprocessor",
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name="classifier_postprocessor",
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -53,21 +53,13 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import deque
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.constants import ACTION, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
@@ -76,102 +68,6 @@ from lerobot.policies.utils import (
|
||||
)
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker
|
||||
_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_")
|
||||
|
||||
|
||||
def canonicalise(k: str) -> str:
|
||||
"""
|
||||
Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a
|
||||
normalisation-buffer key.
|
||||
"""
|
||||
return _VARIANT_RE.sub(".buffer_", k)
|
||||
|
||||
|
||||
def standardise_state_dict(
|
||||
checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True
|
||||
) -> tuple[dict[str, torch.Tensor], list[str]]:
|
||||
"""
|
||||
• Re-keys `checkpoint ` so that every entry matches the *reference* key set.
|
||||
• If several variant keys collapse to the same canonical name we keep the
|
||||
first one and log the collision.
|
||||
• Returns the new dict + a list of entries that could not be matched.
|
||||
"""
|
||||
out, collisions, unmatched = {}, {}, []
|
||||
|
||||
for k, v in checkpoint.items():
|
||||
canon = canonicalise(k)
|
||||
if canon in ref_keys:
|
||||
if canon in out: # duplicate after collapsing
|
||||
collisions.setdefault(canon, []).append(k)
|
||||
else:
|
||||
out[canon] = v
|
||||
else:
|
||||
unmatched.append(k)
|
||||
|
||||
if verbose:
|
||||
for canon, variants in collisions.items():
|
||||
print(f"[standardise_state_dict] '{canon}' ← {variants}")
|
||||
if unmatched:
|
||||
print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys")
|
||||
|
||||
out.update({k: checkpoint[k] for k in unmatched})
|
||||
return out, unmatched
|
||||
|
||||
|
||||
def rename_checkpoint_keys(checkpoint: dict, rename_str: str):
|
||||
"""
|
||||
Renames keys in a checkpoint dictionary based on the given rename string.
|
||||
|
||||
Args:
|
||||
checkpoint (dict): The checkpoint dictionary.
|
||||
rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2".
|
||||
|
||||
Returns:
|
||||
dict: The modified checkpoint with renamed keys.
|
||||
"""
|
||||
|
||||
rename_dict = dict(pair.split("//") for pair in rename_str.split(","))
|
||||
|
||||
new_checkpoint = {}
|
||||
for k, v in checkpoint.items():
|
||||
for old_key, new_key in rename_dict.items():
|
||||
if old_key in k:
|
||||
k = k.replace(old_key, new_key)
|
||||
new_checkpoint[k] = v
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def load_smolvla(
|
||||
model: torch.nn.Module,
|
||||
filename: str | os.PathLike,
|
||||
*,
|
||||
device: str = "cpu",
|
||||
checkpoint_keys_mapping: str = "",
|
||||
) -> torch.nn.Module:
|
||||
state_dict = safetensors.torch.load_file(filename, device=device)
|
||||
|
||||
# Optional user-supplied renames (e.g. "model._orig_mod.//model.")
|
||||
if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping:
|
||||
state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping)
|
||||
|
||||
state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys()))
|
||||
|
||||
# HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset
|
||||
norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs")
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)}
|
||||
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if not all(key.startswith(norm_keys) for key in missing) or unexpected:
|
||||
raise RuntimeError(
|
||||
"SmolVLA %d missing / %d unexpected keys",
|
||||
len(missing),
|
||||
len(unexpected),
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
@@ -326,28 +222,17 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.reset()
|
||||
|
||||
@@ -357,23 +242,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
# HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls,
|
||||
model: "SmolVLAPolicy",
|
||||
model_file: str,
|
||||
map_location: str,
|
||||
strict: bool,
|
||||
):
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
return load_smolvla(
|
||||
model,
|
||||
model_file,
|
||||
device=map_location,
|
||||
checkpoint_keys_mapping="model._orig_mod.//model.",
|
||||
)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -389,7 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
|
||||
@@ -397,8 +266,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
@@ -408,8 +275,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -450,11 +315,11 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE}.tokens"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE}.attention_mask"]
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
@@ -518,30 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
if len(tasks) == 1:
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])]
|
||||
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding=self.config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
|
||||
111
src/lerobot/policies/smolvla/processor_smolvla.py
Normal file
111
src/lerobot/policies/smolvla/processor_smolvla.py
Normal file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ProcessorStepRegistry,
|
||||
RenameProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_smolvla_pre_post_processors(
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
SmolVLANewLineProcessor(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.vlm_model_name,
|
||||
padding=config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=config.tokenizer_max_length,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
|
||||
class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
|
||||
"""Add a new line to the end of the task if it doesn't have one."""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: add newline if not present
|
||||
if not task.endswith("\n"):
|
||||
new_complementary_data["task"] = f"{task}\n"
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: add newline to each if not present
|
||||
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
|
||||
# If task is neither string nor list of strings, leave unchanged
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -36,7 +36,6 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
@@ -63,26 +62,19 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.model = TDMPCTOLD(config)
|
||||
self.model_target = deepcopy(self.model)
|
||||
for param in self.model_target.parameters():
|
||||
@@ -137,7 +129,6 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -147,11 +138,12 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -320,11 +312,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
|
||||
70
src/lerobot/policies/tdmpc/processor_tdmpc.py
Normal file
70
src/lerobot/policies/tdmpc/processor_tdmpc.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_tdmpc_pre_post_processors(
|
||||
config: TDMPCConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -28,7 +28,6 @@ import torchvision
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
@@ -48,7 +47,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
config: VQBeTConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -61,14 +59,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self.reset()
|
||||
@@ -128,7 +118,6 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -142,10 +131,12 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
# NOTE: It's important that this happens after stacking the images into a single key.
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
if ACTION in batch:
|
||||
batch.pop(ACTION)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
@@ -165,10 +156,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
# loss: total loss of training RVQ
|
||||
|
||||
71
src/lerobot/policies/vqbet/processor_vqbet.py
Normal file
71
src/lerobot/policies/vqbet/processor_vqbet.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
||||
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
RenameProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def make_vqbet_pre_post_processors(
|
||||
config: VQBeTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[PolicyProcessorPipeline, PolicyProcessorPipeline]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
PolicyProcessorPipeline(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -14,41 +14,90 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .device_processor import DeviceProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from .observation_processor import VanillaObservationProcessor
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
merge_transitions,
|
||||
transition_to_batch,
|
||||
transition_to_dataset_frame,
|
||||
)
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .gym_action_processor import Numpy2TorchActionProcessorStep, Torch2NumpyActionProcessorStep
|
||||
from .hil_processor import (
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
GripperPenaltyProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
)
|
||||
from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessorStep
|
||||
from .pipeline import (
|
||||
ActionProcessor,
|
||||
DoneProcessor,
|
||||
EnvTransition,
|
||||
IdentityProcessor,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DoneProcessorStep,
|
||||
IdentityProcessorStep,
|
||||
InfoProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorKwargs,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessor,
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
TruncatedProcessor,
|
||||
RewardProcessorStep,
|
||||
RobotProcessorPipeline,
|
||||
TruncatedProcessorStep,
|
||||
)
|
||||
from .rename_processor import RenameProcessor
|
||||
from .rename_processor import RenameProcessorStep
|
||||
from .tokenizer_processor import TokenizerProcessorStep
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessor",
|
||||
"DeviceProcessor",
|
||||
"DoneProcessor",
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
"AddTeleopEventsAsInfoStep",
|
||||
"ComplementaryDataProcessorStep",
|
||||
"batch_to_transition",
|
||||
"create_transition",
|
||||
"DeviceProcessorStep",
|
||||
"DoneProcessorStep",
|
||||
"EnvTransition",
|
||||
"IdentityProcessor",
|
||||
"InfoProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
"ObservationProcessor",
|
||||
"GripperPenaltyProcessorStep",
|
||||
"hotswap_stats",
|
||||
"IdentityProcessorStep",
|
||||
"ImageCropResizeProcessorStep",
|
||||
"InfoProcessorStep",
|
||||
"InterventionActionProcessorStep",
|
||||
"JointVelocityProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"merge_transitions",
|
||||
"MotorCurrentProcessorStep",
|
||||
"NormalizerProcessorStep",
|
||||
"Numpy2TorchActionProcessorStep",
|
||||
"ObservationProcessorStep",
|
||||
"PolicyProcessorPipeline",
|
||||
"ProcessorKwargs",
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
"RenameProcessor",
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"RenameProcessorStep",
|
||||
"RewardClassifierProcessorStep",
|
||||
"RewardProcessorStep",
|
||||
"DataProcessorPipeline",
|
||||
"TimeLimitProcessorStep",
|
||||
"AddBatchDimensionProcessorStep",
|
||||
"RobotProcessorPipeline",
|
||||
"TokenizerProcessorStep",
|
||||
"Torch2NumpyActionProcessorStep",
|
||||
"transition_to_batch",
|
||||
"transition_to_dataset_frame",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
"TruncatedProcessorStep",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
159
src/lerobot/processor/batch_processor.py
Normal file
159
src/lerobot/processor/batch_processor.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .core import EnvTransition
|
||||
from .pipeline import (
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_action")
|
||||
class AddBatchDimensionActionStep(ActionProcessorStep):
|
||||
"""Process action component in-place, adding batch dimension if needed."""
|
||||
|
||||
def action(self, action):
|
||||
if not isinstance(action, Tensor) or action.dim() != 1:
|
||||
return action
|
||||
|
||||
return action.unsqueeze(0)
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_observation")
|
||||
class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
"""Process observation component in-place, adding batch dimensions where needed."""
|
||||
|
||||
def observation(self, observation):
|
||||
# Process state observations - add batch dim if 1D
|
||||
for state_key in [OBS_STATE, OBS_ENV_STATE]:
|
||||
if state_key in observation:
|
||||
state_value = observation[state_key]
|
||||
if isinstance(state_value, Tensor) and state_value.dim() == 1:
|
||||
observation[state_key] = state_value.unsqueeze(0)
|
||||
|
||||
# Process single image observation - add batch dim if 3D
|
||||
if OBS_IMAGE in observation:
|
||||
image_value = observation[OBS_IMAGE]
|
||||
if isinstance(image_value, Tensor) and image_value.dim() == 3:
|
||||
observation[OBS_IMAGE] = image_value.unsqueeze(0)
|
||||
|
||||
# Process multiple image observations - add batch dim if 3D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
return observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data")
|
||||
class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""Process complementary data in-place, handling task field batching."""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
# Process task field - wrap string in list to add batch dimension
|
||||
if "task" in complementary_data:
|
||||
task_value = complementary_data["task"]
|
||||
if isinstance(task_value, str):
|
||||
complementary_data["task"] = [task_value]
|
||||
|
||||
# Process index field - add batch dim if 0D
|
||||
if "index" in complementary_data:
|
||||
index_value = complementary_data["index"]
|
||||
if isinstance(index_value, Tensor) and index_value.dim() == 0:
|
||||
complementary_data["index"] = index_value.unsqueeze(0)
|
||||
|
||||
# Process task_index field - add batch dim if 0D
|
||||
if "task_index" in complementary_data:
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
return complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="to_batch_processor")
|
||||
class AddBatchDimensionProcessorStep(ProcessorStep):
|
||||
"""Processor that adds batch dimensions to observations and actions when needed.
|
||||
|
||||
This processor ensures that observations and actions have proper batch dimensions for model processing:
|
||||
|
||||
- For state observations (observation.state, observation.environment_state):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For image observations (observation.image, observation.images.*):
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 3-dimensional (H, W, C)
|
||||
|
||||
- For actions:
|
||||
Adds batch dimension (unsqueeze at dim=0) if tensor is 1-dimensional
|
||||
|
||||
- For task field in complementary data:
|
||||
Wraps string task in a list to add batch dimension
|
||||
(task must be a string or list of strings)
|
||||
|
||||
This is useful when processing single transitions that need to be batched for
|
||||
model inference or when converting from unbatched environment outputs to
|
||||
batched model inputs.
|
||||
|
||||
The processor only modifies tensors that need batching and leaves already
|
||||
batched tensors unchanged.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# State: (7,) -> (1, 7)
|
||||
# Image: (224, 224, 3) -> (1, 224, 224, 3)
|
||||
# Action: (4,) -> (1, 4)
|
||||
# Task: "pick_cube" -> ["pick_cube"]
|
||||
# Already batched: (1, 7) -> (1, 7) [unchanged]
|
||||
```
|
||||
"""
|
||||
|
||||
to_batch_action_processor: AddBatchDimensionActionStep = field(
|
||||
default_factory=AddBatchDimensionActionStep
|
||||
)
|
||||
to_batch_observation_processor: AddBatchDimensionObservationStep = field(
|
||||
default_factory=AddBatchDimensionObservationStep
|
||||
)
|
||||
to_batch_complementary_data_processor: AddBatchDimensionComplementaryDataStep = field(
|
||||
default_factory=AddBatchDimensionComplementaryDataStep
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = self.to_batch_action_processor(transition)
|
||||
transition = self.to_batch_observation_processor(transition)
|
||||
transition = self.to_batch_complementary_data_processor(transition)
|
||||
return transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# NOTE: We ignore the batch dimension when transforming features
|
||||
return features
|
||||
481
src/lerobot/processor/converters.py
Normal file
481
src/lerobot/processor/converters.py
Normal file
@@ -0,0 +1,481 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from functools import singledispatch
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
@singledispatch
|
||||
def to_tensor(
|
||||
value: Any,
|
||||
*,
|
||||
dtype: torch.dtype | None = torch.float32,
|
||||
device: torch.device | str | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert various data types to PyTorch tensors with configurable options.
|
||||
|
||||
This is a unified tensor conversion function using single dispatch to handle
|
||||
different input types appropriately.
|
||||
|
||||
Args:
|
||||
value: Input value to convert (tensor, array, scalar, sequence, etc.)
|
||||
dtype: Target tensor dtype. If None, preserves original dtype.
|
||||
device: Target device for the tensor.
|
||||
|
||||
Returns:
|
||||
PyTorch tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If the input type is not supported.
|
||||
"""
|
||||
raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
|
||||
|
||||
|
||||
@to_tensor.register(torch.Tensor)
|
||||
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle existing PyTorch tensors."""
|
||||
if dtype is not None:
|
||||
value = value.to(dtype=dtype)
|
||||
if device is not None:
|
||||
value = value.to(device=device)
|
||||
return value
|
||||
|
||||
|
||||
@to_tensor.register(np.ndarray)
|
||||
def _(
|
||||
value: np.ndarray,
|
||||
*,
|
||||
dtype=torch.float32,
|
||||
device=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Handle numpy arrays."""
|
||||
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
|
||||
if value.ndim == 0:
|
||||
# Numpy scalars should be converted to 0-dimensional tensors
|
||||
scalar_value = value.item()
|
||||
return torch.tensor(scalar_value, dtype=dtype, device=device)
|
||||
|
||||
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
|
||||
tensor = torch.from_numpy(value)
|
||||
|
||||
# Apply dtype conversion if specified
|
||||
if dtype is not None:
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
if device is not None:
|
||||
tensor = tensor.to(device=device)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
@to_tensor.register(int)
|
||||
@to_tensor.register(float)
|
||||
@to_tensor.register(np.integer)
|
||||
@to_tensor.register(np.floating)
|
||||
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle scalar values including numpy scalars."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(list)
|
||||
@to_tensor.register(tuple)
|
||||
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
|
||||
"""Handle sequences (lists, tuples)."""
|
||||
return torch.tensor(value, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@to_tensor.register(dict)
|
||||
def _(value: dict, *, device=None, **kwargs) -> dict:
|
||||
"""Handle dictionaries by recursively converting values to tensors."""
|
||||
if not value:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
for key, sub_value in value.items():
|
||||
if sub_value is None:
|
||||
continue
|
||||
|
||||
if isinstance(sub_value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
continue
|
||||
|
||||
# Convert individual values to tensors
|
||||
result[key] = to_tensor(
|
||||
sub_value,
|
||||
device=device,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
|
||||
"""Convert tensor to numpy/scalar if needed."""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.item() if x.numel() == 1 else x.detach().cpu().numpy()
|
||||
return x
|
||||
|
||||
|
||||
def _is_image(arr: Any) -> bool:
|
||||
return isinstance(arr, np.ndarray) and arr.dtype == np.uint8 and arr.ndim == 3
|
||||
|
||||
|
||||
def _split_obs_to_state_and_images(obs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
state, images = {}, {}
|
||||
for k, v in obs.items():
|
||||
if "image" in k.lower() or _is_image(v):
|
||||
images[k] = v
|
||||
else:
|
||||
state[k] = v
|
||||
return state, images
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Private Helper Functions (Common Logic)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract complementary data (pad flags, task, index, task_index)."""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
|
||||
|
||||
def _merge_transitions(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||
"""Merge two transitions, with other taking precedence."""
|
||||
out = deepcopy(base)
|
||||
|
||||
for key in (
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.INFO,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
):
|
||||
if other.get(key):
|
||||
out.setdefault(key, {}).update(deepcopy(other[key]))
|
||||
|
||||
for k in (TransitionKey.REWARD, TransitionKey.DONE, TransitionKey.TRUNCATED):
|
||||
if k in other:
|
||||
out[k] = other[k]
|
||||
return out
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Core Conversion Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
reward: float = 0.0,
|
||||
done: bool = False,
|
||||
truncated: bool = False,
|
||||
info: dict[str, Any] | None = None,
|
||||
complementary_data: dict[str, Any] | None = None,
|
||||
) -> EnvTransition:
|
||||
"""Create an EnvTransition with sensible defaults.
|
||||
|
||||
Args:
|
||||
observation: Observation dictionary.
|
||||
action: Action dictionary.
|
||||
reward: Scalar reward value.
|
||||
done: Episode termination flag.
|
||||
truncated: Episode truncation flag.
|
||||
info: Additional info dictionary.
|
||||
complementary_data: Complementary data dictionary.
|
||||
|
||||
Returns:
|
||||
Complete EnvTransition dictionary.
|
||||
"""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info if info is not None else {},
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {},
|
||||
}
|
||||
|
||||
|
||||
def action_to_transition(action: dict[str, Any]) -> EnvTransition: # action_to_transition
|
||||
"""
|
||||
Convert a raw teleop action dict into an EnvTransition under the ACTION TransitionKey.
|
||||
"""
|
||||
act_dict: dict[str, Any] = {}
|
||||
for k, v in action.items():
|
||||
# Check if the value is a type that should not be converted to a tensor.
|
||||
if isinstance(v, (Rotation, dict)):
|
||||
act_dict[f"{ACTION}.{k}"] = v
|
||||
continue
|
||||
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
|
||||
|
||||
return create_transition(observation={}, action=act_dict)
|
||||
|
||||
|
||||
# TODO(Adil, Pepijn): Overtime we can maybe add these converters to pipeline.py itself
|
||||
def observation_to_transition(observation: dict[str, Any]) -> EnvTransition:
|
||||
"""
|
||||
Convert a raw robot observation dict into an EnvTransition under the OBSERVATION TransitionKey.
|
||||
"""
|
||||
state, images = _split_obs_to_state_and_images(observation)
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
for k, v in state.items():
|
||||
arr = np.array(v) if np.isscalar(v) else v
|
||||
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
|
||||
|
||||
for cam, img in images.items():
|
||||
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
|
||||
|
||||
return create_transition(observation=obs_dict, action={})
|
||||
|
||||
|
||||
def transition_to_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""
|
||||
Converts a EnvTransition under the ACTION TransitionKey to a dict with keys ending in '.pos' for raw robot actions.
|
||||
"""
|
||||
out: dict[str, Any] = {}
|
||||
action_dict = transition.get(TransitionKey.ACTION) or {}
|
||||
|
||||
if action_dict is None:
|
||||
return out
|
||||
|
||||
for k, v in action_dict.items():
|
||||
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
|
||||
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
|
||||
out[out_key] = float(v)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def merge_transitions(transitions: Sequence[EnvTransition] | EnvTransition) -> EnvTransition:
|
||||
"""Merge multiple transitions or return single transition.
|
||||
|
||||
Args:
|
||||
transitions: Either a single transition or iterable of transitions.
|
||||
|
||||
Returns:
|
||||
Merged EnvTransition.
|
||||
"""
|
||||
|
||||
if not isinstance(transitions, Sequence): # Single transition
|
||||
return transitions
|
||||
|
||||
items = list(transitions)
|
||||
if not items:
|
||||
raise ValueError("merge_transitions() requires a non-empty sequence of transitions")
|
||||
|
||||
result = items[0]
|
||||
for t in items[1:]:
|
||||
result = _merge_transitions(result, t)
|
||||
return result
|
||||
|
||||
|
||||
def transition_to_dataset_frame(
|
||||
transitions_or_transition: EnvTransition | Sequence[EnvTransition], features: dict[str, dict]
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a single EnvTransition or an iterable of them into a flat, dataset-friendly dictionary for training or evaluation.
|
||||
|
||||
Processes transitions according to the provided feature specification and returns
|
||||
data in the format expected by machine learning models and datasets.
|
||||
|
||||
Args:
|
||||
transitions_or_transition: Either a single EnvTransition dict or an iterable of them
|
||||
(which will be merged using merge_transitions).
|
||||
features: Feature specification dictionary with the following structure:
|
||||
- 'action': dict with 'names': list of action feature names
|
||||
- 'observation.state': dict with 'names': list of state feature names
|
||||
- keys starting with 'observation.images.' are passed through as-is
|
||||
|
||||
Returns:
|
||||
Flat dictionary containing:
|
||||
- numpy arrays for "observation.state" and "action" (vectorized from feature names)
|
||||
- any image tensors defined in features (passed through unchanged)
|
||||
- next.{reward,done,truncated} scalar values
|
||||
- info dict
|
||||
- *_is_pad flags and task from complementary_data
|
||||
"""
|
||||
action_names = features.get(ACTION, {}).get("names", [])
|
||||
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
|
||||
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
|
||||
|
||||
tr = merge_transitions(transitions_or_transition)
|
||||
obs = tr.get(TransitionKey.OBSERVATION, {}) or {}
|
||||
act = tr.get(TransitionKey.ACTION, {}) or {}
|
||||
batch: dict[str, Any] = {}
|
||||
|
||||
# Images passthrough
|
||||
for k in image_keys:
|
||||
if k in obs:
|
||||
batch[k] = obs[k]
|
||||
|
||||
# Observation.state vector
|
||||
if obs_state_names:
|
||||
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
|
||||
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Action vector
|
||||
if action_names:
|
||||
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
|
||||
batch[ACTION] = np.asarray(vals, dtype=np.float32)
|
||||
|
||||
# Add transition metadata
|
||||
if tr.get(TransitionKey.REWARD) is not None:
|
||||
reward_val = _from_tensor(tr[TransitionKey.REWARD])
|
||||
# Check if features expect array format, otherwise keep as scalar
|
||||
if REWARD in features and features[REWARD].get("shape") == (1,):
|
||||
batch[REWARD] = np.array([reward_val], dtype=np.float32)
|
||||
else:
|
||||
batch[REWARD] = reward_val
|
||||
|
||||
if tr.get(TransitionKey.DONE) is not None:
|
||||
done_val = _from_tensor(tr[TransitionKey.DONE])
|
||||
if DONE in features and features[DONE].get("shape") == (1,):
|
||||
batch[DONE] = np.array([done_val], dtype=bool)
|
||||
else:
|
||||
batch[DONE] = done_val
|
||||
|
||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||
truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
|
||||
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
|
||||
else:
|
||||
batch[TRUNCATED] = truncated_val
|
||||
|
||||
# Complementary data flags and task
|
||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
if comp:
|
||||
# pad flags
|
||||
for k, v in comp.items():
|
||||
if k.endswith("_is_pad"):
|
||||
batch[k] = v
|
||||
# task label
|
||||
if comp.get("task") is not None:
|
||||
batch["task"] = comp["task"]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
"""Convert a batch dict coming from LeRobot replay/dataset code into an EnvTransition dictionary.
|
||||
|
||||
The function maps well known keys to the EnvTransition structure. Missing keys are
|
||||
filled with sane defaults (None or 0.0/False).
|
||||
|
||||
Keys recognised (case-sensitive):
|
||||
* "observation.*" (keys starting with "observation." are grouped into observation dict)
|
||||
* "action"
|
||||
* "next.reward"
|
||||
* "next.done"
|
||||
* "next.truncated"
|
||||
* "info"
|
||||
* "_is_pad" patterns (padding flags)
|
||||
* "task", "index", "task_index" (complementary data)
|
||||
|
||||
Additional keys are ignored so that existing dataloaders can carry extra
|
||||
metadata without breaking the processor.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary from datasets or dataloaders containing the above keys.
|
||||
|
||||
Returns:
|
||||
EnvTransition dictionary with properly structured transition data.
|
||||
"""
|
||||
|
||||
# Validate input type
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
# Extract observation keys
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
complementary_data = _extract_complementary_data(batch)
|
||||
|
||||
return create_transition(
|
||||
observation=observation_keys if observation_keys else None,
|
||||
action=batch.get("action"),
|
||||
reward=batch.get("next.reward", 0.0),
|
||||
done=batch.get("next.done", False),
|
||||
truncated=batch.get("next.truncated", False),
|
||||
info=batch.get("info", {}),
|
||||
complementary_data=complementary_data if complementary_data else None,
|
||||
)
|
||||
|
||||
|
||||
def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
"""Inverse of batch_to_transition. Returns a dict with canonical field names used throughout LeRobot.
|
||||
|
||||
Converts an EnvTransition back to the batch format expected by datasets, dataloaders,
|
||||
and other LeRobot components.
|
||||
|
||||
Output format:
|
||||
* "action": Action data from transition
|
||||
* "next.reward": Reward value (defaults to 0.0)
|
||||
* "next.done": Done flag (defaults to False)
|
||||
* "next.truncated": Truncated flag (defaults to False)
|
||||
* "info": Info dictionary (defaults to {})
|
||||
* Flattened observation keys (e.g., "observation.state", "observation.images.cam1")
|
||||
* Complementary data fields ("task", "index", "task_index", padding flags)
|
||||
|
||||
Args:
|
||||
transition: EnvTransition dictionary to convert.
|
||||
|
||||
Returns:
|
||||
Batch dictionary with canonical LeRobot field names suitable for dataloaders.
|
||||
"""
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
"info": transition.get(TransitionKey.INFO, {}),
|
||||
}
|
||||
|
||||
# Add complementary data
|
||||
comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
if comp_data:
|
||||
batch.update(comp_data)
|
||||
|
||||
# Flatten observation dict
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
|
||||
return batch
|
||||
49
src/lerobot/processor/core.py
Normal file
49
src/lerobot/processor/core.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
|
||||
# TODO(Steven): Use consts
|
||||
OBSERVATION = "observation"
|
||||
ACTION = "action"
|
||||
REWARD = "reward"
|
||||
DONE = "done"
|
||||
TRUNCATED = "truncated"
|
||||
INFO = "info"
|
||||
COMPLEMENTARY_DATA = "complementary_data"
|
||||
|
||||
|
||||
EnvTransition = TypedDict(
|
||||
"EnvTransition",
|
||||
{
|
||||
TransitionKey.OBSERVATION.value: dict[str, Any] | None,
|
||||
TransitionKey.ACTION.value: Any | torch.Tensor | None,
|
||||
TransitionKey.REWARD.value: float | torch.Tensor | None,
|
||||
TransitionKey.DONE.value: bool | torch.Tensor | None,
|
||||
TransitionKey.TRUNCATED.value: bool | torch.Tensor | None,
|
||||
TransitionKey.INFO.value: dict[str, Any] | None,
|
||||
TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None,
|
||||
},
|
||||
)
|
||||
145
src/lerobot/processor/delta_action_processor.py
Normal file
145
src/lerobot/processor/delta_action_processor.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@dataclass
|
||||
class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
"""
|
||||
Map a tensor to a delta action dictionary.
|
||||
"""
|
||||
|
||||
use_gripper: bool = True
|
||||
|
||||
def action(self, action: Tensor) -> dict:
|
||||
if action.dim() > 1:
|
||||
action = action.squeeze(0)
|
||||
|
||||
# TODO (maractingi): add rotation
|
||||
delta_action = {
|
||||
f"{ACTION}.delta_x": action[0],
|
||||
f"{ACTION}.delta_y": action[1],
|
||||
f"{ACTION}.delta_z": action[2],
|
||||
}
|
||||
if self.use_gripper:
|
||||
delta_action[f"{ACTION}.gripper"] = action[3]
|
||||
return delta_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{ACTION}.delta_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.delta_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
if self.use_gripper:
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_delta_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapDeltaActionToRobotActionStep(ActionProcessorStep):
|
||||
"""
|
||||
Map delta actions from teleoperators (gamepad, keyboard) to robot target actions
|
||||
for use with inverse kinematics processors.
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.delta_x": float,
|
||||
"action.delta_y": float,
|
||||
"action.delta_z": float,
|
||||
"action.gripper": float (optional),
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.target_x": float,
|
||||
"action.target_y": float,
|
||||
"action.target_z": float,
|
||||
"action.target_wx": float,
|
||||
"action.target_wy": float,
|
||||
"action.target_wz": float,
|
||||
"action.gripper": float,
|
||||
}
|
||||
"""
|
||||
|
||||
# Scale factors for delta movements
|
||||
position_scale: float = 1.0
|
||||
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
|
||||
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
|
||||
|
||||
def action(self, action: dict) -> dict:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
# TODO (maractingi): changing this target_xyz naming convention from the teleop_devices
|
||||
delta_x = action.pop(f"{ACTION}.delta_x", 0.0)
|
||||
delta_y = action.pop(f"{ACTION}.delta_y", 0.0)
|
||||
delta_z = action.pop(f"{ACTION}.delta_z", 0.0)
|
||||
gripper = action.pop(f"{ACTION}.gripper", 1.0) # Default to "stay" (1.0)
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position
|
||||
enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise
|
||||
|
||||
# Scale the deltas appropriately
|
||||
scaled_delta_x = delta_x * self.position_scale
|
||||
scaled_delta_y = delta_y * self.position_scale
|
||||
scaled_delta_z = delta_z * self.position_scale
|
||||
|
||||
# For gamepad/keyboard, we don't have rotation input, so set to 0
|
||||
# These could be extended in the future for more sophisticated teleoperators
|
||||
target_wx = 0.0
|
||||
target_wy = 0.0
|
||||
target_wz = 0.0
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
f"{ACTION}.enabled": enabled,
|
||||
f"{ACTION}.target_x": scaled_delta_x,
|
||||
f"{ACTION}.target_y": scaled_delta_y,
|
||||
f"{ACTION}.target_z": scaled_delta_z,
|
||||
f"{ACTION}.target_wx": target_wx,
|
||||
f"{ACTION}.target_wy": target_wy,
|
||||
f"{ACTION}.target_wz": target_wz,
|
||||
f"{ACTION}.gripper": float(gripper),
|
||||
}
|
||||
|
||||
return action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transform features to match output format."""
|
||||
features.pop(f"{ACTION}.delta_x", None)
|
||||
features.pop(f"{ACTION}.delta_y", None)
|
||||
features.pop(f"{ACTION}.delta_z", None)
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
|
||||
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
@@ -19,64 +19,116 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("device_processor")
|
||||
@dataclass
|
||||
class DeviceProcessor:
|
||||
"""Processes transitions by moving tensors to the specified device.
|
||||
class DeviceProcessorStep(ProcessorStep):
|
||||
"""Processes transitions by moving tensors to the specified device and optionally converting float dtypes.
|
||||
|
||||
This processor ensures that all tensors in the transition are moved to the
|
||||
specified device (CPU or GPU) before they are returned.
|
||||
specified device (CPU or GPU) before they are returned. It can also convert
|
||||
floating-point tensors to a specified dtype while preserving non-float types
|
||||
(int, long, bool, etc.).
|
||||
"""
|
||||
|
||||
device: torch.device = "cpu"
|
||||
device: str = "cpu"
|
||||
float_dtype: str | None = None
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"half": torch.float16,
|
||||
"float": torch.float32,
|
||||
"double": torch.float64,
|
||||
}
|
||||
|
||||
def __post_init__(self):
|
||||
self.device = get_safe_torch_device(self.device)
|
||||
self.tensor_device: torch.device = get_safe_torch_device(self.device)
|
||||
self.device = self.tensor_device.type # cuda might have changed to cuda:1
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
# Validate and convert float_dtype string to torch dtype
|
||||
if self.float_dtype is not None:
|
||||
if self.float_dtype not in self.DTYPE_MAPPING:
|
||||
raise ValueError(
|
||||
f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}"
|
||||
)
|
||||
|
||||
self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype]
|
||||
else:
|
||||
self._target_float_dtype = None
|
||||
|
||||
def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Process a tensor by moving to device and optionally converting float dtype.
|
||||
|
||||
If the tensor is already on a GPU and we're configured for a GPU, it preserves
|
||||
that GPU placement (useful for multi-GPU training with Accelerate).
|
||||
Otherwise, it moves to the configured device.
|
||||
"""
|
||||
# Determine target device
|
||||
if tensor.is_cuda and self.tensor_device.type == "cuda":
|
||||
# Both tensor and target are on GPU - preserve tensor's GPU placement
|
||||
# This handles multi-GPU scenarios where Accelerate has already placed
|
||||
# tensors on the correct GPU for each process
|
||||
target_device = tensor.device
|
||||
else:
|
||||
# Either tensor is on CPU, or we're configured for CPU
|
||||
# In both cases, use the configured device
|
||||
target_device = self.tensor_device
|
||||
|
||||
# Only move if necessary
|
||||
if tensor.device != target_device:
|
||||
tensor = tensor.to(target_device, non_blocking=self.non_blocking)
|
||||
|
||||
# Convert float dtype if specified and tensor is floating point
|
||||
if self._target_float_dtype is not None and tensor.is_floating_point():
|
||||
tensor = tensor.to(dtype=self._target_float_dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Create a copy of the transition
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Process observation tensors
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in observation.items()
|
||||
}
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
simple_tensor_keys = [
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.REWARD,
|
||||
TransitionKey.DONE,
|
||||
TransitionKey.TRUNCATED,
|
||||
]
|
||||
|
||||
# Process action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
|
||||
dict_tensor_keys = [
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
]
|
||||
|
||||
# Process reward tensor
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is not None and isinstance(reward, torch.Tensor):
|
||||
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
|
||||
# Process simple tensors
|
||||
for key in simple_tensor_keys:
|
||||
value = transition.get(key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
new_transition[key] = self._process_tensor(value)
|
||||
|
||||
# Process done tensor
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is not None and isinstance(done, torch.Tensor):
|
||||
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process truncated tensor
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is not None and isinstance(truncated, torch.Tensor):
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated.to(
|
||||
self.device, non_blocking=self.non_blocking
|
||||
)
|
||||
# Process dictionary-like tensors
|
||||
for key in dict_tensor_keys:
|
||||
data_dict = transition.get(key)
|
||||
if data_dict is not None:
|
||||
new_data_dict = {
|
||||
k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in data_dict.items()
|
||||
}
|
||||
new_transition[key] = new_data_dict
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {"device": self.device}
|
||||
return {"device": self.device, "float_dtype": self.float_dtype}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
72
src/lerobot/processor/gym_action_processor.py
Normal file
72
src/lerobot/processor/gym_action_processor.py
Normal file
@@ -0,0 +1,72 @@
|
||||
#! /usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
from .converters import to_tensor
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("torch2numpy_action_processor")
|
||||
@dataclass
|
||||
class Torch2NumpyActionProcessorStep(ActionProcessorStep):
|
||||
"""Convert PyTorch tensor actions to NumPy arrays."""
|
||||
|
||||
squeeze_batch_dim: bool = True
|
||||
|
||||
def action(self, action: torch.Tensor) -> np.ndarray:
|
||||
if not isinstance(action, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"Expected torch.Tensor or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
|
||||
numpy_action = action.detach().cpu().numpy()
|
||||
|
||||
# Remove batch dimensions but preserve action dimensions
|
||||
# Only squeeze if there's a batch dimension (first dim == 1)
|
||||
if (
|
||||
self.squeeze_batch_dim
|
||||
and numpy_action.shape
|
||||
and len(numpy_action.shape) > 1
|
||||
and numpy_action.shape[0] == 1
|
||||
):
|
||||
numpy_action = numpy_action.squeeze(0)
|
||||
|
||||
return numpy_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("numpy2torch_action_processor")
|
||||
@dataclass
|
||||
class Numpy2TorchActionProcessorStep(ActionProcessorStep):
|
||||
"""Convert NumPy array action to PyTorch tensor."""
|
||||
|
||||
def action(self, action: np.ndarray) -> torch.Tensor:
|
||||
if not isinstance(action, np.ndarray):
|
||||
raise TypeError(
|
||||
f"Expected np.ndarray or None, got {type(action).__name__}. "
|
||||
"Use appropriate processor for non-tensor actions."
|
||||
)
|
||||
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
|
||||
return torch_action
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
332
src/lerobot/processor/hil_processor.py
Normal file
332
src/lerobot/processor/hil_processor.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import (
|
||||
ComplementaryDataProcessorStep,
|
||||
InfoProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TruncatedProcessorStep,
|
||||
)
|
||||
|
||||
GRIPPER_KEY = "gripper"
|
||||
DISCRETE_PENALTY_KEY = "discrete_penalty"
|
||||
TELEOP_ACTION_KEY = "teleop_action"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||
@dataclass
|
||||
class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep):
|
||||
"""Add teleoperator action to transition complementary data."""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def complementary_data(self, complementary_data: dict) -> dict:
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_teleop_action_as_info")
|
||||
@dataclass
|
||||
class AddTeleopEventsAsInfoStep(InfoProcessorStep):
|
||||
"""Add teleoperator control events to transition info."""
|
||||
|
||||
teleop_device: Teleoperator
|
||||
|
||||
def info(self, info: dict) -> dict:
|
||||
new_info = dict(info)
|
||||
teleop_events = getattr(self.teleop_device, "get_teleop_events", lambda: {})()
|
||||
new_info.update(teleop_events)
|
||||
return new_info
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("image_crop_resize_processor")
|
||||
@dataclass
|
||||
class ImageCropResizeProcessorStep(ObservationProcessorStep):
|
||||
"""Crop and resize image observations."""
|
||||
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||
resize_size: tuple[int, int] | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
if self.resize_size is None and not self.crop_params_dict:
|
||||
return observation
|
||||
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Process all image keys in the observation
|
||||
for key in observation:
|
||||
if "image" not in key:
|
||||
continue
|
||||
|
||||
image = observation[key]
|
||||
device = image.device
|
||||
# NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu
|
||||
if device.type == "mps":
|
||||
image = image.cpu()
|
||||
# Crop if crop params are provided for this key
|
||||
if self.crop_params_dict is not None and key in self.crop_params_dict:
|
||||
crop_params = self.crop_params_dict[key]
|
||||
image = F.crop(image, *crop_params)
|
||||
if self.resize_size is not None:
|
||||
image = F.resize(image, self.resize_size)
|
||||
image = image.clamp(0.0, 1.0)
|
||||
new_observation[key] = image.to(device)
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"crop_params_dict": self.crop_params_dict,
|
||||
"resize_size": self.resize_size,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if self.resize_size is None:
|
||||
return features
|
||||
for key in features:
|
||||
if "image" in key:
|
||||
features[key] = PolicyFeature(type=features[key].type, shape=self.resize_size)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("time_limit_processor")
|
||||
class TimeLimitProcessorStep(TruncatedProcessorStep):
|
||||
"""Track episode steps and enforce time limits."""
|
||||
|
||||
max_episode_steps: int
|
||||
current_step: int = 0
|
||||
|
||||
def truncated(self, truncated):
|
||||
self.current_step += 1
|
||||
if self.current_step >= self.max_episode_steps:
|
||||
truncated = True
|
||||
# TODO (steven): missing an else truncated = False?
|
||||
return truncated
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"max_episode_steps": self.max_episode_steps,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = 0
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""Apply penalty for inappropriate gripper usage."""
|
||||
|
||||
penalty: float = -0.01
|
||||
max_gripper_pos: float = 30.0
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""Calculate gripper penalty and add to complementary data."""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
current_gripper_pos = complementary_data.get("raw_joint_positions", None).get(GRIPPER_KEY, None)
|
||||
if current_gripper_pos is None:
|
||||
return complementary_data
|
||||
|
||||
gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
|
||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||
|
||||
# Normalize gripper state and action
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
# Create new complementary data with penalty info
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the processor state."""
|
||||
self.last_gripper_state = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("intervention_action_processor")
|
||||
class InterventionActionProcessorStep(ProcessorStep):
|
||||
"""Handle human intervention actions and episode termination."""
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return transition
|
||||
|
||||
# Get intervention signals from complementary data
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
|
||||
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
|
||||
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
|
||||
success = info.get(TeleopEvents.SUCCESS, False)
|
||||
rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False)
|
||||
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Override action if intervention is active
|
||||
if is_intervention and teleop_action is not None:
|
||||
if isinstance(teleop_action, dict):
|
||||
# Convert teleop_action dict to tensor format
|
||||
action_list = [
|
||||
teleop_action.get(f"{ACTION}.delta_x", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_y", 0.0),
|
||||
teleop_action.get(f"{ACTION}.delta_z", 0.0),
|
||||
]
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
|
||||
elif isinstance(teleop_action, np.ndarray):
|
||||
action_list = teleop_action.tolist()
|
||||
else:
|
||||
action_list = teleop_action
|
||||
|
||||
teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device)
|
||||
new_transition[TransitionKey.ACTION] = teleop_action_tensor
|
||||
|
||||
# Handle episode termination
|
||||
new_transition[TransitionKey.DONE] = bool(terminate_episode) or (
|
||||
self.terminate_on_success and success
|
||||
)
|
||||
new_transition[TransitionKey.REWARD] = float(success)
|
||||
|
||||
# Update info with intervention metadata
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info[TeleopEvents.IS_INTERVENTION] = is_intervention
|
||||
info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode
|
||||
info[TeleopEvents.SUCCESS] = success
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
# Update complementary data with teleop action
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"use_gripper": self.use_gripper,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("reward_classifier_processor")
|
||||
class RewardClassifierProcessorStep(ProcessorStep):
|
||||
"""Apply reward classification to image observations."""
|
||||
|
||||
pretrained_path: str | None = None
|
||||
device: str = "cpu"
|
||||
success_threshold: float = 0.5
|
||||
success_reward: float = 1.0
|
||||
terminate_on_success: bool = True
|
||||
|
||||
reward_classifier: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the reward classifier after dataclass initialization."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||
self.reward_classifier.to(self.device)
|
||||
self.reward_classifier.eval()
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None or self.reward_classifier is None:
|
||||
return new_transition
|
||||
|
||||
# Extract images from observation
|
||||
images = {key: value for key, value in observation.items() if "image" in key}
|
||||
|
||||
if not images:
|
||||
return new_transition
|
||||
|
||||
# Run reward classifier
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold)
|
||||
|
||||
classifier_frequency = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
# Calculate reward and termination
|
||||
reward = new_transition.get(TransitionKey.REWARD, 0.0)
|
||||
terminated = new_transition.get(TransitionKey.DONE, False)
|
||||
|
||||
if math.isclose(success, 1, abs_tol=1e-2):
|
||||
reward = self.success_reward
|
||||
if self.terminate_on_success:
|
||||
terminated = True
|
||||
|
||||
# Update transition
|
||||
new_transition[TransitionKey.REWARD] = reward
|
||||
new_transition[TransitionKey.DONE] = terminated
|
||||
|
||||
# Update info with classifier frequency
|
||||
info = new_transition.get(TransitionKey.INFO, {})
|
||||
info["reward_classifier_frequency"] = classifier_frequency
|
||||
new_transition[TransitionKey.INFO] = info
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"device": self.device,
|
||||
"success_threshold": self.success_threshold,
|
||||
"success_reward": self.success_reward,
|
||||
"terminate_on_success": self.terminate_on_success,
|
||||
}
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
109
src/lerobot/processor/joint_observations_processor.py
Normal file
109
src/lerobot/processor/joint_observations_processor.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("joint_velocity_processor")
|
||||
class JointVelocityProcessorStep(ObservationProcessorStep):
|
||||
"""Add joint velocity information to observations."""
|
||||
|
||||
dt: float = 0.1
|
||||
|
||||
last_joint_positions: torch.Tensor | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Get current joint positions (assuming they're in observation.state)
|
||||
current_positions = observation.get(OBS_STATE)
|
||||
if current_positions is None:
|
||||
# TODO(steven): if we get here, then the transform_features method will not hold
|
||||
raise ValueError(f"{OBS_STATE} is not in observation")
|
||||
|
||||
# Initialize last joint positions if not already set
|
||||
if self.last_joint_positions is None:
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
joint_velocities = torch.zeros_like(current_positions)
|
||||
else:
|
||||
# Compute velocities
|
||||
joint_velocities = (current_positions - self.last_joint_positions) / self.dt
|
||||
|
||||
self.last_joint_positions = current_positions.clone()
|
||||
|
||||
# Extend observation with velocities
|
||||
extended_state = torch.cat([current_positions, joint_velocities], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"dt": self.dt,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_joint_positions = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if OBS_STATE in features:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Double the shape to account for positions + velocities
|
||||
new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:]
|
||||
|
||||
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("current_processor")
|
||||
class MotorCurrentProcessorStep(ObservationProcessorStep):
|
||||
"""Add motor current information to observations."""
|
||||
|
||||
robot: Robot | None = None
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
# Get current values from robot state
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot is not set")
|
||||
|
||||
present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined]
|
||||
motor_currents = torch.tensor(
|
||||
[present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined]
|
||||
dtype=torch.float32,
|
||||
).unsqueeze(0)
|
||||
|
||||
current_state = observation.get(OBS_STATE)
|
||||
if current_state is None:
|
||||
return observation
|
||||
|
||||
extended_state = torch.cat([current_state, motor_currents], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
if OBS_STATE in features and self.robot is not None:
|
||||
original_feature = features[OBS_STATE]
|
||||
# Add motor current dimensions to the original state shape
|
||||
num_motors = 0
|
||||
if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined]
|
||||
num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined]
|
||||
|
||||
if num_motors > 0:
|
||||
new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:]
|
||||
features[OBS_STATE] = PolicyFeature(type=original_feature.type, shape=new_shape)
|
||||
return features
|
||||
503
src/lerobot/processor/migrate_policy_normalization.py
Normal file
503
src/lerobot/processor/migrate_policy_normalization.py
Normal file
@@ -0,0 +1,503 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Generic script to migrate any policy model with normalization layers to the new pipeline-based system.
|
||||
|
||||
This script:
|
||||
1. Loads an existing pretrained policy model
|
||||
2. Extracts normalization statistics from the model
|
||||
3. Creates both preprocessor and postprocessor:
|
||||
- Preprocessor: normalizes both inputs (observations) and outputs (actions) for training
|
||||
- Postprocessor: unnormalizes outputs (actions) for inference
|
||||
4. Removes normalization layers from the model state_dict
|
||||
5. Saves the new model and both processors
|
||||
|
||||
Usage:
|
||||
python src/lerobot/processor/migrate_policy_normalization.py \
|
||||
--pretrained-path lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--policy-type act \
|
||||
--push-to-hub
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from .pipeline import PolicyProcessorPipeline
|
||||
from .rename_processor import RenameProcessorStep
|
||||
|
||||
# Policy type to class mapping
|
||||
POLICY_CLASSES = {
|
||||
"act": "lerobot.policies.act.modeling_act.ACTPolicy",
|
||||
"diffusion": "lerobot.policies.diffusion.modeling_diffusion.DiffusionPolicy",
|
||||
"pi0": "lerobot.policies.pi0.modeling_pi0.PI0Policy",
|
||||
"pi0fast": "lerobot.policies.pi0fast.modeling_pi0fast.PI0FASTPolicy",
|
||||
"smolvla": "lerobot.policies.smolvla.modeling_smolvla.SmolVLAPolicy",
|
||||
"tdmpc": "lerobot.policies.tdmpc.modeling_tdmpc.TDMPCPolicy",
|
||||
"vqbet": "lerobot.policies.vqbet.modeling_vqbet.VQBeTPolicy",
|
||||
"sac": "lerobot.policies.sac.modeling_sac.SACPolicy",
|
||||
"classifier": "lerobot.policies.classifier.modeling_classifier.ClassifierPolicy",
|
||||
}
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Extract normalization statistics from model state_dict."""
|
||||
stats = {}
|
||||
|
||||
# Define patterns to match and their prefixes to remove
|
||||
normalization_patterns = [
|
||||
"normalize_inputs.buffer_",
|
||||
"unnormalize_outputs.buffer_",
|
||||
"normalize_targets.buffer_",
|
||||
"normalize.", # Must come after normalize_* patterns
|
||||
"unnormalize.", # Must come after unnormalize_* patterns
|
||||
"input_normalizer.",
|
||||
"output_normalizer.",
|
||||
]
|
||||
|
||||
# Process each key in state_dict
|
||||
for key, tensor in state_dict.items():
|
||||
# Try each pattern
|
||||
for pattern in normalization_patterns:
|
||||
if key.startswith(pattern):
|
||||
# Extract the remaining part after the pattern
|
||||
remaining = key[len(pattern) :]
|
||||
parts = remaining.split(".")
|
||||
|
||||
# Need at least feature name and stat type
|
||||
if len(parts) >= 2:
|
||||
# Last part is the stat type (mean, std, min, max, etc.)
|
||||
stat_type = parts[-1]
|
||||
# Everything else is the feature name
|
||||
feature_name = ".".join(parts[:-1]).replace("_", ".")
|
||||
|
||||
# Add to stats
|
||||
if feature_name not in stats:
|
||||
stats[feature_name] = {}
|
||||
stats[feature_name][stat_type] = tensor.clone()
|
||||
|
||||
# Only process the first matching pattern
|
||||
break
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def detect_features_and_norm_modes(
|
||||
config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]]
|
||||
) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]:
|
||||
"""Detect features and normalization modes from config and stats."""
|
||||
features = {}
|
||||
norm_modes = {}
|
||||
|
||||
# First, check if there's a normalization_mapping in the config
|
||||
if "normalization_mapping" in config:
|
||||
print(f"Found normalization_mapping in config: {config['normalization_mapping']}")
|
||||
# Extract normalization modes from config
|
||||
for feature_name, mode_str in config["normalization_mapping"].items():
|
||||
# Convert string to NormalizationMode enum
|
||||
if mode_str == "mean_std":
|
||||
mode = NormalizationMode.MEAN_STD
|
||||
elif mode_str == "min_max":
|
||||
mode = NormalizationMode.MIN_MAX
|
||||
else:
|
||||
print(f"Warning: Unknown normalization mode '{mode_str}' for feature '{feature_name}'")
|
||||
continue
|
||||
|
||||
# Determine feature type from feature name
|
||||
if "image" in feature_name or "visual" in feature_name:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in feature_name:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in feature_name:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
norm_modes[feature_type] = mode
|
||||
|
||||
# Try to extract from config
|
||||
if "features" in config:
|
||||
for key, feature_config in config["features"].items():
|
||||
shape = feature_config.get("shape", feature_config.get("dim"))
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||
|
||||
# Determine feature type
|
||||
if "image" in key or "visual" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE # Default
|
||||
|
||||
features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
# If no features in config, infer from stats
|
||||
if not features:
|
||||
for key, stat_dict in stats.items():
|
||||
# Get shape from any stat tensor
|
||||
tensor = next(iter(stat_dict.values()))
|
||||
shape = tuple(tensor.shape)
|
||||
|
||||
# Determine feature type based on key
|
||||
if "image" in key or "visual" in key or "pixels" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key or "joint" in key or "position" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
# If normalization modes weren't in config, determine based on available stats
|
||||
if not norm_modes:
|
||||
for key, stat_dict in stats.items():
|
||||
if key in features:
|
||||
if "mean" in stat_dict and "std" in stat_dict:
|
||||
feature_type = features[key].type
|
||||
if feature_type not in norm_modes:
|
||||
norm_modes[feature_type] = NormalizationMode.MEAN_STD
|
||||
elif "min" in stat_dict and "max" in stat_dict:
|
||||
feature_type = features[key].type
|
||||
if feature_type not in norm_modes:
|
||||
norm_modes[feature_type] = NormalizationMode.MIN_MAX
|
||||
|
||||
# Default normalization modes if not detected
|
||||
if FeatureType.VISUAL not in norm_modes:
|
||||
norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD
|
||||
if FeatureType.STATE not in norm_modes:
|
||||
norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX
|
||||
if FeatureType.ACTION not in norm_modes:
|
||||
norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD
|
||||
|
||||
return features, norm_modes
|
||||
|
||||
|
||||
def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Remove normalization layers from state_dict."""
|
||||
new_state_dict = {}
|
||||
|
||||
# Patterns to remove
|
||||
remove_patterns = [
|
||||
"normalize_inputs.",
|
||||
"unnormalize_outputs.",
|
||||
"normalize_targets.", # Added pattern for target normalization
|
||||
"normalize.",
|
||||
"unnormalize.",
|
||||
"input_normalizer.",
|
||||
"output_normalizer.",
|
||||
"normalizer.",
|
||||
]
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
should_remove = any(pattern in key for pattern in remove_patterns)
|
||||
if not should_remove:
|
||||
new_state_dict[key] = tensor
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
|
||||
"""Convert features from old format to PolicyFeature objects."""
|
||||
converted_features = {}
|
||||
|
||||
for key, feature_dict in features_dict.items():
|
||||
# Determine feature type based on key
|
||||
if "image" in key or "visual" in key:
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
# Get shape from feature dict
|
||||
shape = feature_dict.get("shape", feature_dict.get("dim"))
|
||||
shape = (shape,) if isinstance(shape, int) else tuple(shape)
|
||||
|
||||
converted_features[key] = PolicyFeature(feature_type, shape)
|
||||
|
||||
return converted_features
|
||||
|
||||
|
||||
def load_model_from_hub(
|
||||
repo_id: str, revision: str = None
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
|
||||
"""Load model state_dict and config from hub."""
|
||||
# Download files
|
||||
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
|
||||
|
||||
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
|
||||
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
|
||||
|
||||
# Load state_dict
|
||||
state_dict = load_safetensors(safetensors_path)
|
||||
|
||||
# Load config
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
with open(train_config_path) as f:
|
||||
train_config = json.load(f)
|
||||
|
||||
return state_dict, config, train_config
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Migrate policy models with normalization layers to new pipeline system"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained model (hub repo or local directory)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output directory for migrated model (default: same as pretrained-path)",
|
||||
)
|
||||
parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub")
|
||||
parser.add_argument(
|
||||
"--hub-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Hub repository ID for pushing (default: same as pretrained-path)",
|
||||
)
|
||||
parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load")
|
||||
parser.add_argument("--private", action="store_true", help="Make the hub repository private")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load model and config
|
||||
print(f"Loading model from {args.pretrained_path}...")
|
||||
if os.path.isdir(args.pretrained_path):
|
||||
# Local directory
|
||||
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
|
||||
with open(os.path.join(args.pretrained_path, "config.json")) as f:
|
||||
config = json.load(f)
|
||||
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
|
||||
train_config = json.load(f)
|
||||
else:
|
||||
# Hub repository
|
||||
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
|
||||
|
||||
# Extract normalization statistics
|
||||
print("Extracting normalization statistics...")
|
||||
stats = extract_normalization_stats(state_dict)
|
||||
|
||||
print(f"Found normalization statistics for: {list(stats.keys())}")
|
||||
|
||||
# Detect input features and normalization modes
|
||||
print("Detecting features and normalization modes...")
|
||||
features, norm_map = detect_features_and_norm_modes(config, stats)
|
||||
|
||||
print(f"Detected features: {list(features.keys())}")
|
||||
print(f"Normalization modes: {norm_map}")
|
||||
|
||||
# Remove normalization layers from state_dict
|
||||
print("Removing normalization layers from model...")
|
||||
new_state_dict = remove_normalization_layers(state_dict)
|
||||
|
||||
removed_keys = set(state_dict.keys()) - set(new_state_dict.keys())
|
||||
if removed_keys:
|
||||
print(f"Removed {len(removed_keys)} normalization layer keys")
|
||||
|
||||
# Determine output path
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir)
|
||||
else:
|
||||
if os.path.isdir(args.pretrained_path):
|
||||
output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated"
|
||||
else:
|
||||
output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Clean up config - remove normalization_mapping field
|
||||
cleaned_config = dict(config)
|
||||
if "normalization_mapping" in cleaned_config:
|
||||
print("Removing 'normalization_mapping' field from config")
|
||||
del cleaned_config["normalization_mapping"]
|
||||
policy_type = deepcopy(cleaned_config["type"])
|
||||
|
||||
del cleaned_config["type"]
|
||||
|
||||
# Instantiate the policy model with cleaned config and load the cleaned state dict
|
||||
print(f"Instantiating {policy_type} policy model...")
|
||||
policy_class_path = POLICY_CLASSES[policy_type]
|
||||
module_path, class_name = policy_class_path.rsplit(".", 1)
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
policy_class = getattr(module, class_name)
|
||||
|
||||
# Create config class instance
|
||||
config_module_path = module_path.replace("modeling", "configuration")
|
||||
config_module = importlib.import_module(config_module_path)
|
||||
# Handle special cases for config class names
|
||||
config_class_names = {
|
||||
"act": "ACTConfig",
|
||||
"diffusion": "DiffusionConfig",
|
||||
"pi0": "PI0Config",
|
||||
"pi0fast": "PI0FASTConfig",
|
||||
"smolvla": "SmolVLAConfig",
|
||||
"tdmpc": "TDMPCConfig",
|
||||
"vqbet": "VQBeTConfig",
|
||||
"sac": "SACConfig",
|
||||
"classifier": "ClassifierConfig",
|
||||
}
|
||||
config_class_name = config_class_names.get(policy_type, f"{policy_type.upper()}Config")
|
||||
config_class = getattr(config_module, config_class_name)
|
||||
|
||||
# Convert input_features and output_features to PolicyFeature objects - these are mandatory
|
||||
if "input_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'input_features' in config")
|
||||
if "output_features" not in cleaned_config:
|
||||
raise ValueError("Missing mandatory 'output_features' in config")
|
||||
|
||||
cleaned_config["input_features"] = convert_features_to_policy_features(cleaned_config["input_features"])
|
||||
cleaned_config["output_features"] = convert_features_to_policy_features(cleaned_config["output_features"])
|
||||
|
||||
# Create config instance from cleaned config dict
|
||||
policy_config = config_class(**cleaned_config)
|
||||
|
||||
# Create policy instance - some policies expect dataset_stats
|
||||
policy = policy_class(policy_config)
|
||||
|
||||
# Load the cleaned state dict
|
||||
policy.load_state_dict(new_state_dict, strict=True)
|
||||
print("Successfully loaded cleaned state dict into policy model")
|
||||
|
||||
# Now create preprocessor and postprocessor with cleaned_config available
|
||||
print("Creating preprocessor and postprocessor...")
|
||||
# The pattern from existing processor factories:
|
||||
# - Preprocessor has two NormalizerProcessorSteps: one for input_features, one for output_features
|
||||
# - Postprocessor has one UnnormalizerProcessorStep for output_features only
|
||||
|
||||
# Get features from cleaned_config (now they're PolicyFeature objects)
|
||||
input_features = cleaned_config.get("input_features", {})
|
||||
output_features = cleaned_config.get("output_features", {})
|
||||
|
||||
# Create preprocessor with two normalizers (following the pattern from processor factories)
|
||||
preprocessor_steps = [
|
||||
RenameProcessorStep(rename_map={}),
|
||||
NormalizerProcessorStep(
|
||||
features={**input_features, **output_features},
|
||||
norm_map=norm_map,
|
||||
stats=stats,
|
||||
),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=policy_config.device),
|
||||
]
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps, name="robot_preprocessor")
|
||||
|
||||
# Create postprocessor with unnormalizer for outputs only
|
||||
postprocessor_steps = [
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
UnnormalizerProcessorStep(features=output_features, norm_map=norm_map, stats=stats),
|
||||
]
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps, name="robot_postprocessor")
|
||||
|
||||
# Determine hub repo ID if pushing to hub
|
||||
if args.push_to_hub:
|
||||
if args.hub_repo_id:
|
||||
hub_repo_id = args.hub_repo_id
|
||||
else:
|
||||
if not os.path.isdir(args.pretrained_path):
|
||||
# Use same repo with "_migrated" suffix
|
||||
hub_repo_id = f"{args.pretrained_path}_migrated"
|
||||
else:
|
||||
raise ValueError("--hub-repo-id must be specified when pushing local model to hub")
|
||||
else:
|
||||
hub_repo_id = None
|
||||
|
||||
# Save preprocessor and postprocessor to root directory
|
||||
print(f"Saving preprocessor to {output_dir}...")
|
||||
preprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
preprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
print(f"Saving postprocessor to {output_dir}...")
|
||||
postprocessor.save_pretrained(output_dir)
|
||||
if args.push_to_hub:
|
||||
postprocessor.push_to_hub(repo_id=hub_repo_id, private=args.private)
|
||||
|
||||
# Save model using the policy's save_pretrained method
|
||||
print(f"Saving model to {output_dir}...")
|
||||
policy.save_pretrained(
|
||||
output_dir, push_to_hub=args.push_to_hub, repo_id=hub_repo_id, private=args.private
|
||||
)
|
||||
|
||||
# Generate and save model card
|
||||
print("Generating model card...")
|
||||
# Get metadata from original config
|
||||
dataset_repo_id = train_config.get("repo_id", "unknown")
|
||||
license = config.get("license", "apache-2.0")
|
||||
|
||||
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
|
||||
tags = set(tags).union({"robotics", "lerobot", policy_type})
|
||||
tags = list(tags)
|
||||
|
||||
# Generate model card
|
||||
card = policy.generate_model_card(
|
||||
dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags
|
||||
)
|
||||
|
||||
# Save model card locally
|
||||
card.save(str(output_dir / "README.md"))
|
||||
print(f"Model card saved to {output_dir / 'README.md'}")
|
||||
# Push model card to hub if requested
|
||||
if args.push_to_hub:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(output_dir / "README.md"),
|
||||
path_in_repo="README.md",
|
||||
repo_id=hub_repo_id,
|
||||
repo_type="model",
|
||||
commit_message="Add model card for migrated model",
|
||||
)
|
||||
print("Model card pushed to hub")
|
||||
|
||||
print("\nMigration complete!")
|
||||
print(f"Migrated model saved to: {output_dir}")
|
||||
if args.push_to_hub:
|
||||
print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,179 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
"""Convert numpy arrays and other types to torch tensors."""
|
||||
tensor_stats: dict[str, dict[str, Tensor]] = {}
|
||||
for key, sub in stats.items():
|
||||
tensor_stats[key] = {}
|
||||
for stat_name, value in sub.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
tensor_val = torch.from_numpy(value.astype(np.float32))
|
||||
elif isinstance(value, torch.Tensor):
|
||||
tensor_val = value.to(dtype=torch.float32)
|
||||
elif isinstance(value, (int, float, list, tuple)):
|
||||
tensor_val = torch.tensor(value, dtype=torch.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}")
|
||||
tensor_stats[key][stat_name] = tensor_val
|
||||
return tensor_stats
|
||||
from .converters import to_tensor
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessor:
|
||||
"""Normalizes observations and actions in a single processor step.
|
||||
class _NormalizationMixin:
|
||||
"""
|
||||
A mixin class providing core functionality for normalization and unnormalization.
|
||||
|
||||
This processor handles normalization of both observation and action tensors
|
||||
using either mean/std normalization or min/max scaling to a [-1, 1] range.
|
||||
|
||||
For each tensor key in the stats dictionary, the processor will:
|
||||
- Use mean/std normalization if those statistics are provided: (x - mean) / std
|
||||
- Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1
|
||||
|
||||
The processor can be configured to normalize only specific keys by setting
|
||||
the normalize_keys parameter.
|
||||
This class manages normalization statistics, their conversion to tensors, device placement,
|
||||
and the application of normalization transformations. It is designed to be inherited by
|
||||
concrete ProcessorStep implementations.
|
||||
"""
|
||||
|
||||
# Features and normalisation map are mandatory to match the design of normalize.py
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
|
||||
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
# Explicit subset of keys to normalise. If ``None`` every key (except
|
||||
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
||||
# membership checks O(1).
|
||||
normalize_keys: set[str] | None = None
|
||||
|
||||
device: torch.device | str | None = None
|
||||
eps: float = 1e-8
|
||||
normalize_observation_keys: set[str] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
normalize_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
) -> NormalizerProcessor:
|
||||
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
||||
|
||||
The features and norm_map parameters are mandatory to match the design
|
||||
pattern used in normalize.py.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_keys=normalize_keys,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
# Robust JSON deserialization handling (guard empty maps)
|
||||
if self.features:
|
||||
first_val = next(iter(self.features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed
|
||||
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
if self.norm_map:
|
||||
# if keys are strings (JSON), rebuild enum map
|
||||
if all(isinstance(k, str) for k in self.norm_map.keys()):
|
||||
reconstructed = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed
|
||||
|
||||
# Convert statistics once so we avoid repeated numpy→Tensor conversions
|
||||
# during runtime.
|
||||
# Convert stats to tensors and move to the target device once during initialization.
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
|
||||
# Ensure *normalize_keys* is a set for fast look-ups and compare by
|
||||
# value later when returning the configuration.
|
||||
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
|
||||
self.normalize_keys = set(self.normalize_keys)
|
||||
def to(self, device: torch.device | str) -> _NormalizationMixin:
|
||||
"""Moves the processor's normalization stats to the specified device and returns self."""
|
||||
self.device = device
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device)
|
||||
return self
|
||||
|
||||
def _normalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat: dict[str, Tensor] = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU
|
||||
return flat
|
||||
|
||||
# Decide which keys should be normalised for this call.
|
||||
if self.normalize_keys is not None:
|
||||
keys_to_norm = self.normalize_keys
|
||||
else:
|
||||
# Use feature map to skip action keys.
|
||||
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
|
||||
|
||||
processed = dict(observation)
|
||||
for key in keys_to_norm:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
def load_state_dict(self, state: dict[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
# Load to the processor's configured device.
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = (tensor - mean) / (std + self.eps)
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
return processed
|
||||
|
||||
def _normalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return (tensor - mean) / (std + self.eps)
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._normalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with normalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
config = {
|
||||
@@ -183,45 +88,87 @@ class NormalizerProcessor:
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
if self.normalize_keys is not None:
|
||||
# Serialise as a list for YAML / JSON friendliness
|
||||
config["normalize_keys"] = sorted(self.normalize_keys)
|
||||
if self.normalize_observation_keys is not None:
|
||||
config["normalize_observation_keys"] = sorted(self.normalize_observation_keys)
|
||||
return config
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]:
|
||||
new_observation = dict(observation)
|
||||
for key, feature in self.features.items():
|
||||
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
|
||||
continue
|
||||
if feature.type != FeatureType.ACTION and key in new_observation:
|
||||
tensor = torch.as_tensor(new_observation[key], dtype=torch.float32)
|
||||
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
|
||||
return new_observation
|
||||
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
|
||||
tensor = torch.as_tensor(action, dtype=torch.float32)
|
||||
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
|
||||
return processed_action
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
def _apply_transform(
|
||||
self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False
|
||||
) -> Tensor:
|
||||
"""Core logic to apply normalization or unnormalization."""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
|
||||
return tensor
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
# Ensure input tensor is on the same device as the stats.
|
||||
if self.device and tensor.device != self.device:
|
||||
tensor = tensor.to(self.device)
|
||||
|
||||
# For Accelerate compatibility: move stats to match input tensor device
|
||||
input_device = tensor.device
|
||||
stats = self._tensor_stats[key]
|
||||
tensor = tensor.to(dtype=torch.float32)
|
||||
|
||||
# Move stats to input device if needed
|
||||
stats_device = next(iter(stats.values())).device
|
||||
if stats_device != input_device:
|
||||
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
|
||||
|
||||
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
# Avoid division by zero by adding a small epsilon.
|
||||
denom = std + self.eps
|
||||
if inverse:
|
||||
return tensor * std + mean
|
||||
return (tensor - mean) / denom
|
||||
|
||||
if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
denom = max_val - min_val
|
||||
# When min_val == max_val, substitute the denominator with a small epsilon
|
||||
# to prevent division by zero. This consistently maps an input equal to
|
||||
# min_val to -1, ensuring a stable transformation.
|
||||
denom = torch.where(
|
||||
denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom
|
||||
)
|
||||
if inverse:
|
||||
# Map from [-1, 1] back to [min, max]
|
||||
return (tensor + 1) / 2 * denom + min_val
|
||||
# Map from [min, max] to [-1, 1]
|
||||
return 2 * (tensor - min_val) / denom - 1
|
||||
|
||||
# If necessary stats are missing, return input unchanged.
|
||||
return tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
class UnnormalizerProcessor:
|
||||
"""Inverse normalisation for observations and actions.
|
||||
|
||||
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
|
||||
transform.
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies normalization to observations and actions in a transition.
|
||||
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
This class directly implements the normalization logic for both observation and action
|
||||
components of an `EnvTransition`, using statistics (mean/std or min/max) provided at
|
||||
initialization.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
@@ -229,103 +176,97 @@ class UnnormalizerProcessor:
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
) -> UnnormalizerProcessor:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
def _unnormalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||
processed = dict(observation)
|
||||
for key in keys:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
continue
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = tensor * std + mean
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
return processed
|
||||
|
||||
def _unnormalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
return action
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
*,
|
||||
normalize_observation_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
device: torch.device | str | None = None,
|
||||
) -> NormalizerProcessorStep:
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_observation_keys=normalize_observation_keys,
|
||||
eps=eps,
|
||||
device=device,
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return tensor * std + mean
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with unnormalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
|
||||
# Handle observation normalization.
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(
|
||||
observation, inverse=False
|
||||
)
|
||||
|
||||
# Handle action normalization.
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False)
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep):
|
||||
"""
|
||||
A processor that applies unnormalization (the inverse of normalization) to
|
||||
observations and actions in a transition.
|
||||
|
||||
This is typically used to transform actions from a normalized policy output back into
|
||||
the original scale for execution in an environment.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
device: torch.device | str | None = None,
|
||||
) -> UnnormalizerProcessorStep:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Handle observation unnormalization.
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True)
|
||||
|
||||
# Handle action unnormalization.
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True)
|
||||
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
def hotswap_stats(
|
||||
policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]]
|
||||
) -> PolicyProcessorPipeline:
|
||||
"""
|
||||
Replaces normalization statistics in a PolicyProcessor pipeline.
|
||||
|
||||
This function creates a deep copy of the provided `PolicyProcessorPipeline` and updates the
|
||||
statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` steps within it.
|
||||
It's useful for adapting a trained policy to a new environment or dataset with
|
||||
different data distributions.
|
||||
"""
|
||||
rp = deepcopy(policy_processor)
|
||||
for step in rp.steps:
|
||||
if isinstance(step, _NormalizationMixin):
|
||||
step.stats = stats
|
||||
# Re-initialize tensor_stats on the correct device.
|
||||
step._tensor_stats = to_tensor(stats, device=step.device)
|
||||
return rp
|
||||
|
||||
@@ -22,12 +22,13 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessor(ObservationProcessor):
|
||||
class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes environment observations into the LeRobot format by handling both images and states.
|
||||
|
||||
@@ -106,9 +107,8 @@ class VanillaObservationProcessor(ObservationProcessor):
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,19 +13,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="rename_processor")
|
||||
class RenameProcessor(ObservationProcessor):
|
||||
class RenameProcessorStep(ObservationProcessorStep):
|
||||
"""Rename processor that renames keys in the observation."""
|
||||
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
@@ -43,9 +42,20 @@ class RenameProcessor(ObservationProcessor):
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||
- Keys not in `rename_map` remain unchanged.
|
||||
"""
|
||||
return {self.rename_map.get(k, k): v for k, v in features.items()}
|
||||
|
||||
|
||||
def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]:
|
||||
"""Rename keys in the stats dictionary according to rename_map (defensive copy)."""
|
||||
if not stats:
|
||||
return {}
|
||||
renamed: dict[str, dict[str, Any]] = {}
|
||||
for old_key, sub_stats in stats.items():
|
||||
new_key = rename_map.get(old_key, old_key)
|
||||
renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {}
|
||||
return renamed
|
||||
|
||||
246
src/lerobot/processor/tokenizer_processor.py
Normal file
246
src/lerobot/processor/tokenizer_processor.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Tokenizer processor for handling text tokenization in robot transitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoTokenizer
|
||||
else:
|
||||
AutoTokenizer = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="tokenizer_processor")
|
||||
class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
"""Tokenizes text tasks in complementary data using a huggingface tokenizer.
|
||||
|
||||
This processor handles tokenization of task strings found in the complementary_data
|
||||
using a specified pretrained tokenizer from Hugging Face. It adds tokenized versions
|
||||
to the observation data for model processing while preserving the original task string.
|
||||
|
||||
The processor supports both single strings and lists of strings as task inputs.
|
||||
|
||||
Args:
|
||||
tokenizer_name: Name of the pretrained tokenizer to load from Hugging Face Hub
|
||||
(e.g., "bert-base-uncased", "microsoft/DialoGPT-medium"). This will be used
|
||||
with AutoTokenizer.from_pretrained(). If tokenizer is provided, this is ignored.
|
||||
tokenizer: A tokenizer object (e.g., from transformers library) that implements
|
||||
the __call__ method. If provided, tokenizer_name is ignored. This parameter
|
||||
is not serialized and must be provided via overrides when loading.
|
||||
max_length: Maximum sequence length for tokenization. Defaults to 512.
|
||||
task_key: Key in complementary_data containing the task text. Defaults to "task".
|
||||
padding: Padding strategy for tokenization. Defaults to "max_length".
|
||||
truncation: Whether to truncate sequences longer than max_length. Defaults to True.
|
||||
|
||||
Examples:
|
||||
Using tokenizer name (auto-loaded):
|
||||
```python
|
||||
processor = TokenizerProcessorStep(tokenizer_name="bert-base-uncased", max_length=128)
|
||||
```
|
||||
|
||||
Using custom tokenizer object:
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
custom_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
processor = TokenizerProcessorStep(tokenizer=custom_tokenizer, max_length=128)
|
||||
```
|
||||
"""
|
||||
|
||||
tokenizer_name: str | None = None
|
||||
tokenizer: Any | None = None # Otherwise transformers is not available in the core dependencies
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
|
||||
# Internal tokenizer instance (not serialized)
|
||||
input_tokenizer: Any = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize the tokenizer from the provided tokenizer or tokenizer name."""
|
||||
if not _transformers_available:
|
||||
raise ImportError(
|
||||
"The 'transformers' library is not installed. "
|
||||
"Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessorStep."
|
||||
)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
# Use provided tokenizer object directly
|
||||
self.input_tokenizer = self.tokenizer
|
||||
elif self.tokenizer_name is not None:
|
||||
if AutoTokenizer is None:
|
||||
raise ImportError("AutoTokenizer is not available")
|
||||
self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'tokenizer' or 'tokenizer_name' must be provided. "
|
||||
"Pass a tokenizer object directly or a tokenizer name to auto-load."
|
||||
)
|
||||
|
||||
def get_task(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""Extract and normalize task from complementary data.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data.
|
||||
|
||||
Returns:
|
||||
List of task strings if task is present, None otherwise.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
if self.task_key not in complementary_data:
|
||||
return None
|
||||
|
||||
task = complementary_data[self.task_key]
|
||||
if task is None:
|
||||
return None
|
||||
|
||||
# Convert to list of strings
|
||||
if isinstance(task, str):
|
||||
return [task]
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
return task
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation):
|
||||
"""Process the transition by tokenizing the task text.
|
||||
|
||||
Args:
|
||||
transition: Input transition containing complementary_data with task text.
|
||||
|
||||
Returns:
|
||||
Modified transition with tokenized task added to observation.
|
||||
|
||||
Raises:
|
||||
ValueError: If tokenizer initialization failed.
|
||||
"""
|
||||
task = self.get_task(self.transition)
|
||||
if task is None:
|
||||
return observation
|
||||
|
||||
# Tokenize the task (creates CPU tensors)
|
||||
tokenized_prompt = self._tokenize_text(task)
|
||||
|
||||
# Detect device from existing tensors in the transition
|
||||
target_device = self._detect_device(self.transition)
|
||||
|
||||
# Move tokenized tensors to match the device of other data
|
||||
if target_device is not None:
|
||||
tokenized_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_prompt.items()
|
||||
}
|
||||
|
||||
# Get or create observation dict
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Add tokenized data to observation
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
"""Detect device from existing tensors in the transition.
|
||||
|
||||
This allows the tokenized tensors to match the device of other data,
|
||||
which is especially important for multi-GPU training with Accelerate.
|
||||
|
||||
Args:
|
||||
transition: The transition to search for existing tensors.
|
||||
|
||||
Returns:
|
||||
The device of the first tensor found, or None if no tensors exist.
|
||||
"""
|
||||
# Check observation tensors first (most likely to exist)
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation:
|
||||
for value in observation.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.device
|
||||
|
||||
# Check action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if isinstance(action, torch.Tensor):
|
||||
return action.device
|
||||
|
||||
return None # No tensors found, keep on CPU
|
||||
|
||||
def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]:
|
||||
"""Tokenize text using the configured tokenizer.
|
||||
|
||||
Args:
|
||||
text: Text string or list of strings to tokenize.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tokenized output with keys like 'input_ids', 'attention_mask'.
|
||||
"""
|
||||
return self.input_tokenizer(
|
||||
text,
|
||||
max_length=self.max_length,
|
||||
truncation=self.truncation,
|
||||
padding=self.padding,
|
||||
padding_side=self.padding_side,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization.
|
||||
|
||||
Note: Only tokenizer_name is saved, not the tokenizer object itself.
|
||||
When loading, provide the tokenizer via overrides if needed.
|
||||
"""
|
||||
config = {
|
||||
"max_length": self.max_length,
|
||||
"task_key": self.task_key,
|
||||
"padding_side": self.padding_side,
|
||||
"padding": self.padding,
|
||||
"truncation": self.truncation,
|
||||
}
|
||||
|
||||
# Only include tokenizer_name if it was used (not when tokenizer object was provided)
|
||||
# TODO(steven): Consider saving the name of the _tokenizer if it was loaded
|
||||
if self.tokenizer_name is not None and self.tokenizer is None:
|
||||
config["tokenizer_name"] = self.tokenizer_name
|
||||
|
||||
return config
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Add tokenized task features to the feature contract.
|
||||
|
||||
Args:
|
||||
features: Input feature dictionary.
|
||||
|
||||
Returns:
|
||||
Updated feature dictionary with tokenized task features added.
|
||||
"""
|
||||
# Add features for tokenized output if they don't exist
|
||||
# Standard tokenizer output includes tokens and attention_mask
|
||||
|
||||
if OBS_LANGUAGE_TOKENS not in features:
|
||||
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features:
|
||||
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
@@ -59,7 +59,7 @@ lerobot-record \
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
@@ -72,10 +72,23 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RobotProcessorPipeline,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
action_to_transition,
|
||||
observation_to_transition,
|
||||
transition_to_dataset_frame,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -149,6 +162,8 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.single_task is None:
|
||||
@@ -187,6 +202,36 @@ class RecordConfig:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
""" --------------- record_loop() data flow --------------------------
|
||||
[ Robot ]
|
||||
V
|
||||
[ robot.get_observation() ] ---> raw_obs
|
||||
V
|
||||
[ robot_observation_processor ] ---> obs_transition
|
||||
V
|
||||
.-----( ACTION LOGIC )------------------.
|
||||
V V
|
||||
[ From Teleoperator ] [ From Policy ]
|
||||
| |
|
||||
| [teleop.get_action] -> raw_action | [predict_action]
|
||||
| | | |
|
||||
| V | V
|
||||
| [teleop_action_processor] | |
|
||||
| | | |
|
||||
'---> teleop_transition '---> policy_transition
|
||||
| |
|
||||
'-------------------------.-------------'
|
||||
V
|
||||
[ robot_action_processor ] --> robot_action_to_send
|
||||
V
|
||||
[ robot.send_action() ] -- (Robot Executes)
|
||||
V
|
||||
( Transitions are merged & added to Dataset )
|
||||
V
|
||||
( Rerun Log / Loop Wait )
|
||||
"""
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def record_loop(
|
||||
robot: Robot,
|
||||
@@ -195,15 +240,32 @@ def record_loop(
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
control_time_s: int | None = None,
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None, # runs after teleop
|
||||
robot_action_processor: RobotProcessorPipeline | None = None, # runs before robot
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None, # runs after robot
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
):
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=lambda tr: tr
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=lambda tr: tr, to_output=transition_to_robot_action
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
teleop_arm = teleop_keyboard = None
|
||||
if isinstance(teleop, list):
|
||||
if isinstance(teleop, list): # For LeKiwi
|
||||
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
|
||||
teleop_arm = next(
|
||||
(
|
||||
@@ -219,9 +281,20 @@ def record_loop(
|
||||
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
||||
)
|
||||
|
||||
# if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
# Reset policy and processor if they are provided
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset custom pipelines
|
||||
teleop_action_processor.reset()
|
||||
robot_action_processor.reset()
|
||||
robot_observation_processor.reset()
|
||||
|
||||
policy_transition = None
|
||||
teleop_transition = None
|
||||
obs_transition = None
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
@@ -232,51 +305,88 @@ def record_loop(
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
observation = robot.get_observation()
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
if dataset is not None:
|
||||
observation_frame = transition_to_dataset_frame(
|
||||
obs_transition, dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
|
||||
if policy is not None:
|
||||
action_values = predict_action(
|
||||
observation_frame,
|
||||
policy,
|
||||
get_safe_torch_device(policy.config.device),
|
||||
policy.config.use_amp,
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
action = teleop.get_action()
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
|
||||
|
||||
action_names = dataset.features["action"]["names"]
|
||||
policy_action = {f"action.{name}": float(action_values[i]) for i, name in enumerate(action_names)}
|
||||
policy_transition = {
|
||||
TransitionKey.ACTION: policy_action,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
elif isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
|
||||
elif isinstance(teleop, list):
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
|
||||
keyboard_action = teleop_keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
teleop_transition = teleop_action_processor(act)
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
sent_action = robot.send_action(action)
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
|
||||
if policy is not None and policy_transition is not None:
|
||||
robot_action_to_send = robot_action_processor(policy_transition)
|
||||
else:
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_ = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
frame = {**observation_frame, **action_frame}
|
||||
# If transition_to_dataset_frame is provided, use it to merge the transitions.
|
||||
merged = []
|
||||
if obs_transition is not None: # The observation from the robot
|
||||
merged.append(obs_transition)
|
||||
if teleop_transition is not None: # The action from teleop
|
||||
merged.append(teleop_transition)
|
||||
if policy_transition is not None: # The action from policy
|
||||
merged.append(policy_transition)
|
||||
frame = transition_to_dataset_frame(
|
||||
merged if len(merged) > 1 else merged[0], dataset.features
|
||||
) # Convert the observation to the dataset format
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(observation, action)
|
||||
log_rerun_data([obs_transition, teleop_transition or policy_transition])
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
@@ -296,7 +406,15 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Add next.* features that are generated during recording
|
||||
transition_features = {
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
"next.truncated": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
dataset_features = {**action_features, **obs_features, **transition_features}
|
||||
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
@@ -328,6 +446,18 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
@@ -345,6 +475,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
|
||||
@@ -45,9 +45,10 @@ from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import action_to_transition, transition_to_robot_action
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -83,13 +84,25 @@ class ReplayConfig:
|
||||
dataset: DatasetReplayConfig
|
||||
# Use vocal synthesis to read events.
|
||||
play_sounds: bool = True
|
||||
# Optional processor for actions before sending to robot
|
||||
robot_action_processor: RobotProcessorPipeline | None = None
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
@parser.wrap()
|
||||
def replay(cfg: ReplayConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Initialize robot action processor with default if not provided
|
||||
robot_action_processor = cfg.robot_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=action_to_transition,
|
||||
to_output=transition_to_robot_action, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Reset processor
|
||||
robot_action_processor.reset()
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
@@ -104,7 +117,10 @@ def replay(cfg: ReplayConfig):
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
robot.send_action(action)
|
||||
# Process action through robot action processor
|
||||
processed_action = robot_action_processor(action)
|
||||
|
||||
robot.send_action(processed_action) # type: ignore[arg-type]
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
|
||||
@@ -14,6 +14,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig
|
||||
from .config_so100_follower import SO100FollowerConfig
|
||||
from .so100_follower import SO100Follower
|
||||
from .so100_follower_end_effector import SO100FollowerEndEffector
|
||||
|
||||
@@ -39,35 +39,3 @@ class SO100FollowerConfig(RobotConfig):
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100_follower_end_effector")
|
||||
@dataclass
|
||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
||||
|
||||
# Path to URDF file for kinematics
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
urdf_path: str | None = None
|
||||
|
||||
# End-effector frame name in the URDF
|
||||
target_frame_name: str = "gripper_frame_link"
|
||||
|
||||
# Default bounds for the end-effector position (in meters)
|
||||
end_effector_bounds: dict[str, list[float]] = field(
|
||||
default_factory=lambda: {
|
||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
||||
}
|
||||
)
|
||||
|
||||
max_gripper_pos: float = 50
|
||||
|
||||
end_effector_step_sizes: dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
"x": 0.02,
|
||||
"y": 0.02,
|
||||
"z": 0.02,
|
||||
}
|
||||
)
|
||||
|
||||
460
src/lerobot/robots/so100_follower/robot_kinematic_processor.py
Normal file
460
src/lerobot/robots/so100_follower/robot_kinematic_processor.py
Normal file
@@ -0,0 +1,460 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
ActionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
EnvTransition,
|
||||
ObservationProcessorStep,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_reference_and_delta")
|
||||
@dataclass
|
||||
class EEReferenceAndDelta(ActionProcessorStep):
|
||||
"""
|
||||
Compute the desired end-effector pose from the target pose and the current pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
end_effector_step_sizes: dict
|
||||
motor_names: list[str]
|
||||
use_latched_reference: bool = (
|
||||
True # If True, latch reference on enable; if False, always use current pose
|
||||
)
|
||||
|
||||
reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_prev_enabled: bool = field(default=False, init=False, repr=False)
|
||||
_command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, action):
|
||||
new_action = action.copy()
|
||||
comp = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
if "reference_joint_positions" in comp:
|
||||
q = comp["reference_joint_positions"]
|
||||
else:
|
||||
q = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Current pose from FK on measured joints
|
||||
t_curr = self.kinematics.forward_kinematics(q)
|
||||
|
||||
enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
|
||||
tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
|
||||
ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
|
||||
tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
|
||||
wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
|
||||
wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
|
||||
wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
|
||||
|
||||
desired = None
|
||||
|
||||
if enabled:
|
||||
ref = t_curr
|
||||
if self.use_latched_reference:
|
||||
# Latched reference mode: latch reference at the rising edge
|
||||
if not self._prev_enabled or self.reference_ee_pose is None:
|
||||
self.reference_ee_pose = t_curr.copy()
|
||||
ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr
|
||||
|
||||
delta_p = np.array(
|
||||
[
|
||||
tx * self.end_effector_step_sizes["x"],
|
||||
ty * self.end_effector_step_sizes["y"],
|
||||
tz * self.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = ref[:3, :3] @ r_abs
|
||||
desired[:3, 3] = ref[:3, 3] + delta_p
|
||||
|
||||
self._command_when_disabled = desired.copy()
|
||||
else:
|
||||
# While disabled, keep sending the same command to avoid drift.
|
||||
if self._command_when_disabled is None:
|
||||
# If we've never had an enabled command yet, freeze current FK pose once.
|
||||
self._command_when_disabled = t_curr.copy()
|
||||
desired = self._command_when_disabled.copy()
|
||||
|
||||
# Write action fields
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
new_action[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
new_action[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
new_action[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
new_action[f"{ACTION}.ee.wx"] = float(tw[0])
|
||||
new_action[f"{ACTION}.ee.wy"] = float(tw[1])
|
||||
new_action[f"{ACTION}.ee.wz"] = float(tw[2])
|
||||
|
||||
self._prev_enabled = enabled
|
||||
return new_action
|
||||
|
||||
def reset(self):
|
||||
self._prev_enabled = False
|
||||
self.reference_ee_pose = None
|
||||
self._command_when_disabled = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.enabled", None)
|
||||
features.pop(f"{ACTION}.target_x", None)
|
||||
features.pop(f"{ACTION}.target_y", None)
|
||||
features.pop(f"{ACTION}.target_z", None)
|
||||
features.pop(f"{ACTION}.target_wx", None)
|
||||
features.pop(f"{ACTION}.target_wy", None)
|
||||
features.pop(f"{ACTION}.target_wz", None)
|
||||
|
||||
features[f"{ACTION}.ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("ee_bounds_and_safety")
|
||||
@dataclass
|
||||
class EEBoundsAndSafety(ActionProcessorStep):
|
||||
"""
|
||||
Clip the end-effector pose to the bounds and check for jumps.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
end_effector_bounds: dict
|
||||
max_ee_step_m: float = 0.05
|
||||
max_ee_twist_step_rad: float = 0.20
|
||||
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
raise ValueError(
|
||||
"Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action"
|
||||
)
|
||||
|
||||
pos = np.array([x, y, z], dtype=float)
|
||||
twist = np.array([wx, wy, wz], dtype=float)
|
||||
|
||||
# Clip position
|
||||
pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"])
|
||||
|
||||
# Check for jumps in position
|
||||
if self._last_pos is not None:
|
||||
dpos = pos - self._last_pos
|
||||
n = float(np.linalg.norm(dpos))
|
||||
if n > self.max_ee_step_m and n > 0:
|
||||
pos = self._last_pos + dpos * (self.max_ee_step_m / n)
|
||||
raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
|
||||
|
||||
self._last_pos = pos
|
||||
self._last_twist = twist
|
||||
|
||||
act[f"{ACTION}.ee.x"] = float(pos[0])
|
||||
act[f"{ACTION}.ee.y"] = float(pos[1])
|
||||
act[f"{ACTION}.ee.z"] = float(pos[2])
|
||||
act[f"{ACTION}.ee.wx"] = float(twist[0])
|
||||
act[f"{ACTION}.ee.wy"] = float(twist[1])
|
||||
act[f"{ACTION}.ee.wz"] = float(twist[2])
|
||||
return act
|
||||
|
||||
def reset(self):
|
||||
self._last_pos = None
|
||||
self._last_twist = None
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# check if features as f"{ACTION}.ee.{x,y,z,wx,wy,wz}"
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints")
|
||||
@dataclass
|
||||
class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
"""
|
||||
Compute the desired joint positions from the desired end-effector pose.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"complementary_data.raw_joint_positions": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.joint_name_1.pos": float,
|
||||
"action.joint_name_2.pos": float,
|
||||
...
|
||||
"action.joint_name_n.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
initial_guess_current_joints: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
z = act.get(f"{ACTION}.ee.z", None)
|
||||
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return new_transition
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
if raw is None:
|
||||
raise ValueError(
|
||||
"raw_joint_positions is not in complementary data and is required for EEReferenceAndDelta"
|
||||
)
|
||||
|
||||
if self.initial_guess_current_joints: # Use current joints as initial guess
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
else: # Use previous ik solution as initial guess
|
||||
if self.q_curr is None:
|
||||
self.q_curr = np.array([float(raw[n]) for n in self.motor_names], dtype=float)
|
||||
|
||||
# Build desired 4x4 transform from pos + rotvec (twist)
|
||||
t_des = np.eye(4, dtype=float)
|
||||
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
t_des[:3, 3] = [x, y, z]
|
||||
|
||||
# Compute inverse kinematics
|
||||
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
|
||||
self.q_curr = q_target
|
||||
|
||||
new_act = dict(act)
|
||||
for i, name in enumerate(self.motor_names):
|
||||
if name == "gripper":
|
||||
# TODO(pepijn): Investigate if this is correct
|
||||
# Do we want an observation key in the action field?
|
||||
new_act[f"{ACTION}.gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
if not self.initial_guess_current_joints:
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for name in self.motor_names:
|
||||
features[f"{ACTION}.{name}.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
self.q_curr = None
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
|
||||
@dataclass
|
||||
class GripperVelocityToJoint(ProcessorStep):
|
||||
"""
|
||||
Convert the gripper velocity to a joint velocity.
|
||||
|
||||
Input ACTION keys:
|
||||
{
|
||||
"action.gripper": float,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.gripper.pos": float,
|
||||
}
|
||||
"""
|
||||
|
||||
motor_names: list[str]
|
||||
speed_factor: float = 20.0
|
||||
clip_min: float = 0.0
|
||||
clip_max: float = 100.0
|
||||
discrete_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if f"{ACTION}.gripper" not in act:
|
||||
raise ValueError(f"Required action key '{ACTION}.gripper' not found in transition")
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
raise ValueError(
|
||||
f"Required motor name 'gripper' not found in self.motor_names={self.motor_names}"
|
||||
)
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_action = act.get(f"{ACTION}.gripper", 1.0)
|
||||
gripper_action = gripper_action - 1.0
|
||||
gripper_action *= self.clip_max
|
||||
act[f"{ACTION}.gripper"] = gripper_action
|
||||
|
||||
# Get current gripper position from complementary data
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
|
||||
# Compute desired gripper velocity
|
||||
u = float(act.get(f"{ACTION}.gripper", 0.0))
|
||||
delta = u * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||
|
||||
new_act = dict(act)
|
||||
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
features[f"{ACTION}.gripper.pos"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{OBS_STATE}.gripper.pos"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee")
|
||||
@dataclass
|
||||
class ForwardKinematicsJointsToEE(ObservationProcessorStep):
|
||||
"""
|
||||
Compute the end-effector pose from the joint positions.
|
||||
|
||||
Input OBSERVATION keys:
|
||||
{
|
||||
"observation.state.{joint_name_1,joint_name_2,...,joint_name_n}.pos": float,
|
||||
}
|
||||
|
||||
Output OBSERVATION keys:
|
||||
{
|
||||
"observation.state.ee.{x,y,z,wx,wy,wz}" : float
|
||||
}
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
|
||||
def observation(self, obs: dict) -> dict:
|
||||
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
|
||||
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
|
||||
|
||||
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
pos = t[:3, 3]
|
||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
obs[f"{OBS_STATE}.ee.x"] = float(pos[0])
|
||||
obs[f"{OBS_STATE}.ee.y"] = float(pos[1])
|
||||
obs[f"{OBS_STATE}.ee.z"] = float(pos[2])
|
||||
obs[f"{OBS_STATE}.ee.wx"] = float(tw[0])
|
||||
obs[f"{OBS_STATE}.ee.wy"] = float(tw[1])
|
||||
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
|
||||
return obs
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||
features[f"{OBS_STATE}.ee.{k}"] = PolicyFeature(type=FeatureType.STATE, shape=(1,))
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("add_robot_observation")
|
||||
@dataclass
|
||||
class AddRobotObservationAsComplimentaryData(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Read the robot's current observation and insert it into the transition as complementary data.
|
||||
|
||||
- Joint positions are added under complementary_data["raw_joint_positions"] as a dict:
|
||||
{ "<motor_name>": <float position>, ... }
|
||||
"""
|
||||
|
||||
robot: Robot
|
||||
|
||||
def complementary_data(self, comp: dict | None) -> dict:
|
||||
new_comp = dict(comp)
|
||||
obs = (
|
||||
self.robot.get_observation()
|
||||
) # todo(steven): why not self.trtansition.get(TransitionKey.OBSERVATION)?
|
||||
|
||||
new_comp["raw_joint_positions"] = {
|
||||
k.removesuffix(".pos"): float(v)
|
||||
for k, v in obs.items()
|
||||
if isinstance(k, str) and k.endswith(".pos")
|
||||
}
|
||||
return new_comp
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -1,200 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import make_cameras_from_configs
|
||||
from lerobot.errors import DeviceNotConnectedError
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.motors import Motor, MotorNormMode
|
||||
from lerobot.motors.feetech import FeetechMotorsBus
|
||||
|
||||
from . import SO100Follower
|
||||
from .config_so100_follower import SO100FollowerEndEffectorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SO100FollowerEndEffector(SO100Follower):
|
||||
"""
|
||||
SO100Follower robot with end-effector space control.
|
||||
|
||||
This robot inherits from SO100Follower but transforms actions from
|
||||
end-effector space to joint space before sending them to the motors.
|
||||
"""
|
||||
|
||||
config_class = SO100FollowerEndEffectorConfig
|
||||
name = "so100_follower_end_effector"
|
||||
|
||||
def __init__(self, config: SO100FollowerEndEffectorConfig):
|
||||
super().__init__(config)
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES),
|
||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.config = config
|
||||
|
||||
# Initialize the kinematics module for the so100 robot
|
||||
if self.config.urdf_path is None:
|
||||
raise ValueError(
|
||||
"urdf_path must be provided in the configuration for end-effector control. "
|
||||
"Please set urdf_path in your SO100FollowerEndEffectorConfig."
|
||||
)
|
||||
|
||||
self.kinematics = RobotKinematics(
|
||||
urdf_path=self.config.urdf_path,
|
||||
target_frame_name=self.config.target_frame_name,
|
||||
)
|
||||
|
||||
# Store the bounds for end-effector position
|
||||
self.end_effector_bounds = self.config.end_effector_bounds
|
||||
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, Any]:
|
||||
"""
|
||||
Define action features for end-effector control.
|
||||
Returns dictionary with dtype, shape, and names.
|
||||
"""
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform action from end-effector space to joint space and send to motors.
|
||||
|
||||
Args:
|
||||
action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control
|
||||
or a numpy array with [delta_x, delta_y, delta_z]
|
||||
|
||||
Returns:
|
||||
The joint-space action that was sent to the motors
|
||||
"""
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Convert action to numpy array if not already
|
||||
if isinstance(action, dict):
|
||||
if all(k in action for k in ["delta_x", "delta_y", "delta_z"]):
|
||||
delta_ee = np.array(
|
||||
[
|
||||
action["delta_x"] * self.config.end_effector_step_sizes["x"],
|
||||
action["delta_y"] * self.config.end_effector_step_sizes["y"],
|
||||
action["delta_z"] * self.config.end_effector_step_sizes["z"],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
if "gripper" not in action:
|
||||
action["gripper"] = [1.0]
|
||||
action = np.append(delta_ee, action["gripper"])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}"
|
||||
)
|
||||
action = np.zeros(4, dtype=np.float32)
|
||||
|
||||
if self.current_joint_pos is None:
|
||||
# Read current joint positions
|
||||
current_joint_pos = self.bus.sync_read("Present_Position")
|
||||
self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors])
|
||||
|
||||
# Calculate current end-effector position using forward kinematics
|
||||
if self.current_ee_pos is None:
|
||||
self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos)
|
||||
|
||||
# Set desired end-effector position by adding delta
|
||||
desired_ee_pos = np.eye(4)
|
||||
desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation
|
||||
|
||||
# Add delta to position and clip to bounds
|
||||
desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3]
|
||||
if self.end_effector_bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3],
|
||||
self.end_effector_bounds["min"],
|
||||
self.end_effector_bounds["max"],
|
||||
)
|
||||
|
||||
# Compute inverse kinematics to get joint positions
|
||||
target_joint_values_in_degrees = self.kinematics.inverse_kinematics(
|
||||
self.current_joint_pos, desired_ee_pos
|
||||
)
|
||||
|
||||
# Create joint space action dictionary
|
||||
joint_action = {
|
||||
f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys())
|
||||
}
|
||||
|
||||
# Handle gripper separately if included in action
|
||||
# Gripper delta action is in the range 0 - 2,
|
||||
# We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos
|
||||
joint_action["gripper.pos"] = np.clip(
|
||||
self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos,
|
||||
5,
|
||||
self.config.max_gripper_pos,
|
||||
)
|
||||
|
||||
self.current_ee_pos = desired_ee_pos.copy()
|
||||
self.current_joint_pos = target_joint_values_in_degrees.copy()
|
||||
self.current_joint_pos[-1] = joint_action["gripper.pos"]
|
||||
|
||||
# Send joint space action to parent class
|
||||
return super().send_action(joint_action)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()}
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def reset(self):
|
||||
self.current_ee_pos = None
|
||||
self.current_joint_pos = None
|
||||
@@ -29,10 +29,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
from .so100_follower import SO100Follower
|
||||
|
||||
return SO100Follower(config)
|
||||
elif config.type == "so100_follower_end_effector":
|
||||
from .so100_follower import SO100FollowerEndEffector
|
||||
|
||||
return SO100FollowerEndEffector(config)
|
||||
elif config.type == "so101_follower":
|
||||
from .so101_follower import SO101Follower
|
||||
|
||||
@@ -69,6 +65,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot:
|
||||
raise ValueError(config.type)
|
||||
|
||||
|
||||
# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset
|
||||
def ensure_safe_goal_position(
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
|
||||
) -> dict[str, float]:
|
||||
|
||||
@@ -62,9 +62,16 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.scripts.rl.gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
@@ -91,7 +98,6 @@ from lerobot.utils.utils import (
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
|
||||
#################################################
|
||||
# Main entry point #
|
||||
#################################################
|
||||
@@ -236,7 +242,8 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(cfg=cfg.env)
|
||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||
env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
@@ -257,6 +264,12 @@ def act_with_policy(
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -274,45 +287,71 @@ def act_with_policy(
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = policy_timer.fps_last
|
||||
observation = {
|
||||
k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features
|
||||
}
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# Use the new step function
|
||||
new_transition = step_env_and_process_transition(
|
||||
env=online_env,
|
||||
transition=transition,
|
||||
action=action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
next_observation = {
|
||||
k: v
|
||||
for k, v in new_transition[TransitionKey.OBSERVATION].items()
|
||||
if k in cfg.policy.input_features
|
||||
}
|
||||
|
||||
# Teleop action is the action that was executed in the environment
|
||||
# It is either the action from the teleop device or the action from the policy
|
||||
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
|
||||
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
episode_total_steps += 1
|
||||
|
||||
# NOTE: We override the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
episode_intervention = True
|
||||
# Increment intervention steps counter
|
||||
episode_intervention_steps += 1
|
||||
|
||||
complementary_info = {
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
state=observation,
|
||||
action=executed_action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
next_state=next_observation,
|
||||
done=done,
|
||||
truncated=truncated, # TODO: (azouitine) Handle truncation properly
|
||||
complementary_info=info,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
)
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
# Update transition for next iteration
|
||||
transition = new_transition
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
@@ -347,12 +386,20 @@ def act_with_policy(
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters
|
||||
# Reset intervention counters and environment
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -75,6 +75,7 @@ from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.scripts.rl import learner_service
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
MAX_MESSAGE_SIZE,
|
||||
@@ -1048,10 +1049,8 @@ def get_observation_features(
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(
|
||||
next_observations, normalize=True
|
||||
)
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
@@ -1176,7 +1175,7 @@ def process_transitions(
|
||||
|
||||
# Add to offline buffer if it's an intervention
|
||||
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
"is_intervention"
|
||||
TeleopEvents.IS_INTERVENTION
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
|
||||
@@ -26,12 +26,13 @@ from torch.optim import Optimizer
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
@@ -140,6 +141,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
||||
)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
@@ -149,6 +153,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json")
|
||||
postprocessor.from_pretrained(
|
||||
cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
@@ -203,12 +211,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
batch = preprocessor(batch)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
train_tracker,
|
||||
policy,
|
||||
@@ -240,7 +245,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
save_checkpoint(
|
||||
checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
@@ -284,6 +291,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
policy.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -56,11 +56,17 @@ import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
|
||||
import draccus
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.processor import IdentityProcessorStep, RobotProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
action_to_transition,
|
||||
observation_to_transition,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -97,39 +103,84 @@ class TeleoperateConfig:
|
||||
teleop_time_s: float | None = None
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Optional processors for data transformation
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None # runs after teleop
|
||||
robot_action_processor: RobotProcessorPipeline | None = None # runs before robot
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None # runs after robot
|
||||
|
||||
|
||||
def teleop_loop(
|
||||
teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None
|
||||
teleop: Teleoperator,
|
||||
robot: Robot,
|
||||
fps: int,
|
||||
display_data: bool = False,
|
||||
duration: float | None = None,
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None,
|
||||
):
|
||||
# Initialize processors with defaults if not provided
|
||||
teleop_action_processor = teleop_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()], to_transition=action_to_transition, to_output=lambda tr: tr
|
||||
)
|
||||
robot_action_processor = robot_action_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=lambda tr: tr,
|
||||
to_output=transition_to_robot_action, # type: ignore[arg-type]
|
||||
)
|
||||
robot_observation_processor = robot_observation_processor or RobotProcessorPipeline(
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=lambda tr: tr,
|
||||
)
|
||||
|
||||
# Reset processors
|
||||
teleop_action_processor.reset()
|
||||
robot_action_processor.reset()
|
||||
robot_observation_processor.reset()
|
||||
|
||||
display_len = max(len(key) for key in robot.action_features)
|
||||
start = time.perf_counter()
|
||||
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
if display_data:
|
||||
observation = robot.get_observation()
|
||||
log_rerun_data(observation, action)
|
||||
|
||||
robot.send_action(action)
|
||||
# Get teleop action
|
||||
raw_action = teleop.get_action()
|
||||
|
||||
# Process teleop action through pipeline
|
||||
teleop_transition = teleop_action_processor(raw_action)
|
||||
|
||||
# Process action for robot through pipeline
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
|
||||
# Send processed action to robot (robot_action_processor.to_output should return dict[str, Any])
|
||||
robot.send_action(robot_action_to_send) # type: ignore[arg-type]
|
||||
|
||||
if display_data:
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
# Process robot observation through pipeline
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
log_rerun_data([obs_transition, teleop_transition])
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
|
||||
# Display the final robot action that was sent
|
||||
for motor, value in robot_action_to_send.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
move_cursor_up(len(robot_action_to_send) + 5)
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
print(f"{'NAME':<{display_len}} | {'NORM':>7}")
|
||||
for motor, value in action.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
|
||||
if duration is not None and time.perf_counter() - start >= duration:
|
||||
return
|
||||
|
||||
move_cursor_up(len(action) + 5)
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
@parser.wrap()
|
||||
def teleoperate(cfg: TeleoperateConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
@@ -143,7 +194,16 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s)
|
||||
teleop_loop(
|
||||
teleop=teleop,
|
||||
robot=robot,
|
||||
fps=cfg.fps,
|
||||
display_data=cfg.display_data,
|
||||
duration=cfg.teleop_time_s,
|
||||
teleop_action_processor=cfg.teleop_action_processor,
|
||||
robot_action_processor=cfg.robot_action_processor,
|
||||
robot_observation_processor=cfg.robot_observation_processor,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
from .utils import make_teleoperator_from_config
|
||||
from .utils import TeleopEvents, make_teleoperator_from_config
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
import logging
|
||||
|
||||
from ..utils import TeleopEvents
|
||||
|
||||
|
||||
class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
@@ -134,10 +136,10 @@ class KeyboardController(InputController):
|
||||
return False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -255,13 +257,13 @@ class GamepadController(InputController):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
@@ -451,11 +453,11 @@ class GamepadControllerHID(InputController):
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = "success"
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = "failure"
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
from .configuration_gamepad import GamepadTeleopConfig
|
||||
|
||||
|
||||
@@ -93,9 +94,9 @@ class GamepadTeleop(Teleoperator):
|
||||
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
|
||||
|
||||
action_dict = {
|
||||
"delta_x": gamepad_action[0],
|
||||
"delta_y": gamepad_action[1],
|
||||
"delta_z": gamepad_action[2],
|
||||
"action.delta_x": gamepad_action[0],
|
||||
"action.delta_y": gamepad_action[1],
|
||||
"action.delta_z": gamepad_action[2],
|
||||
}
|
||||
|
||||
# Default gripper action is to stay
|
||||
@@ -107,6 +108,48 @@ class GamepadTeleop(Teleoperator):
|
||||
|
||||
return action_dict
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra control events from the gamepad such as intervention status,
|
||||
episode termination, success indicators, etc.
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
if self.gamepad is None:
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: False,
|
||||
TeleopEvents.TERMINATE_EPISODE: False,
|
||||
TeleopEvents.SUCCESS: False,
|
||||
TeleopEvents.RERECORD_EPISODE: False,
|
||||
}
|
||||
|
||||
# Update gamepad state to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
# Check if intervention is active
|
||||
is_intervention = self.gamepad.should_intervene()
|
||||
|
||||
# Get episode end status
|
||||
episode_end_status = self.gamepad.get_episode_end_status()
|
||||
terminate_episode = episode_end_status in [
|
||||
TeleopEvents.RERECORD_EPISODE,
|
||||
TeleopEvents.FAILURE,
|
||||
]
|
||||
success = episode_end_status == TeleopEvents.SUCCESS
|
||||
rerecord_episode = episode_end_status == TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: is_intervention,
|
||||
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
|
||||
TeleopEvents.SUCCESS: success,
|
||||
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from the gamepad."""
|
||||
if self.gamepad is not None:
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig
|
||||
|
||||
PYNPUT_AVAILABLE = True
|
||||
@@ -167,25 +168,15 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (4,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3},
|
||||
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2, "action.gripper": 3},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (3,),
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
|
||||
"names": {"action.delta_x": 0, "action.delta_y": 1, "action.delta_z": 2},
|
||||
}
|
||||
|
||||
def _on_press(self, key):
|
||||
if hasattr(key, "char"):
|
||||
key = key.char
|
||||
self.event_queue.put((key, True))
|
||||
|
||||
def _on_release(self, key):
|
||||
if hasattr(key, "char"):
|
||||
key = key.char
|
||||
self.event_queue.put((key, False))
|
||||
|
||||
def get_action(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
@@ -226,12 +217,75 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
self.current_pressed.clear()
|
||||
|
||||
action_dict = {
|
||||
"delta_x": delta_x,
|
||||
"delta_y": delta_y,
|
||||
"delta_z": delta_z,
|
||||
"action.delta_x": delta_x,
|
||||
"action.delta_y": delta_y,
|
||||
"action.delta_z": delta_z,
|
||||
}
|
||||
|
||||
if self.config.use_gripper:
|
||||
action_dict["gripper"] = gripper_action
|
||||
|
||||
return action_dict
|
||||
|
||||
def get_teleop_events(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get extra control events from the keyboard such as intervention status,
|
||||
episode termination, success indicators, etc.
|
||||
|
||||
Keyboard mappings:
|
||||
- Any movement keys pressed = intervention active
|
||||
- 's' key = success (terminate episode successfully)
|
||||
- 'r' key = rerecord episode (terminate and rerecord)
|
||||
- 'q' key = quit episode (terminate without success)
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- is_intervention: bool - Whether human is currently intervening
|
||||
- terminate_episode: bool - Whether to terminate the current episode
|
||||
- success: bool - Whether the episode was successful
|
||||
- rerecord_episode: bool - Whether to rerecord the episode
|
||||
"""
|
||||
if not self.is_connected:
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: False,
|
||||
TeleopEvents.TERMINATE_EPISODE: False,
|
||||
TeleopEvents.SUCCESS: False,
|
||||
TeleopEvents.RERECORD_EPISODE: False,
|
||||
}
|
||||
|
||||
# Check if any movement keys are currently pressed (indicates intervention)
|
||||
movement_keys = [
|
||||
keyboard.Key.up,
|
||||
keyboard.Key.down,
|
||||
keyboard.Key.left,
|
||||
keyboard.Key.right,
|
||||
keyboard.Key.shift,
|
||||
keyboard.Key.shift_r,
|
||||
keyboard.Key.ctrl_r,
|
||||
keyboard.Key.ctrl_l,
|
||||
]
|
||||
is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
|
||||
|
||||
# Check for episode control commands from misc_keys_queue
|
||||
terminate_episode = False
|
||||
success = False
|
||||
rerecord_episode = False
|
||||
|
||||
# Process any pending misc keys
|
||||
while not self.misc_keys_queue.empty():
|
||||
key = self.misc_keys_queue.get_nowait()
|
||||
if key == "s":
|
||||
success = True
|
||||
elif key == "r":
|
||||
terminate_episode = True
|
||||
rerecord_episode = True
|
||||
elif key == "q":
|
||||
terminate_episode = True
|
||||
success = False
|
||||
|
||||
return {
|
||||
TeleopEvents.IS_INTERVENTION: is_intervention,
|
||||
TeleopEvents.TERMINATE_EPISODE: terminate_episode,
|
||||
TeleopEvents.SUCCESS: success,
|
||||
TeleopEvents.RERECORD_EPISODE: rerecord_episode,
|
||||
}
|
||||
|
||||
18
src/lerobot/teleoperators/phone/__init__.py
Normal file
18
src/lerobot/teleoperators/phone/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_phone import PhoneConfig
|
||||
from .phone import Phone
|
||||
36
src/lerobot/teleoperators/phone/config_phone.py
Normal file
36
src/lerobot/teleoperators/phone/config_phone.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
class PhoneOS(Enum):
|
||||
ANDROID = "android"
|
||||
IOS = "ios"
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("phone")
|
||||
@dataclass
|
||||
class PhoneConfig(TeleoperatorConfig):
|
||||
phone_os: PhoneOS = PhoneOS.IOS
|
||||
camera_offset = np.array(
|
||||
[0.0, -0.02, 0.04]
|
||||
) # iPhone 14 Pro camera is 2cm off center and 4cm above center
|
||||
97
src/lerobot/teleoperators/phone/phone_processor.py
Normal file
97
src/lerobot/teleoperators/phone/phone_processor.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.processor import ActionProcessorStep, ProcessorStepRegistry
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneOS
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_phone_action_to_robot_action")
|
||||
@dataclass
|
||||
class MapPhoneActionToRobotAction(ActionProcessorStep):
|
||||
"""
|
||||
Map calibrated phone pose (actions) to the inputs for robot actions
|
||||
|
||||
Expected input ACTION keys:
|
||||
{
|
||||
"action.phone.enabled": bool,
|
||||
"action.phone.pos": np.ndarray,
|
||||
"action.phone.rot": Rotation,
|
||||
"action.phone.raw_inputs": dict,
|
||||
}
|
||||
|
||||
Output ACTION keys:
|
||||
{
|
||||
"action.enabled": bool,
|
||||
"action.ee.{x,y,z,wx,wy,wz}" : float
|
||||
"action.gripper": float,
|
||||
}
|
||||
"""
|
||||
|
||||
platform: PhoneOS
|
||||
_enabled_prev: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def action(self, act: dict) -> dict:
|
||||
# Pop them from the action
|
||||
enabled = bool(act.pop(f"{ACTION}.phone.enabled", 0))
|
||||
pos = act.pop(f"{ACTION}.phone.pos", None)
|
||||
rot = act.pop(f"{ACTION}.phone.rot", None)
|
||||
inputs = act.pop(f"{ACTION}.phone.raw_inputs", {})
|
||||
|
||||
if pos is None or rot is None:
|
||||
raise ValueError("pos and rot must be present in action")
|
||||
|
||||
rotvec = rot.as_rotvec() # Absolute orientation as rotvec
|
||||
|
||||
# Map certain inputs to certain actions
|
||||
if self.platform == PhoneOS.IOS:
|
||||
gripper = float(inputs.get("a3", 0.0))
|
||||
else:
|
||||
a = float(inputs.get("reservedButtonA", 0.0))
|
||||
b = float(inputs.get("reservedButtonB", 0.0))
|
||||
gripper = (
|
||||
a - b
|
||||
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
|
||||
|
||||
# For some actions we need to invert the axis
|
||||
act[f"{ACTION}.enabled"] = enabled
|
||||
act[f"{ACTION}.target_x"] = -pos[1] if enabled else 0.0
|
||||
act[f"{ACTION}.target_y"] = pos[0] if enabled else 0.0
|
||||
act[f"{ACTION}.target_z"] = pos[2] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wx"] = rotvec[1] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wy"] = rotvec[0] if enabled else 0.0
|
||||
act[f"{ACTION}.target_wz"] = -rotvec[2] if enabled else 0.0
|
||||
act[f"{ACTION}.gripper"] = gripper # Still send gripper action when disabled
|
||||
return act
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.phone.enabled", None)
|
||||
features.pop(f"{ACTION}.phone.pos", None)
|
||||
features.pop(f"{ACTION}.phone.rot", None)
|
||||
features.pop(f"{ACTION}.phone.raw_inputs", None)
|
||||
|
||||
features[f"{ACTION}.enabled"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_z"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[f"{ACTION}.gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
return features
|
||||
358
src/lerobot/teleoperators/phone/teleop_phone.py
Normal file
358
src/lerobot/teleoperators/phone/teleop_phone.py
Normal file
@@ -0,0 +1,358 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Docs:
|
||||
# hebi: https://docs.hebi.us/tools.html#mobile-io
|
||||
# teleop: https://github.com/SpesRobotics/teleop
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import hebi
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
from teleop import Teleop
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasePhone:
|
||||
_enabled: bool = False
|
||||
_calib_pos: np.ndarray | None = None
|
||||
_calib_rot_inv: Rotation | None = None
|
||||
|
||||
def _reapply_position_calibration(self, pos: np.ndarray) -> None:
|
||||
self._calib_pos = pos.copy()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return (self._calib_pos is not None) and (self._calib_rot_inv is not None)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {
|
||||
"phone.pos": np.ndarray, # shape (3,)
|
||||
"phone.rot": Rotation, # scipy.spatial.transform.Rotation
|
||||
"phone.raw_inputs": dict, # analogs/buttons or webXR meta
|
||||
"phone.enabled": bool,
|
||||
}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
# No haptic or other feedback implemented yet
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
# No additional configuration required for phone teleop
|
||||
pass
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# We could add haptic feedback (vibrations) here, but it's not implemented yet
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IOSPhone(BasePhone, Teleoperator):
|
||||
name = "ios_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._group = None
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._group is not None
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
|
||||
lookup = hebi.Lookup()
|
||||
time.sleep(2.0)
|
||||
group = lookup.get_group_from_names(["HEBI"], ["mobileIO"])
|
||||
if group is None:
|
||||
raise RuntimeError("Mobile I/O not found — check name/family settings in the app.")
|
||||
self._group = group
|
||||
logger.info(f"{self} connected to HEBI group with {group.size} module(s).")
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n")
|
||||
position, rotation = self._wait_for_capture_trigger()
|
||||
self._calib_pos = position.copy()
|
||||
self._calib_rot_inv = rotation.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
while True:
|
||||
has_pose, position, rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
io = getattr(fb_pose, "io", None)
|
||||
button_b = getattr(io, "b", None) if io is not None else None
|
||||
button_b1_pressed = False
|
||||
if button_b is not None:
|
||||
button_b1_pressed = bool(button_b.get_int(1))
|
||||
if button_b1_pressed:
|
||||
return position, rotation
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
fbk = self._group.get_next_feedback()
|
||||
pose = fbk[0]
|
||||
ar_pos = getattr(pose, "ar_position", None)
|
||||
ar_quat = getattr(pose, "ar_orientation", None)
|
||||
if ar_pos is None or ar_quat is None:
|
||||
return False, None, None, None
|
||||
# HEBI provides orientation in w, x, y, z format.
|
||||
# Scipy's Rotation expects x, y, z, w.
|
||||
quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw
|
||||
rot = Rotation.from_quat(quat_xyzw)
|
||||
pos = ar_pos - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
|
||||
def get_action(self) -> dict:
|
||||
has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
io = getattr(fb_pose, "io", None)
|
||||
if io is not None:
|
||||
bank_a, bank_b = io.a, io.b
|
||||
if bank_a:
|
||||
for ch in range(1, 9):
|
||||
if bank_a.has_float(ch):
|
||||
raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch))
|
||||
if bank_b:
|
||||
for ch in range(1, 9):
|
||||
if bank_b.has_int(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch))
|
||||
elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch):
|
||||
raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch))
|
||||
|
||||
enable = bool(raw_inputs.get("b1", 0))
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_position)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_position - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rotation
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
return {
|
||||
"phone.pos": pos_cal,
|
||||
"phone.rot": rot_cal,
|
||||
"phone.raw_inputs": raw_inputs,
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._group = None
|
||||
|
||||
|
||||
class AndroidPhone(BasePhone, Teleoperator):
|
||||
name = "android_phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self._teleop = None
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
self._latest_message = None
|
||||
self._android_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._teleop is not None
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Starting teleop stream for Android...")
|
||||
self._teleop = Teleop()
|
||||
self._teleop.subscribe(self._android_callback)
|
||||
self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True)
|
||||
self._teleop_thread.start()
|
||||
logger.info(f"{self} connected, teleop stream started.")
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
print(
|
||||
"Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)"
|
||||
)
|
||||
print("Touch and move on the WebXR page to capture this pose...\n")
|
||||
|
||||
pos, rot = self._wait_for_capture_trigger()
|
||||
self._calib_pos = pos.copy()
|
||||
self._calib_rot_inv = rot.inv()
|
||||
self._enabled = False
|
||||
print("Calibration done\n")
|
||||
|
||||
def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]:
|
||||
"""Wait trigger for calibration: iOS: B1. Android: 'move'."""
|
||||
while True:
|
||||
with self._android_lock:
|
||||
msg = self._latest_message or {}
|
||||
|
||||
if bool(msg.get("move", False)):
|
||||
ok, pos, rot, _pose = self._read_current_pose()
|
||||
if ok:
|
||||
return pos, rot
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]:
|
||||
with self._android_lock:
|
||||
if self._latest_pose is None:
|
||||
return False, None, None, None
|
||||
p = self._latest_pose.copy()
|
||||
pose = self._latest_pose
|
||||
rot = Rotation.from_matrix(p[:3, :3])
|
||||
pos = p[:3, 3] - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
|
||||
def _android_callback(self, pose: np.ndarray, message: dict) -> None:
|
||||
with self._android_lock:
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
|
||||
def get_action(self) -> dict:
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
|
||||
# Collect raw inputs (B1 / analogs on iOS, move/scale on Android)
|
||||
raw_inputs: dict[str, float | int | bool] = {}
|
||||
msg = self._latest_message or {}
|
||||
raw_inputs["move"] = bool(msg.get("move", False))
|
||||
raw_inputs["scale"] = float(msg.get("scale", 1.0))
|
||||
raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False))
|
||||
raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False))
|
||||
|
||||
enable = bool(raw_inputs.get("move", False))
|
||||
|
||||
# Rising edge then re-capture calibration immediately from current raw pose
|
||||
if enable and not self._enabled:
|
||||
self._reapply_position_calibration(raw_pos)
|
||||
|
||||
# Apply calibration
|
||||
pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos)
|
||||
rot_cal = self._calib_rot_inv * raw_rot
|
||||
|
||||
self._enabled = enable
|
||||
|
||||
return {
|
||||
"phone.pos": pos_cal,
|
||||
"phone.rot": rot_cal,
|
||||
"phone.raw_inputs": raw_inputs,
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._teleop = None
|
||||
if self._teleop_thread and self._teleop_thread.is_alive():
|
||||
self._teleop_thread.join(timeout=1.0)
|
||||
self._teleop_thread = None
|
||||
self._latest_pose = None
|
||||
|
||||
|
||||
class Phone(Teleoperator):
|
||||
"""
|
||||
Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API).
|
||||
For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs.
|
||||
|
||||
Press and hold **B1** to enable teleoperation. While enabled, the first B1 press
|
||||
captures a reference pose and rotation, when disabled and pressed again the position is reapplied.
|
||||
"""
|
||||
|
||||
config_class = PhoneConfig
|
||||
name = "phone"
|
||||
|
||||
def __init__(self, config: PhoneConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self._phone_impl: Teleoperator
|
||||
|
||||
if self.config.phone_os == PhoneOS.IOS:
|
||||
self._phone_impl = IOSPhone(config)
|
||||
elif self.config.phone_os == PhoneOS.ANDROID:
|
||||
self._phone_impl = AndroidPhone(config)
|
||||
else:
|
||||
raise ValueError(f"Invalid config phone_os: {self.config.phone_os}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._phone_impl.is_connected
|
||||
|
||||
def connect(self) -> None:
|
||||
return self._phone_impl.connect()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
return self._phone_impl.calibrate()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._phone_impl.is_calibrated
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.action_features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return self._phone_impl.feedback_features
|
||||
|
||||
def configure(self) -> None:
|
||||
return self._phone_impl.configure()
|
||||
|
||||
def get_action(self) -> dict:
|
||||
return self._phone_impl.get_action()
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
return self._phone_impl.send_feedback(feedback)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
return self._phone_impl.disconnect()
|
||||
@@ -12,10 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
|
||||
|
||||
class TeleopEvents(Enum):
|
||||
"""Shared constants for teleoperator events across teleoperators."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
RERECORD_EPISODE = "rerecord_episode"
|
||||
IS_INTERVENTION = "is_intervention"
|
||||
TERMINATE_EPISODE = "terminate_episode"
|
||||
|
||||
|
||||
def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator:
|
||||
if config.type == "keyboard":
|
||||
from .keyboard import KeyboardTeleop
|
||||
|
||||
@@ -31,6 +31,7 @@ from termcolor import colored
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyProcessorPipeline, TransitionKey
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@@ -101,6 +102,8 @@ def predict_action(
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
use_amp: bool,
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
@@ -122,10 +125,14 @@ def predict_action(
|
||||
observation["task"] = task if task else ""
|
||||
observation["robot_type"] = robot_type if robot_type else ""
|
||||
|
||||
observation = preprocessor(observation)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
|
||||
@@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
|
||||
|
||||
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
||||
_transformers_available = is_package_available("transformers")
|
||||
_gym_xarm_available = is_package_available("gym_xarm")
|
||||
_gym_aloha_available = is_package_available("gym_aloha")
|
||||
_gym_pusht_available = is_package_available("gym_pusht")
|
||||
|
||||
@@ -32,6 +32,7 @@ from lerobot.datasets.utils import load_json, write_json
|
||||
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
||||
|
||||
|
||||
@@ -74,6 +75,8 @@ def save_checkpoint(
|
||||
policy: PreTrainedPolicy,
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None = None,
|
||||
preprocessor: PolicyProcessorPipeline | None = None,
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -81,7 +84,9 @@ def save_checkpoint(
|
||||
├── pretrained_model/
|
||||
│ ├── config.json # policy config
|
||||
│ ├── model.safetensors # policy weights
|
||||
│ └── train_config.json # train config
|
||||
│ ├── train_config.json # train config
|
||||
│ ├── processor.json # processor config (if preprocessor provided)
|
||||
│ └── step_*.safetensors # processor state files (if any)
|
||||
└── training_state/
|
||||
├── optimizer_param_groups.json # optimizer param groups
|
||||
├── optimizer_state.safetensors # optimizer state
|
||||
@@ -95,10 +100,15 @@ def save_checkpoint(
|
||||
policy (PreTrainedPolicy): The policy to save.
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
if preprocessor is not None:
|
||||
preprocessor.save_pretrained(pretrained_dir)
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
||||
|
||||
|
||||
|
||||
@@ -12,12 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.processor import EnvTransition, TransitionKey
|
||||
|
||||
|
||||
def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
"""Initializes the Rerun SDK for visualizing the control loop."""
|
||||
@@ -28,19 +31,87 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]):
|
||||
for obs, val in observation.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"observation.{obs}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
if val.ndim == 1:
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v)))
|
||||
def _is_scalar(x):
|
||||
return (
|
||||
isinstance(x, numbers.Real)
|
||||
or isinstance(x, (np.integer, np.floating))
|
||||
or (isinstance(x, np.ndarray) and x.ndim == 0)
|
||||
)
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
data: list[dict[str | Any] | EnvTransition] | dict[str | Any] | EnvTransition | None = None,
|
||||
*,
|
||||
observation: dict[str, Any] | None = None,
|
||||
action: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
items = data if isinstance(data, list) else ([data] if data is not None else [])
|
||||
|
||||
obs = {} if observation is None else dict(observation)
|
||||
act = {} if action is None else dict(action)
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
if any(isinstance(k, TransitionKey) for k in item.keys()):
|
||||
o = item.get(TransitionKey.OBSERVATION) or {}
|
||||
a = item.get(TransitionKey.ACTION) or {}
|
||||
if isinstance(o, dict):
|
||||
obs.update(o)
|
||||
if isinstance(a, dict):
|
||||
act.update(a)
|
||||
continue
|
||||
|
||||
keys = list(item.keys())
|
||||
has_obs = any(str(k).startswith("observation.") for k in keys)
|
||||
has_act = any(str(k).startswith("action.") for k in keys)
|
||||
|
||||
if has_obs or has_act:
|
||||
if has_obs:
|
||||
obs.update(item)
|
||||
if has_act:
|
||||
act.update(item)
|
||||
else:
|
||||
# No prefixes: assume first is observation, second is action, others are observation
|
||||
if idx == 0:
|
||||
obs.update(item)
|
||||
elif idx == 1:
|
||||
act.update(item)
|
||||
else:
|
||||
rr.log(f"observation.{obs}", rr.Image(val), static=True)
|
||||
for act, val in action.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"action.{act}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))
|
||||
obs.update(item)
|
||||
|
||||
for k, v in obs.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("observation.") else f"observation.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.ndim == 1:
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
rr.log(key, rr.Image(arr), static=True)
|
||||
|
||||
for k, v in act.items():
|
||||
if v is None:
|
||||
continue
|
||||
key = k if str(k).startswith("action.") else f"action.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalar(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.ndim == 1:
|
||||
for i, vi in enumerate(v):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
else:
|
||||
# Fall back to flattening higher-dimensional arrays
|
||||
flat = v.flatten()
|
||||
for i, vi in enumerate(flat):
|
||||
rr.log(f"{key}_{i}", rr.Scalar(float(vi)))
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -23,7 +23,8 @@ from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
|
||||
@@ -37,7 +38,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
train_cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
dataset = make_dataset(train_cfg)
|
||||
dataset_stats = dataset.meta.stats
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(train_cfg.policy, dataset_stats=dataset_stats)
|
||||
policy.train()
|
||||
|
||||
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
|
||||
@@ -49,7 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
batch = preprocessor(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
if output_dict is not None:
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
output_dict["loss"] = loss
|
||||
@@ -96,7 +101,12 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
else:
|
||||
actions_queue = train_cfg.policy.n_action_repeats
|
||||
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
actions = {}
|
||||
for i in range(actions_queue):
|
||||
unnormalized_action = policy.select_action(obs).contiguous()
|
||||
action_robot = postprocessor({TransitionKey.ACTION: unnormalized_action}).get(TransitionKey.ACTION)
|
||||
actions[str(i)] = action_robot
|
||||
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
132
tests/datasets/test_dataset_utils.py
Normal file
132
tests/datasets/test_dataset_utils.py
Normal file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_merge_simple_vectors():
|
||||
g1 = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.x", "ee.y"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["ee.y", "ee.z"],
|
||||
}
|
||||
}
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
|
||||
assert "action" in out
|
||||
assert out["action"]["dtype"] == "float32"
|
||||
# Names merged with preserved order and de-dupuplication
|
||||
assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"]
|
||||
# Shape correctly recomputed from names length
|
||||
assert out["action"]["shape"] == (3,)
|
||||
|
||||
|
||||
def test_merge_multiple_groups_order_and_dedup():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}}
|
||||
g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}}
|
||||
g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}}
|
||||
|
||||
out = combine_feature_dicts(g1, g2, g3)
|
||||
|
||||
assert out["action"]["names"] == ["a", "b", "c", "d"]
|
||||
assert out["action"]["shape"] == (4,)
|
||||
|
||||
|
||||
def test_non_vector_last_wins_for_images():
|
||||
# Non-vector (images) with same name should be overwritten by the last image specified
|
||||
g1 = {
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 480, 640),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
g2 = {
|
||||
"observation.images.front": {
|
||||
"dtype": "image",
|
||||
"shape": (3, 720, 1280),
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
}
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
assert out["observation.images.front"]["shape"] == (3, 720, 1280)
|
||||
assert out["observation.images.front"]["dtype"] == "image"
|
||||
|
||||
|
||||
def test_dtype_mismatch_raises():
|
||||
g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}}
|
||||
g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}}
|
||||
|
||||
with pytest.raises(ValueError, match="dtype mismatch for 'action'"):
|
||||
_ = combine_feature_dicts(g1, g2)
|
||||
|
||||
|
||||
def test_non_dict_passthrough_last_wins():
|
||||
g1 = {"misc": 123}
|
||||
g2 = {"misc": 456}
|
||||
|
||||
out = combine_feature_dicts(g1, g2)
|
||||
# For non-dict entries the last one wins
|
||||
assert out["misc"] == 456
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
@@ -26,7 +26,7 @@ from safetensors.torch import load_file
|
||||
from lerobot import available_policies
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
@@ -39,8 +39,8 @@ from lerobot.policies.factory import (
|
||||
get_policy_class,
|
||||
make_policy,
|
||||
make_policy_config,
|
||||
make_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
@@ -151,6 +151,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(train_cfg)
|
||||
preprocessor, _ = make_pre_post_processors(train_cfg.policy, None)
|
||||
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
|
||||
assert isinstance(policy, PreTrainedPolicy)
|
||||
|
||||
@@ -224,6 +225,7 @@ def test_act_backbone_lr():
|
||||
assert cfg.policy.optimizer_lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
preprocessor, _ = make_pre_post_processors(cfg.policy, None)
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
@@ -263,108 +265,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
def test_normalize(insert_temporal_dim):
|
||||
"""
|
||||
Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise
|
||||
an exception when the forward pass is called without the stats having been provided.
|
||||
|
||||
TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as
|
||||
expected.
|
||||
"""
|
||||
|
||||
input_features = {
|
||||
"observation.image": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 96, 96),
|
||||
),
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(10,),
|
||||
),
|
||||
}
|
||||
output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
),
|
||||
}
|
||||
|
||||
norm_map = {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
dataset_stats = {
|
||||
"observation.image": {
|
||||
"mean": torch.randn(3, 1, 1),
|
||||
"std": torch.randn(3, 1, 1),
|
||||
"min": torch.randn(3, 1, 1),
|
||||
"max": torch.randn(3, 1, 1),
|
||||
},
|
||||
"observation.state": {
|
||||
"mean": torch.randn(10),
|
||||
"std": torch.randn(10),
|
||||
"min": torch.randn(10),
|
||||
"max": torch.randn(10),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.randn(5),
|
||||
"std": torch.randn(5),
|
||||
"min": torch.randn(5),
|
||||
"max": torch.randn(5),
|
||||
},
|
||||
}
|
||||
|
||||
bsize = 2
|
||||
input_batch = {
|
||||
"observation.image": torch.randn(bsize, 3, 96, 96),
|
||||
"observation.state": torch.randn(bsize, 10),
|
||||
}
|
||||
output_batch = {
|
||||
"action": torch.randn(bsize, 5),
|
||||
}
|
||||
|
||||
if insert_temporal_dim:
|
||||
tdim = 4
|
||||
|
||||
for key in input_batch:
|
||||
# [2,3,96,96] -> [2,tdim,3,96,96]
|
||||
input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1)
|
||||
|
||||
for key in output_batch:
|
||||
output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1)
|
||||
|
||||
# test without stats
|
||||
normalize = Normalize(input_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
normalize(input_batch)
|
||||
|
||||
# test with stats
|
||||
normalize = Normalize(input_features, norm_map, stats=dataset_stats)
|
||||
normalize(input_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_normalize = Normalize(input_features, norm_map, stats=None)
|
||||
new_normalize.load_state_dict(normalize.state_dict())
|
||||
new_normalize(input_batch)
|
||||
|
||||
# test without stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
with pytest.raises(AssertionError):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
new_unnormalize = Unnormalize(output_features, norm_map, stats=None)
|
||||
new_unnormalize.load_state_dict(unnormalize.state_dict())
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multikey", [True, False])
|
||||
def test_multikey_construction(multikey: bool):
|
||||
"""
|
||||
@@ -464,6 +364,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
||||
https://github.com/huggingface/lerobot/pull/1127.
|
||||
NOTE: If the test don't pass and you don't change the policy, and note the dependencies version,
|
||||
and you changed your processor, you might have to update the test artifact.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
355
tests/processor/test_act_processor.py
Normal file
355
tests/processor/test_act_processor.py
Normal file
@@ -0,0 +1,355 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for ACT policy processor."""
|
||||
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
RenameProcessorStep,
|
||||
TransitionKey,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(observation=None, action=None, **kwargs):
|
||||
"""Helper function to create a transition dictionary."""
|
||||
transition = {}
|
||||
if observation is not None:
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = action
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(TransitionKey, key.upper()):
|
||||
transition[getattr(TransitionKey, key.upper())] = value
|
||||
return transition
|
||||
|
||||
|
||||
def create_default_config():
|
||||
"""Create a default ACT configuration for testing."""
|
||||
config = ACTConfig()
|
||||
config.input_features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
config.device = "cpu"
|
||||
return config
|
||||
|
||||
|
||||
def create_default_stats():
|
||||
"""Create default dataset statistics for testing."""
|
||||
return {
|
||||
OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
ACTION: {"mean": torch.zeros(4), "std": torch.ones(4)},
|
||||
}
|
||||
|
||||
|
||||
def test_make_act_processor_basic():
|
||||
"""Test basic creation of ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
assert postprocessor.name == "robot_postprocessor"
|
||||
|
||||
# Check steps in preprocessor
|
||||
assert len(preprocessor.steps) == 4
|
||||
assert isinstance(preprocessor.steps[0], RenameProcessorStep)
|
||||
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep)
|
||||
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep)
|
||||
assert isinstance(preprocessor.steps[3], DeviceProcessorStep)
|
||||
|
||||
# Check steps in postprocessor
|
||||
assert len(postprocessor.steps) == 2
|
||||
assert isinstance(postprocessor.steps[0], DeviceProcessorStep)
|
||||
assert isinstance(postprocessor.steps[1], UnnormalizerProcessorStep)
|
||||
|
||||
|
||||
def test_act_processor_normalization():
|
||||
"""Test that ACT processor correctly normalizes and unnormalizes data."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is normalized and batched
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
# Process action through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is unnormalized
|
||||
assert postprocessed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_cuda():
|
||||
"""Test ACT processor with CUDA device."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is on CUDA
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
# Process through postprocessor
|
||||
action_transition = create_transition(action=processed[TransitionKey.ACTION])
|
||||
postprocessed = postprocessor(action_transition)
|
||||
|
||||
# Check that action is back on CPU
|
||||
assert postprocessed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_accelerate_scenario():
|
||||
"""Test ACT processor in simulated Accelerate scenario (data already on GPU)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on same GPU (not moved unnecessarily)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
|
||||
def test_act_processor_multi_gpu():
|
||||
"""Test ACT processor with multi-GPU setup."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
|
||||
# Simulate data on different GPU (like in multi-GPU training)
|
||||
device = torch.device("cuda:1")
|
||||
observation = {OBS_STATE: torch.randn(1, 7).to(device)}
|
||||
action = torch.randn(1, 4).to(device)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data stays on cuda:1 (not moved to cuda:0)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device == device
|
||||
assert processed[TransitionKey.ACTION].device == device
|
||||
|
||||
|
||||
def test_act_processor_without_stats():
|
||||
"""Test ACT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Should still create processors, but normalization won't have stats
|
||||
assert preprocessor is not None
|
||||
assert postprocessor is not None
|
||||
|
||||
# Process should still work (but won't normalize without stats)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed is not None
|
||||
|
||||
|
||||
def test_act_processor_save_and_load():
|
||||
"""Test saving and loading ACT processor."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = DataProcessorPipeline.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = loaded_preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7)
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 4)
|
||||
|
||||
|
||||
def test_act_processor_device_placement_preservation():
|
||||
"""Test that ACT processor preserves device placement correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
# Test with CPU config
|
||||
config.device = "cpu"
|
||||
preprocessor, _ = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert processed[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_act_processor_mixed_precision():
|
||||
"""Test ACT processor with mixed precision (float16)."""
|
||||
config = create_default_config()
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Modify the device processor to use float16
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessorStep with one that uses float16
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessorStep):
|
||||
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
action = torch.randn(4, dtype=torch.float32)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
# Process through preprocessor
|
||||
processed = preprocessor(transition)
|
||||
|
||||
# Check that data is converted to float16
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16
|
||||
assert processed[TransitionKey.ACTION].dtype == torch.float16
|
||||
|
||||
|
||||
def test_act_processor_batch_consistency():
|
||||
"""Test that ACT processor handles different batch sizes correctly."""
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(observation, action)
|
||||
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
|
||||
|
||||
# Test already batched data
|
||||
observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8
|
||||
action_batched = torch.randn(8, 4)
|
||||
transition_batched = create_transition(observation_batched, action_batched)
|
||||
|
||||
processed_batched = preprocessor(transition_batched)
|
||||
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
|
||||
assert processed_batched[TransitionKey.ACTION].shape[0] == 8
|
||||
@@ -1,11 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import (
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
_default_batch_to_transition,
|
||||
_default_transition_to_batch,
|
||||
)
|
||||
from lerobot.processor import DataProcessorPipeline, TransitionKey
|
||||
from lerobot.processor.converters import batch_to_transition, transition_to_batch
|
||||
|
||||
|
||||
def _dummy_batch():
|
||||
@@ -24,7 +20,7 @@ def _dummy_batch():
|
||||
|
||||
def test_observation_grouping_roundtrip():
|
||||
"""Test that observation.* keys are properly grouped and ungrouped."""
|
||||
proc = RobotProcessor([])
|
||||
proc = DataProcessorPipeline([])
|
||||
batch_in = _dummy_batch()
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
@@ -48,7 +44,7 @@ def test_observation_grouping_roundtrip():
|
||||
|
||||
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
|
||||
"""Test that batch_to_transition correctly groups observation.* keys."""
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
@@ -60,7 +56,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
"info": {"episode": 42},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation is a dict with all observation.* keys
|
||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
||||
@@ -87,7 +83,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
|
||||
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that _default_transition_to_batch correctly flattens observation dict."""
|
||||
"""Test that transition_to_batch correctly flattens observation dict."""
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
@@ -104,7 +100,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
batch = _default_transition_to_batch(transition)
|
||||
batch = transition_to_batch(transition)
|
||||
|
||||
# Check that observation.* keys are flattened back to batch
|
||||
assert "observation.image.top" in batch
|
||||
@@ -134,7 +130,7 @@ def test_no_observation_keys():
|
||||
"info": {"test": "no_obs"},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Observation should be None when no observation.* keys
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
@@ -147,7 +143,7 @@ def test_no_observation_keys():
|
||||
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] == "action_data"
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
@@ -159,7 +155,7 @@ def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
@@ -173,7 +169,7 @@ def test_minimal_batch():
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch["action"] == "minimal_action"
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
@@ -186,7 +182,7 @@ def test_empty_batch():
|
||||
"""Test behavior with empty batch."""
|
||||
batch = {}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# All fields should have defaults
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
@@ -198,7 +194,7 @@ def test_empty_batch():
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] is None
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
@@ -219,8 +215,8 @@ def test_complex_nested_observation():
|
||||
"info": {"episode_length": 200, "success": True},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
transition = batch_to_transition(batch)
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
|
||||
# Check that all observation keys are preserved
|
||||
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
||||
@@ -254,7 +250,7 @@ def test_custom_converter():
|
||||
|
||||
def to_tr(batch):
|
||||
# Custom converter that modifies the reward
|
||||
tr = _default_batch_to_transition(batch)
|
||||
tr = batch_to_transition(batch)
|
||||
# Double the reward
|
||||
reward = tr.get(TransitionKey.REWARD, 0.0)
|
||||
new_tr = tr.copy()
|
||||
@@ -262,10 +258,10 @@ def test_custom_converter():
|
||||
return new_tr
|
||||
|
||||
def to_batch(tr):
|
||||
batch = _default_transition_to_batch(tr)
|
||||
batch = transition_to_batch(tr)
|
||||
return batch
|
||||
|
||||
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 4),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user