mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
24 Commits
user/khali
...
test/night
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03829342e2 | ||
|
|
d861d97e87 | ||
|
|
3f80a52728 | ||
|
|
095856b06a | ||
|
|
563f42bdb1 | ||
|
|
8fff0fde7c | ||
|
|
04de496547 | ||
|
|
baf9b50365 | ||
|
|
a0fdbf037a | ||
|
|
c085531b17 | ||
|
|
c7c6205332 | ||
|
|
4e54be1334 | ||
|
|
fde9d08281 | ||
|
|
46044fed75 | ||
|
|
975dcad918 | ||
|
|
d0b58190da | ||
|
|
9a5ab8ffab | ||
|
|
7541d72130 | ||
|
|
0317a15bf1 | ||
|
|
f138e5948a | ||
|
|
8fef4ddab8 | ||
|
|
18d9cb5ac4 | ||
|
|
5095ab0845 | ||
|
|
dac1efd13d |
2
.github/workflows/full_tests.yml
vendored
2
.github/workflows/full_tests.yml
vendored
@@ -173,6 +173,8 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
2
.github/workflows/nightly.yml
vendored
2
.github/workflows/nightly.yml
vendored
@@ -188,7 +188,7 @@ jobs:
|
||||
- name: Verify GPU availability
|
||||
run: |
|
||||
nvidia-smi
|
||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
python -c "import torch; print(f'PyTorch version: {torch.__version__}'); print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
|
||||
- name: Run multi-GPU training tests
|
||||
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
|
||||
|
||||
25
AI_POLICY.md
Normal file
25
AI_POLICY.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# AI Usage Policy
|
||||
|
||||
The LeRobot project welcomes contributions from everyone, and we have a few guidelines regarding AI usage to ensure high code quality, clear communication, and a healthy open-source ecosystem:
|
||||
|
||||
- **Please disclose significant AI assistance.** If you used AI tools (e.g., Copilot, Claude, Cursor, ChatGPT) to generate a substantial portion of your code or text, let us know in your PR description. Transparency helps us review your changes more effectively.
|
||||
- **Own your code (The Human-in-the-Loop).** You must fully understand all the changes you are proposing. If you cannot explain what your AI-assisted code does or how it interacts with LeRobot's broader architecture, please take the time to learn and test it before submitting.
|
||||
- **Keep issues and discussions focused.** You are welcome to use AI to help draft issues or PR descriptions, but please review and edit them carefully before posting. AI can often be overly verbose; trimming the noise and getting straight to the point helps our maintainers address your needs faster.
|
||||
|
||||
Our core maintainers also use AI tools to aid their workflows, but they do so while bringing deep contextual knowledge of the LeRobot codebase to validate the output. We ask all contributors to apply that same level of rigor.
|
||||
|
||||
## Remember the Human Maintainers
|
||||
|
||||
Please remember that LeRobot is maintained by a dedicated team of humans.
|
||||
|
||||
Every discussion, issue, and pull request is read and reviewed by real people. While AI tools can generate thousands of lines of code in seconds, reviewing that code still takes human time and energy. Submitting unverified or low-effort AI output puts an unfair burden on our maintainers.
|
||||
|
||||
Today, the quality of the AI output still heavily depends on the developer driving the tool. We ask that you respect our maintainers' time by thoroughly vetting, testing, and refining your submissions.
|
||||
|
||||
## AI is Welcome Here
|
||||
|
||||
LeRobot operates at the cutting edge of AI and robotics, and many of our maintainers actively embrace AI coding assistants as valuable productivity tools. We are a pro-AI project!
|
||||
|
||||
Our reason for having an AI policy is not an anti-AI stance. Rather, it exists to ensure that AI is used to enhance human contributions, not replace them with unverified noise. It's about how the tools are used, not the tools themselves.
|
||||
|
||||
We value the unique human insight you bring to the LeRobot community. Let AI empower your workflow, but always let your own judgment take the wheel.
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community. Answering questions, helping others, reaching out, and improving the documentation are immensely valuable.
|
||||
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md).
|
||||
Whichever way you choose to contribute, please be mindful to respect our [code of conduct](./CODE_OF_CONDUCT.md) and our [AI policy](./AI_POLICY.md).
|
||||
|
||||
## Ways to Contribute
|
||||
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
|
||||
# Copy the rest of the application source code
|
||||
# Make sure to have the git-LFS files for testing
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
@@ -17,10 +17,10 @@
|
||||
title: Train RL in Simulation
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
- local: hil_collection
|
||||
title: Human In the Loop Data Collection
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
|
||||
@@ -1,237 +0,0 @@
|
||||
# Human-In-the-Loop Data Collection
|
||||
|
||||
Human-In-the-Loop (HIL) data collection lets you improve a trained policy by deploying it on a real robot while a human operator monitors and intervenes when needed. The intervention data — recovery movements and corrections — is recorded alongside autonomous segments, producing a richer training dataset that teaches the policy how to handle failures.
|
||||
|
||||
---
|
||||
|
||||
## Why Human-In-the-Loop?
|
||||
|
||||
Standard behavioral cloning trains policies on successful demonstrations only. During deployment, small errors can compound and push the robot into states never seen during training (distribution shift). HIL data collection addresses this by:
|
||||
|
||||
- Running the trained policy on the real robot
|
||||
- Having a human intervene when the robot is about to fail
|
||||
- Recording the human's recovery and correction as training data
|
||||
- Fine-tuning the policy on the combined dataset
|
||||
|
||||
This produces a policy that not only knows how to perform the task, but also how to recover when things go wrong.
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
During a HIL session, the human operator follows this loop within each episode:
|
||||
|
||||
1. **Watch** the policy run autonomously
|
||||
2. **Pause** when failure is imminent — the robot holds its position
|
||||
3. **Take control** — teleoperate the robot back to a good state (recovery), then correct the behavior
|
||||
4. **Return control to the policy** — the policy resumes autonomous execution
|
||||
5. Repeat steps 2–4 as many times as needed during the episode
|
||||
6. **End the episode** when the task is complete, save and move on to the next rollout
|
||||
|
||||
Both autonomous and human-controlled segments are recorded. The policy and human can alternate control multiple times within a single episode, and the episode continues from the current state after each handoff (no reset required just because intervention happened). This captures autonomous execution, recovery, and correction in one continuous trajectory. After collection, the combined dataset (original demonstrations + HIL data) is used to fine-tune the policy.
|
||||
|
||||
This process can be repeated iteratively: deploy, collect, fine-tune, repeat — each round targeting the current policy's failure modes.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────┐
|
||||
│ Policy v0 (trained on demos) │
|
||||
│ ↓ │
|
||||
│ HIL Collection (target current failure modes) → Fine-tune → Policy v1 │
|
||||
│ ↓ │
|
||||
│ HIL Collection (target new failure modes) → Fine-tune → Policy v2 │
|
||||
│ ↓ │
|
||||
│ ... (repeat until satisfactory performance) │
|
||||
└─────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Hardware Requirements
|
||||
|
||||
### Teleoperator Requirements
|
||||
|
||||
The HIL data collection scripts require **teleoperators with active motors** that can:
|
||||
|
||||
- Enable/disable torque programmatically
|
||||
- Move to target positions (to mirror the robot state when pausing)
|
||||
|
||||
**Compatible teleoperators:**
|
||||
|
||||
- `so101_leader` - SO-101 Leader Arm
|
||||
- `openarms_mini` - OpenArms Mini (via third-party plugin)
|
||||
|
||||
---
|
||||
|
||||
## Scripts
|
||||
|
||||
Two scripts are provided depending on your policy's inference speed:
|
||||
|
||||
| Script | Use Case | Models |
|
||||
| ---------------------------- | ------------------------------------------ | --------------------- |
|
||||
| `hil_data_collection.py` | Standard synchronous inference | ACT, Diffusion Policy |
|
||||
| `hil_data_collection_rtc.py` | Real-Time Chunking for high-latency models | Pi0, Pi0.5, SmolVLA |
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### Step 1: Pre-train a Base Policy
|
||||
|
||||
First, train a policy on your demonstration dataset:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/demo-dataset \
|
||||
--policy.type=pi0 \
|
||||
--output_dir=outputs/pretrain \
|
||||
--batch_size=32 \
|
||||
--steps=50000
|
||||
```
|
||||
|
||||
### Step 2: Collect HIL Data
|
||||
|
||||
**Standard inference (ACT, Diffusion Policy):**
|
||||
|
||||
```bash
|
||||
python examples/rac/hil_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--dataset.single_task="Pick up the cube and place it in the bowl" \
|
||||
--dataset.num_episodes=50
|
||||
```
|
||||
|
||||
**With RTC for large models (Pi0, Pi0.5, SmolVLA):**
|
||||
|
||||
For models with high inference latency, use the RTC script for smooth execution:
|
||||
|
||||
```bash
|
||||
python examples/rac/hil_data_collection_rtc.py \
|
||||
--robot.type=so100_follower \
|
||||
--teleop.type=so100_leader \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
||||
--dataset.single_task="Pick up the cube" \
|
||||
--rtc.execution_horizon=20 \
|
||||
--interpolation=true
|
||||
```
|
||||
|
||||
**Controls (Conceptual):**
|
||||
|
||||
The interaction model is:
|
||||
|
||||
- **Pause input**: pause autonomous policy execution
|
||||
- **Takeover input**: transfer control to the human operator and record intervention data
|
||||
- **Return-to-policy input**: hand control back to the policy and continue the same episode
|
||||
- **Episode control inputs**: save/re-record/stop/reset as needed
|
||||
|
||||
Exact key/pedal bindings can differ across scripts and hardware integrations. Use each script's printed controls as the source of truth for the concrete mapping on your setup.
|
||||
|
||||
**The HIL Protocol:**
|
||||
|
||||
1. Watch the policy run autonomously (teleop is idle/free)
|
||||
2. When you see imminent failure, trigger the **pause input**
|
||||
- Policy stops
|
||||
- Teleoperator moves to match robot position (torque enabled)
|
||||
- No frames recorded during pause
|
||||
3. Trigger the **takeover input** to take control
|
||||
- Teleoperator torque disabled, free to move
|
||||
- **Recovery**: Teleoperate the robot back to a good state
|
||||
- **Correction**: Correct the behavior
|
||||
- All movements are recorded
|
||||
4. Trigger the **return-to-policy input**
|
||||
- Policy resumes autonomous execution from the current state
|
||||
- You can intervene again at any time (repeat steps 2–4)
|
||||
5. End and save the episode when the task is complete (or episode time limit is reached)
|
||||
6. **Reset**: Teleop moves to robot position, you can move the robot to the starting position
|
||||
7. Start the next episode
|
||||
|
||||
**Foot Pedal Setup (Linux):**
|
||||
|
||||
If using a USB foot pedal (PCsensor FootSwitch), ensure access:
|
||||
|
||||
```bash
|
||||
sudo setfacl -m u:$USER:rw /dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd
|
||||
```
|
||||
|
||||
### Step 3: Fine-tune the Policy
|
||||
|
||||
Fine-tune on the combined demonstration + HIL data:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--output_dir=outputs/hil_finetune \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
Then deploy the fine-tuned policy and repeat from Step 2 to target its remaining failure modes.
|
||||
|
||||
---
|
||||
|
||||
## Tips for Effective HIL Collection
|
||||
|
||||
### When to Intervene
|
||||
|
||||
Intervene when you see:
|
||||
|
||||
- Robot about to make an irreversible mistake
|
||||
- Robot hesitating or showing uncertain behavior
|
||||
- Robot deviating from the expected trajectory
|
||||
|
||||
### Recovery: Teleoperating Back to a Good State
|
||||
|
||||
During recovery, teleoperate the robot back to a state where:
|
||||
|
||||
- The robot is in a familiar, in-distribution configuration
|
||||
- The current subtask can still be completed
|
||||
- The recovery trajectory itself is informative training data
|
||||
|
||||
### Quality of Corrections
|
||||
|
||||
During correction:
|
||||
|
||||
- Provide **confident, clean** trajectories
|
||||
- Complete the current subtask fully
|
||||
- Don't overcorrect or add unnecessary movements
|
||||
|
||||
---
|
||||
|
||||
## Related Work
|
||||
|
||||
This HIL data collection approach builds on ideas from interactive imitation learning, including DAgger (Ross et al., 2011), HG-DAgger (Kelly et al., 2019), RaC (Hu et al., 2025), and RECAP (Physical Intelligence, 2025). See those works for a deeper treatment of the theory behind human-in-the-loop policy improvement.
|
||||
|
||||
```bibtex
|
||||
@article{ross2011dagger,
|
||||
title={A Reduction of Imitation Learning and Structured Prediction to No-Regret Online Learning},
|
||||
author={Ross, Stéphane and Gordon, Geoffrey and Bagnell, Drew},
|
||||
journal={Proceedings of the Fourteenth International Conference on Artificial Intelligence and Statistics},
|
||||
year={2011}
|
||||
}
|
||||
|
||||
@article{kelly2019hgdagger,
|
||||
title={HG-DAgger: Interactive Imitation Learning with Human Experts},
|
||||
author={Kelly, Michael and Sidrane, Chelsea and Driggs-Campbell, Katherine and Kochenderfer, Mykel J},
|
||||
journal={arXiv preprint arXiv:1810.02890},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
@article{hu2025rac,
|
||||
title={RaC: Robot Learning for Long-Horizon Tasks by Scaling Recovery and Correction},
|
||||
author={Hu, Zheyuan and Wu, Robyn and Enock, Naveen and Li, Jasmine and Kadakia, Riya and Erickson, Zackory and Kumar, Aviral},
|
||||
journal={arXiv preprint arXiv:2509.07953},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@article{pi2025recap,
|
||||
title={π0.6: a VLA That Learns From Experience},
|
||||
author={Physical Intelligence},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
145
docs/source/rename_map.mdx
Normal file
145
docs/source/rename_map.mdx
Normal file
@@ -0,0 +1,145 @@
|
||||
# Understanding the Rename Map and Empty Cameras
|
||||
|
||||
When you train or evaluate a robot policy, your **dataset** or **environment** hands you observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** was built to expect another (e.g. `observation.images.image`, `observation.images.image2`). The rename map is how you bridge that gap without changing the policy or the data source.
|
||||
|
||||
This guide explains why it exists, how to use it in training and evaluation, and when to use **empty cameras** so you can fine-tune multi-camera policies on datasets that have fewer views.
|
||||
|
||||
---
|
||||
|
||||
## Why observation keys don’t always match
|
||||
|
||||
Policies have a fixed set of **input feature names** (often coming from a pretrained config). For example:
|
||||
|
||||
- **XVLA-base** expects three image keys: `observation.images.image`, `observation.images.image2`, `observation.images.image3`.
|
||||
- **pi0-fast-libero** might expect `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
|
||||
|
||||
Your dataset or sim might use completely different names: `observation.images.front`, `observation.images.eagle`, `observation.images.glove` (e.g. [svla_so100_sorting](https://huggingface.co/datasets/lerobot/svla_so100_sorting)). Or your eval env (e.g. LIBERO) might return `observation.images.image` and `observation.images.image2`.
|
||||
|
||||
Rather than renaming columns in the dataset or editing the policy code, LeRobot lets you pass a **rename map**: a dictionary that says “when you see this key in the data, treat it as this key for the policy.” Renaming is applied in the preprocessing pipeline so the policy always receives the keys it expects.
|
||||
|
||||
---
|
||||
|
||||
## How the rename map works
|
||||
|
||||
The rename map is a dictionary:
|
||||
|
||||
- **Keys** = observation keys as produced by your **dataset** (training) or **environment** (evaluation).
|
||||
- **Values** = the observation keys your **policy** expects.
|
||||
|
||||
Only keys listed in the map are renamed; everything else is left as-is. Under the hood, the [RenameObservationsProcessorStep](https://github.com/huggingface/lerobot/blob/main/src/lerobot/processor/rename_processor.py) runs in the preprocessor and rewrites observation keys (and keeps normalization stats aligned) so the batch matches the policy’s `input_features`.
|
||||
|
||||
You can use the same idea for **training** (dataset → policy) and **evaluation** (env → policy).
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/jadechoghari/images/resolve/main/rename-map.png"
|
||||
alt="Rename map: mapping dataset or environment observation keys to policy input keys"
|
||||
style="max-width: 100%; height: auto;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Option 1: Use a rename map (recommended)
|
||||
|
||||
You pass the mapping on the command line so dataset/env keys are renamed to what the policy expects. No need to change the policy repo or the data.
|
||||
|
||||
### Training example: XVLA on a dataset with different camera names
|
||||
|
||||
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset whose images are stored under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`. Map the dataset keys to the policy keys:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/xvla_training \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.action_mode=auto \
|
||||
--steps=20000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
|
||||
```
|
||||
|
||||
Order of entries in the map doesn’t matter; each dataset key is renamed to the corresponding policy key.
|
||||
|
||||
### Evaluation example: Policy trained on different camera names than the env
|
||||
|
||||
You trained (or downloaded) a policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but your evaluation environment (e.g. LIBERO) returns `observation.images.image` and `observation.images.image2`. Tell the eval script how to rename env keys to policy keys:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0fast-libero \
|
||||
--env.type=libero \
|
||||
... \
|
||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||
```
|
||||
|
||||
So: **key = what the env gives, value = what the policy expects.** Same convention as in training.
|
||||
|
||||
---
|
||||
|
||||
## Option 2: Change the policy config (no rename map)
|
||||
|
||||
If you prefer not to pass a rename map every time, you can **edit the policy’s `config.json`** so that its expected observation keys match your dataset or environment. For example, change the policy’s visual input keys to `observation.images.front`, `observation.images.eagle`, `observation.images.glove` to match your dataset, or to `observation.images.image` / `observation.images.image2` to match LIBERO.
|
||||
|
||||
- **Training:** If the dataset’s camera keys match the (modified) policy config, you don’t need a rename map.
|
||||
- **Evaluation:** If the env’s keys match the (modified) policy config, you don’t need a rename map for eval either.
|
||||
|
||||
The tradeoff: you’re changing the policy repo or your local checkpoint. That’s fine if you’re only ever using that one dataset or env; a rename map keeps the same policy usable across multiple data sources without touching the config.
|
||||
|
||||
---
|
||||
|
||||
## When you have fewer cameras than the policy expects: empty cameras
|
||||
|
||||
Some policies (e.g. XVLA) are built for a fixed number of image inputs (e.g. three). Your dataset might only have **two** cameras. You still want to fine-tune without changing the model architecture.
|
||||
|
||||
LeRobot supports this with **empty cameras**: the config declares extra “slots” that the policy expects, but the dataset (or env) does not provide. Those slots are filled with placeholder keys and typically zero or masked inputs so the policy can run with fewer real views.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/jadechoghari/images/resolve/main/empty_cam.png"
|
||||
alt="Empty cameras: using placeholder slots when the dataset has fewer views than the policy expects"
|
||||
style="max-width: 100%; height: auto;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
- In the policy config (e.g. [xvla-base config.json](https://huggingface.co/lerobot/xvla-base/blob/main/config.json)), `empty_cameras` is the number of these extra slots (default `0`).
|
||||
- For each slot, the config adds an observation key of the form:
|
||||
`observation.images.empty_camera_0`, `observation.images.empty_camera_1`, …
|
||||
|
||||
Example: XVLA-base has three visual inputs and `empty_cameras=0`. Your dataset has only two images. Set **`empty_cameras=1`**. Then:
|
||||
|
||||
1. The config gains a third visual key: `observation.images.empty_camera_0`.
|
||||
2. You still use the rename map (or matching config keys) for the two real cameras.
|
||||
3. The third view is treated as “empty” (no corresponding dataset key); the policy ignores or masks it as needed.
|
||||
|
||||
So you fine-tune on two observations only, and the third visual input is effectively unused. You do **not** need to add a fake third image to your dataset.
|
||||
|
||||
---
|
||||
|
||||
## Where the rename map is used in the codebase
|
||||
|
||||
- **Training** ([`lerobot_train.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_train.py)): `rename_map` is passed into `make_policy(..., rename_map=cfg.rename_map)` and into the preprocessor as `rename_observations_processor: {"rename_map": cfg.rename_map}`. Batches from the dataset are renamed before being fed to the policy.
|
||||
- **Evaluation** ([`lerobot_eval.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py)): Same idea—`rename_map` is passed to `make_policy` and to the preprocessor so env observations are renamed before the policy sees them.
|
||||
- **Processor** ([`rename_processor.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/processor/rename_processor.py)): `RenameObservationsProcessorStep` does the actual key renaming and updates feature metadata so normalization stats stay consistent with the renamed keys.
|
||||
|
||||
If you see a feature mismatch error (“Missing features” / “Extra features”), the error message suggests using `--rename_map` with a mapping from your data’s keys to the policy’s expected keys.
|
||||
|
||||
---
|
||||
|
||||
## Quick reference
|
||||
|
||||
| Goal | What to do |
|
||||
| ------------------------------------- | ---------------------------------------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys (training) | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Fewer cameras than policy expects | Set `empty_cameras` in the policy config (e.g. `1` when you have 2 real cameras and the policy expects 3). |
|
||||
| Avoid passing a rename map | Edit the policy’s `config.json` so its observation keys match your dataset or env. |
|
||||
|
||||
The rename map keeps your pipeline flexible: one policy, many data sources, no code changes—just a small dictionary on the command line or in your config.
|
||||
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Human-in-the-Loop (HIL) Data Collection with Policy Rollout.
|
||||
|
||||
Implements the RaC paradigm (Hu et al., 2025) for LeRobot with standard synchronous
|
||||
inference. For large models with high inference latency, use hil_data_collection_rtc.py.
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously
|
||||
2. Press SPACE to pause - robot holds position
|
||||
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (robot holds position, no recording)
|
||||
c - Take control (start correction, recording resumes)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/hil_data_collection.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/hil_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless, predict_action
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILConfig:
|
||||
robot: RobotConfig
|
||||
teleop: TeleoperatorConfig
|
||||
dataset: HILDatasetConfig
|
||||
policy: PreTrainedConfig | None = None
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
device: str = "cuda"
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rollout_loop(
|
||||
robot: Robot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
cfg: HILConfig,
|
||||
):
|
||||
"""Rollout loop with standard synchronous inference."""
|
||||
fps = cfg.dataset.fps
|
||||
device = get_safe_torch_device(cfg.device)
|
||||
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
frame_buffer = []
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
robot_action: dict[str, Any] = {}
|
||||
action_keys = sorted(robot.action_features.keys())
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < cfg.dataset.episode_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Transition to paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
interpolator.reset()
|
||||
|
||||
# Takeover
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
teleop_disable_torque(teleop)
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
if events["correction_active"]:
|
||||
robot_action = teleop.get_action()
|
||||
robot.send_action(robot_action)
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
elif waiting_for_takeover or events["policy_paused"]:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
else:
|
||||
# Policy execution with optional interpolation
|
||||
if interpolator.needs_new_action():
|
||||
action_values = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=cfg.dataset.single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
robot_action = make_robot_action(action_values, dataset.features)
|
||||
action_tensor = torch.tensor([robot_action[k] for k in action_keys])
|
||||
interpolator.add(action_tensor)
|
||||
|
||||
interp_action = interpolator.get()
|
||||
if interp_action is not None:
|
||||
robot_action = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
||||
robot.send_action(robot_action)
|
||||
last_action = robot_action
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
if cfg.display_data and robot_action:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_time := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_time)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def hil_collect(cfg: HILConfig) -> LeRobotDataset:
|
||||
"""Main HIL data collection function."""
|
||||
init_logging()
|
||||
logger.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="hil_collection")
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot, "cameras") and robot.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
print_controls(rtc=False)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Task: {cfg.dataset.single_task}")
|
||||
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
cfg=cfg,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
hil_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,513 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Human-in-the-Loop (HIL) Data Collection with Real-Time Chunking (RTC).
|
||||
|
||||
Implements the RaC paradigm (Hu et al., 2025) with RTC for large flow-matching models
|
||||
(Pi0, Pi0.5, SmolVLA) that have high inference latency. RTC generates action chunks
|
||||
asynchronously in a background thread for smooth robot control.
|
||||
|
||||
For fast models (ACT, Diffusion), use hil_data_collection.py instead.
|
||||
|
||||
The workflow:
|
||||
1. Policy runs autonomously with RTC
|
||||
2. Press SPACE to pause - robot holds position
|
||||
3. Press 'c' to take control - human provides RECOVERY + CORRECTION
|
||||
4. Press → to end episode (save and continue to next)
|
||||
5. Reset, then do next rollout
|
||||
|
||||
Keyboard Controls:
|
||||
SPACE - Pause policy (robot holds position, no recording)
|
||||
c - Take control (start correction, recording resumes)
|
||||
→ - End episode (save and continue to next)
|
||||
← - Re-record episode
|
||||
ESC - Stop recording and push dataset to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/hil_data_collection_rtc.py \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--policy.path=outputs/train/pi0_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/hil_rtc_dataset \
|
||||
--dataset.single_task="Pick up the cube" \
|
||||
--rtc.execution_horizon=20
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pprint import pformat
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleoperator_from_config
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILRTCConfig:
|
||||
robot: RobotConfig
|
||||
teleop: TeleoperatorConfig
|
||||
dataset: HILDatasetConfig
|
||||
policy: PreTrainedConfig | None = None
|
||||
rtc: RTCConfig = field(default_factory=lambda: RTCConfig(enabled=True, execution_horizon=20))
|
||||
interpolation_multiplier: int = 2 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
display_data: bool = True
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
device: str = "cuda"
|
||||
use_torch_compile: bool = False # First compile takes minutes, disable for real-time
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
self.rtc.enabled = True
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
class ThreadSafeRobot:
|
||||
"""Thread-safe wrapper for robot operations."""
|
||||
|
||||
def __init__(self, robot: Robot):
|
||||
self._robot = robot
|
||||
self._lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
return self._robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict) -> None:
|
||||
with self._lock:
|
||||
self._robot.send_action(action)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
return self._robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return self._robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._robot.name
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str:
|
||||
return self._robot.robot_type
|
||||
|
||||
@property
|
||||
def cameras(self):
|
||||
return getattr(self._robot, "cameras", {})
|
||||
|
||||
|
||||
def rtc_inference_thread(
|
||||
policy: PreTrainedPolicy,
|
||||
obs_holder: dict,
|
||||
hw_features: dict,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
queue_holder: dict,
|
||||
shutdown_event: Event,
|
||||
policy_active: Event,
|
||||
cfg: HILRTCConfig,
|
||||
):
|
||||
"""Background thread for RTC action chunk generation."""
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.dataset.fps
|
||||
threshold = 30
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not policy_active.is_set():
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
queue = queue_holder.get("queue")
|
||||
obs = obs_holder.get("obs")
|
||||
if queue is None or obs is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if queue.qsize() <= threshold:
|
||||
try:
|
||||
current_time = time.perf_counter()
|
||||
idx_before = queue.get_action_index()
|
||||
prev_actions = queue.get_left_over()
|
||||
|
||||
latency = latency_tracker.max()
|
||||
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
||||
|
||||
obs_batch = build_dataset_frame(hw_features, obs, prefix="observation")
|
||||
for name in obs_batch:
|
||||
obs_batch[name] = torch.from_numpy(obs_batch[name])
|
||||
if "image" in name:
|
||||
obs_batch[name] = obs_batch[name].float() / 255
|
||||
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
|
||||
obs_batch[name] = obs_batch[name].unsqueeze(0).to(cfg.device)
|
||||
|
||||
obs_batch["task"] = [cfg.dataset.single_task]
|
||||
obs_batch["robot_type"] = obs_holder.get("robot_type", "unknown")
|
||||
|
||||
preprocessed = preprocessor(obs_batch)
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = postprocessor(actions).squeeze(0)
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
queue.merge(original, processed, new_delay, idx_before)
|
||||
logger.debug(f"[RTC] Inference latency={new_latency:.2f}s, queue={queue.qsize()}")
|
||||
except Exception as e:
|
||||
logger.error(f"[RTC] Error: {e}")
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
@safe_stop_image_writer
|
||||
def rollout_loop(
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
dataset: LeRobotDataset,
|
||||
events: dict,
|
||||
cfg: HILRTCConfig,
|
||||
queue_holder: dict,
|
||||
obs_holder: dict,
|
||||
policy_active: Event,
|
||||
hw_features: dict,
|
||||
):
|
||||
"""Rollout loop with RTC for asynchronous inference."""
|
||||
fps = cfg.dataset.fps
|
||||
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
frame_buffer = []
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
|
||||
robot_action: dict[str, Any] = {}
|
||||
timestamp = 0
|
||||
start_t = time.perf_counter()
|
||||
|
||||
while timestamp < cfg.dataset.episode_time_s:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
break
|
||||
|
||||
# Transition to paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
policy_active.clear()
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
was_paused = True
|
||||
interpolator.reset()
|
||||
|
||||
# Takeover
|
||||
if waiting_for_takeover and events["start_next_episode"]:
|
||||
teleop_disable_torque(teleop)
|
||||
events["start_next_episode"] = False
|
||||
events["correction_active"] = True
|
||||
waiting_for_takeover = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_filtered, prefix=OBS_STR)
|
||||
|
||||
obs_holder["obs"] = obs_filtered
|
||||
|
||||
if events["correction_active"]:
|
||||
robot_action = teleop.get_action()
|
||||
robot.send_action(robot_action)
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
elif waiting_for_takeover or events["policy_paused"]:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
else:
|
||||
# Policy execution with RTC
|
||||
if not policy_active.is_set():
|
||||
policy_active.set()
|
||||
|
||||
queue = queue_holder["queue"]
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get() if queue else None
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action_tensor = interpolator.get()
|
||||
if action_tensor is not None:
|
||||
robot_action = {
|
||||
k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)
|
||||
}
|
||||
robot.send_action(robot_action)
|
||||
last_action = robot_action
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
frame_buffer.append({**obs_frame, **action_frame, "task": cfg.dataset.single_task})
|
||||
|
||||
if cfg.display_data and robot_action:
|
||||
log_rerun_data(observation=obs_filtered, action=robot_action)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_time := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_time)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
|
||||
policy_active.clear()
|
||||
teleop_disable_torque(teleop)
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
|
||||
"""Main HIL data collection function with RTC."""
|
||||
init_logging()
|
||||
logger.info(pformat(cfg.__dict__))
|
||||
|
||||
if cfg.display_data:
|
||||
init_rerun(session_name="hil_rtc_collection")
|
||||
|
||||
robot_raw = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
|
||||
teleop_proc, obs_proc = make_identity_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=robot_raw.action_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=robot_raw.observation_features),
|
||||
use_videos=cfg.dataset.video,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = None
|
||||
listener = None
|
||||
shutdown_event = Event()
|
||||
policy_active = Event()
|
||||
rtc_thread = None
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
if hasattr(robot_raw, "cameras") and robot_raw.cameras:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot_raw.cameras),
|
||||
)
|
||||
else:
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot_raw.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot_raw.cameras if hasattr(robot_raw, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
)
|
||||
|
||||
# Load policy with RTC
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot_raw.connect()
|
||||
robot = ThreadSafeRobot(robot_raw)
|
||||
teleop.connect()
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
queue_holder = {"queue": ActionQueue(cfg.rtc)}
|
||||
obs_holder = {"obs": None, "robot_type": robot.robot_type}
|
||||
hw_features = hw_to_dataset_features(robot_raw.observation_features, "observation")
|
||||
|
||||
rtc_thread = Thread(
|
||||
target=rtc_inference_thread,
|
||||
args=(
|
||||
policy,
|
||||
obs_holder,
|
||||
hw_features,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
queue_holder,
|
||||
shutdown_event,
|
||||
policy_active,
|
||||
cfg,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
rtc_thread.start()
|
||||
|
||||
print_controls(rtc=True)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Task: {cfg.dataset.single_task}")
|
||||
print(f" Interpolation: {cfg.interpolation_multiplier}x\n")
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded = 0
|
||||
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"Episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
|
||||
queue_holder["queue"] = ActionQueue(cfg.rtc)
|
||||
|
||||
rollout_loop(
|
||||
robot=robot,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
cfg=cfg,
|
||||
queue_holder=queue_holder,
|
||||
obs_holder=obs_holder,
|
||||
policy_active=policy_active,
|
||||
hw_features=hw_features,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", cfg.play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
|
||||
if recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
reset_loop(robot, teleop, events, cfg.dataset.fps)
|
||||
|
||||
finally:
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
|
||||
shutdown_event.set()
|
||||
policy_active.clear()
|
||||
|
||||
if rtc_thread and rtc_thread.is_alive():
|
||||
rtc_thread.join(timeout=2.0)
|
||||
|
||||
if dataset:
|
||||
dataset.finalize()
|
||||
|
||||
if robot_raw.is_connected:
|
||||
robot_raw.disconnect()
|
||||
if teleop.is_connected:
|
||||
teleop.disconnect()
|
||||
|
||||
if not is_headless() and listener:
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
hil_rtc_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,206 +0,0 @@
|
||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
"""Check if teleoperator has motor control capabilities."""
|
||||
return hasattr(teleop, "bus") and hasattr(teleop.bus, "disable_torque")
|
||||
|
||||
|
||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
"""Disable teleop torque if supported."""
|
||||
if teleop_has_motor_control(teleop):
|
||||
teleop.bus.disable_torque()
|
||||
|
||||
|
||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
"""Enable teleop torque if supported."""
|
||||
if teleop_has_motor_control(teleop):
|
||||
teleop.bus.enable_torque()
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
||||
return
|
||||
|
||||
teleop_enable_torque(teleop)
|
||||
current = teleop.get_action()
|
||||
steps = int(duration_s * fps)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.bus.sync_write("Goal_Position", {k.replace(".pos", ""): v for k, v in interp.items()})
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize keyboard listener with HIL controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logger.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
||||
print("\n[HIL] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[HIL] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[HIL] ⏸ PAUSED - Press 'c' to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[HIL] ▶ Taking control...")
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("[HIL] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("[HIL] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("[HIL] ESC - Stop recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
return listener, events
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, obs_proc
|
||||
|
||||
|
||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
print("\n" + "=" * 60)
|
||||
print(" [HIL] RESET")
|
||||
print("=" * 60)
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
print(" Press any key to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop_disable_torque(teleop)
|
||||
print(" Teleop enabled - press any key to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
|
||||
|
||||
def print_controls(rtc: bool = False):
|
||||
"""Print control instructions."""
|
||||
print("\n" + "=" * 60)
|
||||
print(" Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else ""))
|
||||
print("=" * 60)
|
||||
print()
|
||||
print(" Controls:")
|
||||
print(" SPACE - Pause policy")
|
||||
print(" c - Take control")
|
||||
print(" → - End episode")
|
||||
print(" ESC - Stop and push to hub")
|
||||
print("=" * 60 + "\n")
|
||||
@@ -83,7 +83,9 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
@@ -149,7 +151,6 @@ class RTCDemoConfig(HubMixin):
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
@@ -350,22 +351,20 @@ def actor_control(
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
# Try to get an action from the queue with timeout
|
||||
action = action_queue.get()
|
||||
|
||||
action = interpolator.get()
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
|
||||
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.4"
|
||||
version = "0.4.5"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -76,7 +76,7 @@ dependencies = [
|
||||
"pyserial>=3.5,<4.0",
|
||||
"wandb>=0.24.0,<0.25.0",
|
||||
|
||||
"torch>=2.2.1,<2.11.0", # TODO: Bump dependency
|
||||
"torch==2.10.0",
|
||||
"torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency
|
||||
"torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency
|
||||
|
||||
@@ -214,6 +214,9 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
|
||||
@@ -49,23 +49,18 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
from lerobot.robots import (
|
||||
RobotConfig, # noqa: F401
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .constants import SUPPORTED_ROBOTS
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
@@ -485,8 +480,9 @@ class RobotClient:
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
@@ -512,4 +508,5 @@ def async_client(cfg: RobotClientConfig):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_third_party_plugins()
|
||||
async_client() # run the client
|
||||
|
||||
@@ -27,7 +27,7 @@ class DatasetConfig:
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
|
||||
@@ -7,6 +7,13 @@
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
{% if repo_id is defined and repo_id %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
@@ -567,20 +567,22 @@ def _copy_and_reindex_data(
|
||||
def _keep_episodes_from_video_with_av(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[float, float]],
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
This function decodes frames from specified time ranges and re-encodes them with
|
||||
This function decodes frames from specified frame ranges and re-encodes them with
|
||||
properly reset timestamps to ensure monotonic progression.
|
||||
|
||||
Args:
|
||||
input_path: Source video file path.
|
||||
output_path: Destination video file path.
|
||||
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
|
||||
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
@@ -622,9 +624,10 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Create set of (start, end) ranges for fast lookup.
|
||||
# Convert to a sorted list for efficient checking.
|
||||
time_ranges = sorted(episodes_to_keep)
|
||||
frame_ranges = sorted(episodes_to_keep)
|
||||
|
||||
# Track frame index for setting PTS and current range being processed.
|
||||
src_frame_count = 0
|
||||
frame_count = 0
|
||||
range_idx = 0
|
||||
|
||||
@@ -634,21 +637,20 @@ def _keep_episodes_from_video_with_av(
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Get frame timestamp.
|
||||
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
|
||||
|
||||
# Check if frame is in any of our desired time ranges.
|
||||
# Check if frame is in any of our desired frame ranges.
|
||||
# Skip ranges that have already passed.
|
||||
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
|
||||
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
|
||||
range_idx += 1
|
||||
|
||||
# If we've passed all ranges, stop processing.
|
||||
if range_idx >= len(time_ranges):
|
||||
if range_idx >= len(frame_ranges):
|
||||
break
|
||||
|
||||
# Check if frame is in current range.
|
||||
start_ts, end_ts = time_ranges[range_idx]
|
||||
if frame_time < start_ts:
|
||||
start_frame = frame_ranges[range_idx][0]
|
||||
|
||||
if src_frame_count < start_frame:
|
||||
src_frame_count += 1
|
||||
continue
|
||||
|
||||
# Frame is in range - create a new frame with reset timestamps.
|
||||
@@ -661,6 +663,7 @@ def _keep_episodes_from_video_with_av(
|
||||
for pkt in v_out.encode(new_frame):
|
||||
out.mux(pkt)
|
||||
|
||||
src_frame_count += 1
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder.
|
||||
@@ -749,15 +752,17 @@ def _copy_and_reindex_videos(
|
||||
f"videos/{video_key}/to_timestamp"
|
||||
]
|
||||
else:
|
||||
# Build list of time ranges to keep, in sorted order.
|
||||
# Build list of frame ranges to keep, in sorted order.
|
||||
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
||||
episodes_to_keep_ranges: list[tuple[float, float]] = []
|
||||
|
||||
episodes_to_keep_ranges: list[tuple[int, int]] = []
|
||||
for old_idx in sorted_keep_episodes:
|
||||
src_ep = src_dataset.meta.episodes[old_idx]
|
||||
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
episodes_to_keep_ranges.append((from_ts, to_ts))
|
||||
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
|
||||
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
|
||||
assert src_ep["length"] == to_frame - from_frame, (
|
||||
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
|
||||
)
|
||||
episodes_to_keep_ranges.append((from_frame, to_frame))
|
||||
|
||||
# Use PyAV filters to efficiently re-encode only the desired segments.
|
||||
assert src_dataset.meta.video_path is not None
|
||||
|
||||
@@ -664,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for the README).
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory where the dataset will be downloaded and
|
||||
stored. If set, all dataset files will be stored directly under this path. If not set, the
|
||||
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
|
||||
HF_LEROBOT_HOME environment variable).
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||
@@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Check if cached dataset contains all requested episodes
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
@@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
@@ -1771,11 +1771,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
if extra_keys:
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
@@ -227,16 +227,17 @@ def decode_video_frames_torchvision(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -248,7 +249,11 @@ def decode_video_frames_torchvision(
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -353,15 +358,16 @@ def decode_video_frames_torchcodec(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
|
||||
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
|
||||
backbone. If None, no resizing is done and the original image resolution is used.
|
||||
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
|
||||
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
|
||||
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
|
||||
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
|
||||
this is computed automatically. Can also be set directly for legacy configs that use
|
||||
crop-only (without resize). If None and no derivation applies, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center
|
||||
crop in eval mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
resize_shape: tuple[int, int] | None = None
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
@@ -139,6 +147,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# Optimization
|
||||
compile_model: bool = False
|
||||
compile_mode: str = "reduce-overhead"
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
@@ -171,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
if self.resize_shape is not None and (
|
||||
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
|
||||
):
|
||||
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
|
||||
if not (0 < self.crop_ratio <= 1.0):
|
||||
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
|
||||
|
||||
if self.resize_shape is not None:
|
||||
if self.crop_ratio < 1.0:
|
||||
self.crop_shape = (
|
||||
int(self.resize_shape[0] * self.crop_ratio),
|
||||
int(self.resize_shape[1] * self.crop_ratio),
|
||||
)
|
||||
else:
|
||||
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
|
||||
self.crop_shape = None
|
||||
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
|
||||
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
@@ -198,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
if self.resize_shape is None and self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for `{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
|
||||
@@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
for key in self.config.image_features:
|
||||
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
@@ -182,6 +185,11 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
if config.compile_model:
|
||||
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||
# common in diffusion inference.
|
||||
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
@@ -446,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if config.crop_shape is not None:
|
||||
if config.resize_shape is not None:
|
||||
self.resize = torchvision.transforms.Resize(config.resize_shape)
|
||||
else:
|
||||
self.resize = None
|
||||
|
||||
crop_shape = config.crop_shape
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -477,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.image_features`.
|
||||
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
if config.crop_shape is not None:
|
||||
dummy_shape_h_w = config.crop_shape
|
||||
elif config.resize_shape is not None:
|
||||
dummy_shape_h_w = config.resize_shape
|
||||
else:
|
||||
dummy_shape_h_w = images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
@@ -499,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
# Preprocess: resize if configured, then crop if configured.
|
||||
|
||||
if self.resize is not None:
|
||||
x = self.resize(x)
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
self._buffer = [action]
|
||||
self._prev = action
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
@@ -277,9 +277,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
|
||||
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@@ -593,6 +593,12 @@ class VLAFlowMatching(nn.Module):
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
# Compile model if requested
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map=device,
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
@@ -56,6 +56,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
|
||||
@@ -61,6 +61,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -74,8 +74,6 @@ from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.cameras import ( # noqa: F401
|
||||
CameraConfig, # noqa: F401
|
||||
)
|
||||
@@ -92,7 +90,6 @@ from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
@@ -128,6 +125,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
@@ -157,7 +155,7 @@ class DatasetRecordConfig:
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
@@ -228,9 +226,6 @@ class RecordConfig:
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
|
||||
# Only applies when using a policy (not teleop)
|
||||
interpolation_multiplier: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
@@ -303,7 +298,6 @@ def record_loop(
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
interpolator: ActionInterpolator | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
):
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
@@ -340,14 +334,7 @@ def record_loop(
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset interpolator if provided
|
||||
if interpolator is not None:
|
||||
interpolator.reset()
|
||||
|
||||
# Calculate control interval based on interpolation
|
||||
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
|
||||
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
|
||||
|
||||
no_action_count = 0
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -368,58 +355,24 @@ def record_loop(
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
# With interpolation: only call policy when interpolator needs new action
|
||||
if use_interpolation:
|
||||
# Get action keys from robot
|
||||
action_keys = sorted(robot.action_features)
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy = make_robot_action(action_values, dataset.features)
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
|
||||
# Convert to tensor for interpolator
|
||||
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
|
||||
interpolator.add(action_tensor)
|
||||
|
||||
# Get interpolated action
|
||||
interp_action = interpolator.get()
|
||||
if interp_action is not None:
|
||||
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
||||
action_values = robot_action_to_send
|
||||
else:
|
||||
# No action available yet, skip this iteration
|
||||
continue
|
||||
else:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
|
||||
# Applies a pipeline to the raw teleop action, default is IdentityProcessor
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
arm_action = teleop_arm.get_action()
|
||||
@@ -428,15 +381,23 @@ def record_loop(
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
else:
|
||||
no_action_count += 1
|
||||
if no_action_count == 1 or no_action_count % 10 == 0:
|
||||
logging.warning(
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device. "
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
if policy is not None and act_processed_policy is not None:
|
||||
action_values = act_processed_policy
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
else:
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
@@ -457,7 +418,7 @@ def record_loop(
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
|
||||
sleep_time_s: float = control_interval - dt_s
|
||||
sleep_time_s: float = 1 / fps - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
@@ -544,7 +505,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
interpolator = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
@@ -555,10 +515,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
# Create interpolator for smoother policy control
|
||||
if cfg.interpolation_multiplier > 1:
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
@@ -590,7 +546,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
interpolator=interpolator,
|
||||
display_compressed_images=display_compressed_images,
|
||||
)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
|
||||
@@ -43,6 +43,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
|
||||
@@ -51,6 +52,7 @@ COMPATIBLE_DEVICES = [
|
||||
"koch_leader",
|
||||
"omx_follower",
|
||||
"omx_leader",
|
||||
"openarm_mini",
|
||||
"so100_follower",
|
||||
"so100_leader",
|
||||
"so101_follower",
|
||||
|
||||
@@ -94,6 +94,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
has_method,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
|
||||
@@ -378,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
# Use effective batch size for proper epoch calculation in distributed training
|
||||
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
train_tracker = MetricsTracker(
|
||||
effective_batch_size,
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
train_metrics,
|
||||
@@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
progbar = tqdm(
|
||||
total=cfg.steps - step,
|
||||
desc="Training",
|
||||
unit="step",
|
||||
disable=inside_slurm(),
|
||||
position=0,
|
||||
leave=True,
|
||||
)
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
@@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
@@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -12,18 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Real-Time Chunking (RTC) utilities for action-chunking policies."""
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
|
||||
__all__ = [
|
||||
"ActionInterpolator",
|
||||
"ActionQueue",
|
||||
"LatencyTracker",
|
||||
"RTCConfig",
|
||||
"RTCProcessor",
|
||||
]
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||
|
||||
port_right: str = "/dev/ttyUSB0"
|
||||
port_left: str = "/dev/ttyUSB1"
|
||||
|
||||
use_degrees: bool = True
|
||||
296
src/lerobot/teleoperators/openarm_mini/openarm_mini.py
Normal file
296
src/lerobot/teleoperators/openarm_mini/openarm_mini.py
Normal file
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Motors whose direction is inverted during readout
|
||||
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
|
||||
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
||||
|
||||
|
||||
class OpenArmMini(Teleoperator):
|
||||
"""
|
||||
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||
|
||||
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||
"""
|
||||
|
||||
config_class = OpenArmMiniConfig
|
||||
name = "openarm_mini"
|
||||
|
||||
def __init__(self, config: OpenArmMiniConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
norm_mode_body = MotorNormMode.DEGREES
|
||||
|
||||
motors_right = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
motors_left = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||
}
|
||||
|
||||
self.bus_right = FeetechMotorsBus(
|
||||
port=self.config.port_right,
|
||||
motors=motors_right,
|
||||
calibration=cal_right,
|
||||
)
|
||||
|
||||
self.bus_left = FeetechMotorsBus(
|
||||
port=self.config.port_left,
|
||||
motors=motors_left,
|
||||
calibration=cal_left,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus_right.motors:
|
||||
features[f"right_{motor}.pos"] = float
|
||||
for motor in self.bus_left.motors:
|
||||
features[f"left_{motor}.pos"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||
self.bus_right.connect()
|
||||
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||
self.bus_left.connect()
|
||||
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArm Mini.
|
||||
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position via half-turn homing
|
||||
4. Interactive gripper calibration (open/close positions)
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
user_input = input(
|
||||
f"Press ENTER to use existing calibration for {self.id}, "
|
||||
f"or type 'c' and press ENTER to run new calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Using existing calibration for {self.id}")
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||
}
|
||||
self.bus_right.write_calibration(cal_right)
|
||||
self.bus_left.write_calibration(cal_left)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
|
||||
self._calibrate_arm("right", self.bus_right)
|
||||
self._calibrate_arm("left", self.bus_left)
|
||||
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
|
||||
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||
"""Calibrate a single arm with Feetech motors."""
|
||||
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||
|
||||
bus.disable_torque()
|
||||
|
||||
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||
for motor in bus.motors:
|
||||
bus.write("Phase", motor, 12)
|
||||
|
||||
for motor in bus.motors:
|
||||
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
input(
|
||||
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
homing_offsets = bus.set_half_turn_homings()
|
||||
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||
|
||||
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||
|
||||
if self.calibration is None:
|
||||
self.calibration = {}
|
||||
|
||||
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||
max_res = motor_resolution - 1
|
||||
|
||||
for motor_name, motor in bus.motors.items():
|
||||
prefixed_name = f"{arm_name}_{motor_name}"
|
||||
|
||||
if motor_name == "gripper":
|
||||
input(
|
||||
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||
f"Step 1: CLOSE the gripper fully\n"
|
||||
f"Press ENTER when gripper is closed..."
|
||||
)
|
||||
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||
|
||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||
|
||||
if closed_pos < open_pos:
|
||||
range_min = int(closed_pos)
|
||||
range_max = int(open_pos)
|
||||
drive_mode = 0
|
||||
else:
|
||||
range_min = int(open_pos)
|
||||
range_max = int(closed_pos)
|
||||
drive_mode = 1
|
||||
|
||||
logger.info(
|
||||
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||
)
|
||||
else:
|
||||
range_min = 0
|
||||
range_max = max_res
|
||||
drive_mode = 0
|
||||
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||
|
||||
self.calibration[prefixed_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=drive_mode,
|
||||
homing_offset=homing_offsets[motor_name],
|
||||
range_min=range_min,
|
||||
range_max=range_max,
|
||||
)
|
||||
|
||||
cal_for_bus = {
|
||||
k.replace(f"{arm_name}_", ""): v
|
||||
for k, v in self.calibration.items()
|
||||
if k.startswith(f"{arm_name}_")
|
||||
}
|
||||
bus.write_calibration(cal_for_bus)
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_right.configure_motors()
|
||||
for motor in self.bus_right.motors:
|
||||
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
self.bus_left.disable_torque()
|
||||
self.bus_left.configure_motors()
|
||||
for motor in self.bus_left.motors:
|
||||
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
print("\nSetting up RIGHT arm motors...")
|
||||
for motor in reversed(self.bus_right.motors):
|
||||
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||
self.bus_right.setup_motor(motor)
|
||||
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||
|
||||
print("\nSetting up LEFT arm motors...")
|
||||
for motor in reversed(self.bus_left.motors):
|
||||
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||
self.bus_left.setup_motor(motor)
|
||||
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""Get current action from both arms (read positions from all motors)."""
|
||||
start = time.perf_counter()
|
||||
|
||||
right_positions = self.bus_right.sync_read("Present_Position")
|
||||
left_positions = self.bus_left.sync_read("Present_Position")
|
||||
|
||||
action: dict[str, Any] = {}
|
||||
for motor, val in right_positions.items():
|
||||
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||
for motor, val in left_positions.items():
|
||||
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.bus_right.disconnect()
|
||||
self.bus_left.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -95,6 +95,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
|
||||
return BiOpenArmLeader(config)
|
||||
elif config.type == "openarm_mini":
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
return OpenArmMini(config)
|
||||
else:
|
||||
try:
|
||||
return cast("Teleoperator", make_device_from_device_class(config))
|
||||
|
||||
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
|
||||
@@ -104,9 +104,10 @@ class MetricsTracker:
|
||||
self.metrics = metrics
|
||||
|
||||
self.steps = initial_step
|
||||
world_size = accelerator.num_processes if accelerator else 1
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
|
||||
self.samples = self.steps * self._batch_size
|
||||
self.samples = self.steps * self._batch_size * world_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
self.accelerator = accelerator
|
||||
@@ -132,7 +133,8 @@ class MetricsTracker:
|
||||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||
world_size = self.accelerator.num_processes if self.accelerator else 1
|
||||
self.samples += self._batch_size * world_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -24,6 +24,11 @@ def mock_metrics():
|
||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||
|
||||
|
||||
class MockAccelerator:
|
||||
def __init__(self, num_processes: int):
|
||||
self.num_processes = num_processes
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
meter = AverageMeter("loss", ":.2f")
|
||||
assert meter.name == "loss"
|
||||
@@ -82,6 +87,37 @@ def test_metrics_tracker_step(mock_metrics):
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_initialization_with_accelerator(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=10,
|
||||
accelerator=MockAccelerator(num_processes=2),
|
||||
)
|
||||
assert tracker.steps == 10
|
||||
assert tracker.samples == 10 * 32 * 2
|
||||
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_step_with_accelerator(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=5,
|
||||
accelerator=MockAccelerator(num_processes=2),
|
||||
)
|
||||
tracker.step()
|
||||
assert tracker.steps == 6
|
||||
assert tracker.samples == (5 * 32 * 2) + (32 * 2)
|
||||
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_getattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
assert tracker.loss == mock_metrics["loss"]
|
||||
|
||||
Reference in New Issue
Block a user