mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
Compare commits
49 Commits
codex/mode
...
feat/robom
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
015c88cf0d | ||
|
|
0164725af8 | ||
|
|
34274c6f70 | ||
|
|
f6a13b1338 | ||
|
|
9db9c35cb4 | ||
|
|
fe96b28c74 | ||
|
|
2438df1307 | ||
|
|
f218d5ab30 | ||
|
|
04125492e4 | ||
|
|
e963e5a0c4 | ||
|
|
26ff40ddd7 | ||
|
|
6d269b28c8 | ||
|
|
b607c8458e | ||
|
|
9e83510c99 | ||
|
|
1f7b03f5f2 | ||
|
|
cb8edf17e6 | ||
|
|
5699f6cbf4 | ||
|
|
0e6114ac36 | ||
|
|
c8ce413d73 | ||
|
|
82dffde7fa | ||
|
|
eaf0218bc8 | ||
|
|
a0e52d52fe | ||
|
|
e99c55af4b | ||
|
|
408e0ca763 | ||
|
|
ce24063efd | ||
|
|
82934719db | ||
|
|
401a217597 | ||
|
|
40094b0464 | ||
|
|
fdbfc015a2 | ||
|
|
d656da8ccc | ||
|
|
b5f65e5332 | ||
|
|
cd6b43ea7a | ||
|
|
2236bbe7a3 | ||
|
|
cb0a944941 | ||
|
|
8a3d64033f | ||
|
|
03ee50e08f | ||
|
|
ca87ccd941 | ||
|
|
77352c495c | ||
|
|
05a5223885 | ||
|
|
580d818aa9 | ||
|
|
587aa82021 | ||
|
|
12b88fce02 | ||
|
|
fc6c94c82a | ||
|
|
1add460678 | ||
|
|
4587c2b648 | ||
|
|
2236cdb302 | ||
|
|
7c2466979e | ||
|
|
39b966e20a | ||
|
|
ba27aab79c |
6
.github/workflows/benchmark_tests.yml
vendored
6
.github/workflows/benchmark_tests.yml
vendored
@@ -382,6 +382,7 @@ jobs:
|
|||||||
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
||||||
--env.type=robotwin \
|
--env.type=robotwin \
|
||||||
--env.task=\"\$ROBOTWIN_TASKS\" \
|
--env.task=\"\$ROBOTWIN_TASKS\" \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
@@ -482,6 +483,7 @@ jobs:
|
|||||||
--policy.path=lerobot/smolvla_robocasa \
|
--policy.path=lerobot/smolvla_robocasa \
|
||||||
--env.type=robocasa \
|
--env.type=robocasa \
|
||||||
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
@@ -693,6 +695,7 @@ jobs:
|
|||||||
--env.task=\"\$ROBOMME_TASKS\" \
|
--env.task=\"\$ROBOMME_TASKS\" \
|
||||||
--env.dataset_split=test \
|
--env.dataset_split=test \
|
||||||
--env.task_ids=[0] \
|
--env.task_ids=[0] \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
@@ -800,6 +803,7 @@ jobs:
|
|||||||
--env.type=libero_plus \
|
--env.type=libero_plus \
|
||||||
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
||||||
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
@@ -900,6 +904,8 @@ jobs:
|
|||||||
--policy.path=lerobot/smolvla_vlabench \
|
--policy.path=lerobot/smolvla_vlabench \
|
||||||
--env.type=vlabench \
|
--env.type=vlabench \
|
||||||
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||||
|
--env.episode_length=50 \
|
||||||
|
--env.max_parallel_tasks=5 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
github.event.workflow_run.event == 'pull_request' &&
|
github.event.workflow_run.event == 'pull_request' &&
|
||||||
github.event.workflow_run.conclusion == 'success' &&
|
github.event.workflow_run.conclusion == 'success' &&
|
||||||
github.repository == 'huggingface/lerobot'
|
github.repository == 'huggingface/lerobot'
|
||||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||||
with:
|
with:
|
||||||
package_name: lerobot
|
package_name: lerobot
|
||||||
secrets:
|
secrets:
|
||||||
|
|||||||
4
.github/workflows/documentation.yml
vendored
4
.github/workflows/documentation.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
|||||||
github.repository == 'huggingface/lerobot'
|
github.repository == 'huggingface/lerobot'
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||||
with:
|
with:
|
||||||
commit_sha: ${{ github.sha }}
|
commit_sha: ${{ github.sha }}
|
||||||
package: lerobot
|
package: lerobot
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
pull-requests: write
|
pull-requests: write
|
||||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||||
with:
|
with:
|
||||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||||
pr_number: ${{ github.event.number }}
|
pr_number: ${{ github.event.number }}
|
||||||
|
|||||||
3
.github/workflows/release.yml
vendored
3
.github/workflows/release.yml
vendored
@@ -152,13 +152,14 @@ jobs:
|
|||||||
BASE_VERSION="${VERSION%%-*}"
|
BASE_VERSION="${VERSION%%-*}"
|
||||||
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
|
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
|
||||||
uv pip install \
|
uv pip install \
|
||||||
|
--torch-backend cpu \
|
||||||
--index-url https://test.pypi.org/simple/ \
|
--index-url https://test.pypi.org/simple/ \
|
||||||
--extra-index-url https://pypi.org/simple \
|
--extra-index-url https://pypi.org/simple \
|
||||||
--index-strategy unsafe-best-match \
|
--index-strategy unsafe-best-match \
|
||||||
"lerobot[all]==$BASE_VERSION"
|
"lerobot[all]==$BASE_VERSION"
|
||||||
else
|
else
|
||||||
echo "Installing release version $VERSION from PyPI..."
|
echo "Installing release version $VERSION from PyPI..."
|
||||||
uv pip install "lerobot[all]==$VERSION"
|
uv pip install --torch-backend cpu "lerobot[all]==$VERSION"
|
||||||
fi
|
fi
|
||||||
- name: Check lerobot version
|
- name: Check lerobot version
|
||||||
run: uv run python -c "import lerobot; print(lerobot.__version__)"
|
run: uv run python -c "import lerobot; print(lerobot.__version__)"
|
||||||
|
|||||||
16
.github/workflows/stale.yml
vendored
16
.github/workflows/stale.yml
vendored
@@ -19,19 +19,19 @@ on:
|
|||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
# Runs at 02:00
|
# Runs at 02:00
|
||||||
schedule:
|
# schedule:
|
||||||
- cron: "0 2 * * *"
|
# - cron: "0 2 * * *"
|
||||||
|
|
||||||
env:
|
env:
|
||||||
CLOSE_ISSUE_MESSAGE: >
|
CLOSE_ISSUE_MESSAGE: >
|
||||||
This issue was closed because it has been stalled for 14 days with no activity.
|
This issue was closed because it has been stalled for 30 days with no activity.
|
||||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||||
CLOSE_PR_MESSAGE: >
|
CLOSE_PR_MESSAGE: >
|
||||||
This PR was closed because it has been stalled for 21 days with no activity.
|
This PR was closed because it has been stalled for 30 days with no activity.
|
||||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||||
WARN_ISSUE_MESSAGE: >
|
WARN_ISSUE_MESSAGE: >
|
||||||
This issue has been automatically marked as stale because it has not had
|
This issue has been automatically marked as stale because it has not had
|
||||||
recent activity (6 months). It will be closed if no further activity occurs.
|
recent activity (1 year). It will be closed if no further activity occurs.
|
||||||
Any change, comment or update to this issue will reset this count.
|
Any change, comment or update to this issue will reset this count.
|
||||||
Thank you for your contributions.
|
Thank you for your contributions.
|
||||||
WARN_PR_MESSAGE: >
|
WARN_PR_MESSAGE: >
|
||||||
@@ -59,10 +59,10 @@ jobs:
|
|||||||
stale-pr-label: stale
|
stale-pr-label: stale
|
||||||
exempt-issue-labels: never-stale
|
exempt-issue-labels: never-stale
|
||||||
exempt-pr-labels: never-stale
|
exempt-pr-labels: never-stale
|
||||||
days-before-issue-stale: 180
|
days-before-issue-stale: 365
|
||||||
days-before-issue-close: 14
|
days-before-issue-close: 30
|
||||||
days-before-pr-stale: 365
|
days-before-pr-stale: 365
|
||||||
days-before-pr-close: 21
|
days-before-pr-close: 30
|
||||||
delete-branch: true
|
delete-branch: true
|
||||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
This file provides guidance to AI agents when working with code in this repository.
|
This file provides guidance to AI agents when working with code in this repository.
|
||||||
|
|
||||||
|
> **User-facing help → [`AGENT_GUIDE.md`](./AGENT_GUIDE.md)** (SO-101 setup, recording, picking a policy, training duration, eval — with copy-pasteable commands).
|
||||||
|
|
||||||
## Project Overview
|
## Project Overview
|
||||||
|
|
||||||
LeRobot is a PyTorch-based library for real-world robotics, providing datasets, pretrained policies, and tools for training, evaluation, data collection, and robot control. It integrates with Hugging Face Hub for model/dataset sharing.
|
LeRobot is a PyTorch-based library for real-world robotics, providing datasets, pretrained policies, and tools for training, evaluation, data collection, and robot control. It integrates with Hugging Face Hub for model/dataset sharing.
|
||||||
|
|||||||
412
AGENT_GUIDE.md
Normal file
412
AGENT_GUIDE.md
Normal file
@@ -0,0 +1,412 @@
|
|||||||
|
# AGENT_GUIDE.md — LeRobot Helper for AI Agents & Users
|
||||||
|
|
||||||
|
This file is a practical, copy-paste-friendly companion for any AI agent (Cursor, Claude, ChatGPT, Codex, etc.) helping a user work with LeRobot. It complements [`AGENTS.md`](./AGENTS.md) (dev/contributor context) with **user-facing guidance**: how to start, what to train, how long, how to record, and how to calibrate an SO-101.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Start here — ask the user first (MANDATORY)
|
||||||
|
|
||||||
|
Before suggesting any command, an agent MUST ask the user at least these questions and wait for answers:
|
||||||
|
|
||||||
|
1. **What's your goal?** (e.g. "teach my SO-101 to fold a cloth", "train a policy on an existing HF dataset", "contribute a PR", "understand the codebase")
|
||||||
|
2. **What hardware do you have?**
|
||||||
|
- Robot: none / SO-100 / SO-101 / Koch / LeKiwi / Reachy / other
|
||||||
|
- Teleop: leader arm / phone / keyboard / gamepad / none
|
||||||
|
- Cameras: how many, resolution, fixed or moving?
|
||||||
|
3. **What machine will you train on?**
|
||||||
|
- GPU model + VRAM (e.g. "laptop 3060 6 GB", "RTX 4090 24 GB", "A100 80 GB", "CPU only")
|
||||||
|
- OS: macOS / Linux / Windows
|
||||||
|
4. **Skill level & time budget?** First time, some ML, experienced? Hours, days, a weekend?
|
||||||
|
5. **Do you already have a dataset?** Yes (HF repo id?) / no / want to record one
|
||||||
|
6. **How can I help right now?** (pick one concrete next step)
|
||||||
|
|
||||||
|
Only after you have answers, propose a concrete path. If something is ambiguous, ask again rather than guessing. Bias toward **the simplest thing that works** for the user's hardware and goal.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. LeRobot in 60 seconds
|
||||||
|
|
||||||
|
LeRobot = **datasets + policies + envs + robot control**, unified by a small set of strong abstractions.
|
||||||
|
|
||||||
|
- **`LeRobotDataset`** — episode-aware dataset (video or images + actions + state), loadable from the Hub or disk.
|
||||||
|
- **Policies** (`ACT`, `Diffusion`, `SmolVLA`, `π0`, `π0.5`, `Wall-X`, `X-VLA`, `VQ-BeT`, `TD-MPC`, …) — all inherit `PreTrainedPolicy` and can be pushed/pulled from the Hub.
|
||||||
|
- **Processors** — small composable transforms between dataset → policy → robot.
|
||||||
|
- **Envs** (sim) and **Robots** (real) — same action/observation contract so code swaps cleanly.
|
||||||
|
- **CLI** — `lerobot-record`, `lerobot-train`, `lerobot-eval`, `lerobot-teleoperate`, `lerobot-calibrate`, `lerobot-find-port`, `lerobot-setup-motors`, `lerobot-replay`.
|
||||||
|
|
||||||
|
See [`AGENTS.md`](./AGENTS.md) for repo architecture.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Quickstart paths (pick one)
|
||||||
|
|
||||||
|
### Path A — "I have an SO-101 and want my first trained policy"
|
||||||
|
|
||||||
|
Go to §4 (SO-101 end-to-end), then §5 (data tips), then §6 (pick a policy — likely **ACT**), then §7 (how long), then §8 (eval).
|
||||||
|
|
||||||
|
### Path B — "No hardware, I want to train on an existing dataset"
|
||||||
|
|
||||||
|
Skip §4. Pick a policy in §6, pick a duration in §7, then run `lerobot-train` per §4.9 with a Hub `--dataset.repo_id` and an `--env.type` for eval. Finish with §8.
|
||||||
|
|
||||||
|
### Path C — "I just want to understand the codebase"
|
||||||
|
|
||||||
|
Read §2 above, then `AGENTS.md` "Architecture", then open `src/lerobot/policies/act/` and `src/lerobot/datasets/lerobot_dataset.py` as canonical examples.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. SO-101 end-to-end cheat-sheet
|
||||||
|
|
||||||
|
Full details in [`docs/source/so101.mdx`](./docs/source/so101.mdx) and [`docs/source/il_robots.mdx`](./docs/source/il_robots.mdx). Minimum commands in order. Confirm arms are assembled + powered before issuing.
|
||||||
|
|
||||||
|
**4.1 Install**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install 'lerobot[feetech]' # SO-100/SO-101 motor stack
|
||||||
|
# pip install 'lerobot[all]' # everything
|
||||||
|
# pip install 'lerobot[aloha,pusht]' # specific features
|
||||||
|
# pip install 'lerobot[smolvla]' # add SmolVLA deps
|
||||||
|
git lfs install && git lfs pull
|
||||||
|
hf auth login # required to push datasets/policies
|
||||||
|
```
|
||||||
|
|
||||||
|
Contributors can alternatively use `uv sync --locked --extra feetech` (see `AGENTS.md`).
|
||||||
|
|
||||||
|
**4.2 Find USB ports** — run once per arm, unplug when prompted.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-find-port
|
||||||
|
```
|
||||||
|
|
||||||
|
macOS: `/dev/tty.usbmodem...`; Linux: `/dev/ttyACM0` (may need `sudo chmod 666 /dev/ttyACM0`).
|
||||||
|
|
||||||
|
**4.3 Setup motor IDs & baudrate** (one-time, per arm)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-setup-motors --robot.type=so101_follower --robot.port=<FOLLOWER_PORT>
|
||||||
|
lerobot-setup-motors --teleop.type=so101_leader --teleop.port=<LEADER_PORT>
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.4 Calibrate** — center all joints, press Enter, sweep each joint through its full range. The `id` is the calibration key — reuse it everywhere.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-calibrate --robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower
|
||||||
|
lerobot-calibrate --teleop.type=so101_leader --teleop.port=<LEADER_PORT> --teleop.id=my_leader
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.5 Teleoperate** (sanity check, no recording)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-teleoperate \
|
||||||
|
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--teleop.type=so101_leader --teleop.port=<LEADER_PORT> --teleop.id=my_leader \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--display_data=true
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Feetech timeout / comms error on SO-100 / SO-101?** Before touching software, check the **red motor LEDs** on the daisy chain.
|
||||||
|
>
|
||||||
|
> - **All steady red, gripper → base chain** → wiring OK.
|
||||||
|
> - **One or more motors dark / chain stops mid-way** → wiring issue: reseat the 3-pin cables, check the controller-board power supply, and make sure each motor is fully clicked in.
|
||||||
|
> - **LEDs blinking** → the motor is in an **error state**: usually overload (forcing a joint past its limit) **or wrong power supply voltage**. SO-100 / SO-101 ship in two variants — a **5 V / 7.4 V** build and a **12 V** build — they are NOT interchangeable. Using a 12 V PSU on a 5 V / 7.4 V arm (or vice-versa) will trip this error; confirm your motor variant before powering up.
|
||||||
|
>
|
||||||
|
> Most "timeout" errors are physical, not code.
|
||||||
|
|
||||||
|
**4.6 Record a dataset** — keys: **→** next, **←** redo, **ESC** finish & upload.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
HF_USER=$(NO_COLOR=1 hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||||
|
|
||||||
|
lerobot-record \
|
||||||
|
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--teleop.type=so101_leader --teleop.port=<LEADER_PORT> --teleop.id=my_leader \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_task \
|
||||||
|
--dataset.single_task="<describe the task in one sentence>" \
|
||||||
|
--dataset.num_episodes=50 \
|
||||||
|
--dataset.episode_time_s=30 \
|
||||||
|
--dataset.reset_time_s=10 \
|
||||||
|
--display_data=true
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.7 Visualize** — **always** do this before training. Look for missing frames, camera blur, unreachable targets, inconsistent object positions.
|
||||||
|
After upload: https://huggingface.co/spaces/lerobot/visualize_dataset → paste `${HF_USER}/my_task`. Works for **any LeRobot-formatted Hub dataset** — use it to scout other datasets, inspect episode quality, or debug your own data before retraining.
|
||||||
|
|
||||||
|
**4.8 Replay an episode** (sanity check)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-replay --robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_task --dataset.episode=0
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.9 Train** (default: ACT — fastest, lowest memory). Apple silicon: `--policy.device=mps`. See §6/§7 for policy and duration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=${HF_USER}/my_task \
|
||||||
|
--policy.type=act \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--output_dir=outputs/train/act_my_task \
|
||||||
|
--job_name=act_my_task \
|
||||||
|
--batch_size=8 \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--policy.repo_id=${HF_USER}/act_my_task
|
||||||
|
```
|
||||||
|
|
||||||
|
**4.10 Evaluate on the real robot** — compare success rate to a teleoperated baseline.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-record \
|
||||||
|
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--dataset.repo_id=${HF_USER}/eval_my_task \
|
||||||
|
--dataset.single_task="<same task description as training>" \
|
||||||
|
--dataset.num_episodes=10 \
|
||||||
|
--policy.path=${HF_USER}/act_my_task
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Data collection tips (beginner → reliable policy)
|
||||||
|
|
||||||
|
Good data beats clever models. Adopt these defaults and deviate only with evidence.
|
||||||
|
|
||||||
|
### 5.1 Setup & ergonomics
|
||||||
|
|
||||||
|
- **Fix the rig and cameras** before touching the software. If the rig vibrates or the operator gets frustrated, fix that first — more bad data won't help.
|
||||||
|
- **Lighting matters more than resolution.** Diffuse, consistent light. Avoid moving shadows.
|
||||||
|
- **"Can you do the task from the camera view alone?"** If no, your cameras are wrong. Fix before recording.
|
||||||
|
- Enable **action interpolation** for rollouts when available for smoother trajectories.
|
||||||
|
|
||||||
|
### 5.2 Practice before you record
|
||||||
|
|
||||||
|
- Do 5–10 demos without recording. Build a deliberate, repeatable strategy.
|
||||||
|
- Hesitant or inconsistent demos teach the model hesitation.
|
||||||
|
|
||||||
|
### 5.3 Quality over speed
|
||||||
|
|
||||||
|
Deliberate, high-quality execution beats fast sloppy runs. Optimize for speed only **after** strategy is dialed in — never trade quality for it.
|
||||||
|
|
||||||
|
### 5.4 Consistency within and across episodes
|
||||||
|
|
||||||
|
Same grasp, approach vector, and timing. Coherent strategies are much easier to learn than wildly varying movements.
|
||||||
|
|
||||||
|
### 5.5 Start small, then extend (the golden rule)
|
||||||
|
|
||||||
|
- **First 50 episodes = constrained version** of the task: one object, fixed position, fixed camera setup, one operator.
|
||||||
|
- Train a quick ACT model. See what fails.
|
||||||
|
- **Then add diversity** along one axis at a time: more positions → more lighting → more objects → more operators.
|
||||||
|
- Don't try to collect the "perfect dataset" on day one. Iterate.
|
||||||
|
|
||||||
|
### 5.6 Policy choice for beginners
|
||||||
|
|
||||||
|
- **Laptop / first time / want results fast → ACT.** Works surprisingly well, trains fast even on a laptop GPU.
|
||||||
|
- **Bigger GPU / language-conditioned / multi-task → SmolVLA.** Unfreezing the vision encoder (see §7) is a big win here.
|
||||||
|
- Defer π0 / π0.5 / Wall-X / X-VLA until you have a proven ACT baseline and a 20+ GB GPU.
|
||||||
|
|
||||||
|
### 5.7 Recommended defaults for your first task
|
||||||
|
|
||||||
|
| Setting | Value |
|
||||||
|
| ---------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| Episodes | **50** to start, scale to 100–300 after first training |
|
||||||
|
| Episode length | 20–45 s (shorter is fine for grasp/place) |
|
||||||
|
| Reset time | 10 s |
|
||||||
|
| FPS | 30 |
|
||||||
|
| Cameras | **2 cameras recommended**: 1 fixed front + 1 wrist. Multi-view often outperforms single-view. A single fixed camera also works to keep things simple. |
|
||||||
|
| Task description | Short, specific, action-phrased sentence |
|
||||||
|
|
||||||
|
### 5.8 Troubleshooting signal
|
||||||
|
|
||||||
|
- Policy fails at one specific stage → record 10–20 more episodes **targeting that stage**.
|
||||||
|
- Policy flaps / oscillates → likely inconsistent demos, or need more training; re-record worst episodes (use **←** to redo).
|
||||||
|
- Policy ignores the object → camera framing or lighting issue, not a model issue.
|
||||||
|
|
||||||
|
See also: [What makes a good dataset](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Which policy should I train?
|
||||||
|
|
||||||
|
Match the policy to the user's **GPU memory** and **time budget**. Numbers below come from an internal profiling run (one training update per policy). They are **indicative only** — see caveats.
|
||||||
|
|
||||||
|
### 6.1 Profiling snapshot (indicative)
|
||||||
|
|
||||||
|
All policies typically train for **5–10 epochs** (see §7).
|
||||||
|
|
||||||
|
> **Human-facing version:** the [Compute Hardware Guide](./docs/source/hardware_guide.mdx) reuses the table below and adds a cloud-GPU tier guide and a Hugging Face Jobs pointer.
|
||||||
|
|
||||||
|
| Policy | Batch | Update (ms) | Peak GPU mem (GB) | Best for |
|
||||||
|
| ----------- | ----: | ----------: | ----------------: | ------------------------------------------------------------------------------------------------ |
|
||||||
|
| `act` | 4 | **83.9** | **0.94** | First-time users, laptops, single-task. Fast and reliable. |
|
||||||
|
| `diffusion` | 4 | 168.6 | 4.94 | Multi-modal action distributions; needs mid-range GPU. |
|
||||||
|
| `smolvla` | 1 | 357.8 | 3.93 | Language-conditioned, multi-task, small VLA. **Unfreeze vision encoder for big gains** (see §7). |
|
||||||
|
| `xvla` | 1 | 731.6 | 15.52 | Large VLA, multi-task. |
|
||||||
|
| `wall_x` | 1 | 716.5 | 15.95 | Large VLA with world-model objective. |
|
||||||
|
| `pi0` | 1 | 940.3 | 15.50 | Strong large VLA baseline (Physical Intelligence). |
|
||||||
|
| `pi05` | 1 | 1055.8 | 16.35 | Newer π policy; similar footprint to `pi0`. |
|
||||||
|
|
||||||
|
**Critical caveats:**
|
||||||
|
|
||||||
|
- **Optimizer:** measured with **SGD**. LeRobot's default is **AdamW**, which keeps extra optimizer state → **peak memory will be noticeably higher** with the default, especially for `pi0`, `pi05`, `wall_x`, `xvla`.
|
||||||
|
- **Batch size:** the large policies were profiled at batch 1. In practice use a **larger batch** for stable training (see §7.4). Memory scales roughly linearly with batch.
|
||||||
|
|
||||||
|
### 6.2 Decision rules
|
||||||
|
|
||||||
|
- **< 8 GB VRAM (laptop, 3060, M-series Mac):** → `act`. Maybe `diffusion` if you have ~6–8 GB free.
|
||||||
|
- **12–16 GB VRAM (4070/4080, A4000):** → `smolvla` with defaults, or `act`/`diffusion` with larger batch. `pi0`/`pi05`/`wall_x`/`xvla` feasible only with small batch + gradient accumulation.
|
||||||
|
- **24+ GB VRAM (3090/4090/A5000):** → any policy. Prefer `smolvla` (unfrozen) for multi-task; `act` for single-task grasp-and-place (still often the best ROI). Could experiment with `pi0` or `pi05` or `xvla`
|
||||||
|
- **80 GB (A100/H100):** → any, with healthy batch. `pi05`, `xvla`, `wall_x` become comfortable.
|
||||||
|
- **CPU only:** → don't train here. Use Google Colab (see [`docs/source/notebooks.mdx`](./docs/source/notebooks.mdx)) or a rented GPU.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. How long should I train?
|
||||||
|
|
||||||
|
Robotics imitation learning usually converges in a **few epochs over the dataset**, not hundreds of thousands of raw steps. Think **epochs first**, then translate to steps.
|
||||||
|
|
||||||
|
### 7.1 Rule of thumb
|
||||||
|
|
||||||
|
- **Typical total: 5–10 epochs.** Start at 5, eval, then decide if more helps.
|
||||||
|
- Very small datasets (< 30 episodes) may want slightly more epochs — but first, **collect more data**.
|
||||||
|
- VLAs with a pretrained vision backbone typically need **fewer** epochs than training from scratch.
|
||||||
|
|
||||||
|
### 7.2 Steps ↔ epochs conversion
|
||||||
|
|
||||||
|
```
|
||||||
|
total_frames = sum of frames over all episodes # e.g. 50 eps × 30 fps × 30 s ≈ 45,000
|
||||||
|
steps_per_epoch = ceil(total_frames / batch_size)
|
||||||
|
total_steps = epochs × steps_per_epoch
|
||||||
|
```
|
||||||
|
|
||||||
|
Examples for `--batch_size=8`:
|
||||||
|
|
||||||
|
| Dataset size | Frames | Steps / epoch | 5 epochs | 10 epochs |
|
||||||
|
| ----------------------- | ------: | ------------: | -------: | --------: |
|
||||||
|
| 50 eps × 30 s @ 30 fps | 45,000 | ~5,625 | 28k | 56k |
|
||||||
|
| 100 eps × 30 s @ 30 fps | 90,000 | ~11,250 | 56k | 113k |
|
||||||
|
| 300 eps × 30 s @ 30 fps | 270,000 | ~33,750 | 169k | 338k |
|
||||||
|
|
||||||
|
Pass the resulting total with `--steps=<N>`; eval at intermediate checkpoints (`outputs/train/.../checkpoints/`).
|
||||||
|
|
||||||
|
### 7.3 Per-policy starting points (single-task, ~50 episodes)
|
||||||
|
|
||||||
|
| Policy | Batch | Steps (first run) | Notes |
|
||||||
|
| -------------- | ----: | ----------------: | ----------------------------------------------------------------- |
|
||||||
|
| `act` | 8–16 | 30k–80k | Usually converges under 50k for single-task. |
|
||||||
|
| `diffusion` | 8–16 | 80k–150k | Benefits from longer training than ACT. |
|
||||||
|
| `smolvla` | 4–8 | 30k–80k | Pretrained VLM → converges fast. |
|
||||||
|
| `pi0` / `pi05` | 1–4 | 30k–80k | Memory-bound; use gradient accumulation for effective batch ≥ 16! |
|
||||||
|
|
||||||
|
### 7.4 Batch size guidance
|
||||||
|
|
||||||
|
- **Bigger batch is preferable** for stable gradients on teleop data.
|
||||||
|
- If GPU memory is the bottleneck, use **gradient accumulation** to raise _effective_ batch without raising peak memory.
|
||||||
|
- Scale **learning rate** gently with batch; most LeRobot defaults work fine for a 2–4× batch change.
|
||||||
|
|
||||||
|
### 7.5 Scale LR schedule & checkpoints with `--steps`
|
||||||
|
|
||||||
|
LeRobot's default schedulers (e.g. SmolVLA's cosine decay) use `scheduler_decay_steps=30_000`, which is sized for long training runs. When you shorten training (e.g. 5k–10k steps on a small dataset), **scale the scheduler down to match** — otherwise the LR stays near the peak and never decays. Same for checkpoint frequency.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train ... \
|
||||||
|
--steps=5000 \
|
||||||
|
--policy.scheduler_decay_steps=5000 \
|
||||||
|
--save_freq=5000
|
||||||
|
```
|
||||||
|
|
||||||
|
Rule of thumb: set `scheduler_decay_steps ≈ steps`, and `save_freq` to whatever granularity you want for eval (e.g. every 1k–5k steps). Match `scheduler_warmup_steps` proportionally if your run is very short.
|
||||||
|
|
||||||
|
### 7.6 SmolVLA: unfreeze the vision encoder for real gains
|
||||||
|
|
||||||
|
SmolVLA ships with `freeze_vision_encoder=True`. Unfreezing usually **improves performance substantially** on specialized tasks, at the cost of more VRAM and slower steps. Enable with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train ... --policy.type=smolvla \
|
||||||
|
--policy.freeze_vision_encoder=false \
|
||||||
|
--policy.train_expert_only=false
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.7 Signals to stop / keep going
|
||||||
|
|
||||||
|
- Train loss plateaus → stop, save a Hub checkpoint.
|
||||||
|
- Train loss still dropping and you're under 10 epochs → keep going.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Evaluation & benchmarks
|
||||||
|
|
||||||
|
Two flavors of evaluation:
|
||||||
|
|
||||||
|
### 8.1 Real-robot eval (SO-101, etc.)
|
||||||
|
|
||||||
|
Reuse `lerobot-record` with `--policy.path` to run the trained policy on-robot and save the run as an eval dataset. Convention: prefix the dataset with `eval_`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-record \
|
||||||
|
--robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.id=my_follower \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--dataset.repo_id=${HF_USER}/eval_my_task \
|
||||||
|
--dataset.single_task="<same task description used during training>" \
|
||||||
|
--dataset.num_episodes=10 \
|
||||||
|
--policy.path=${HF_USER}/act_my_task
|
||||||
|
```
|
||||||
|
|
||||||
|
Report success rate across episodes. Compare to a teleoperated baseline and to an earlier checkpoint to catch regressions.
|
||||||
|
|
||||||
|
### 8.2 Sim-benchmark eval
|
||||||
|
|
||||||
|
For policies trained on sim datasets (PushT, Aloha, LIBERO, MetaWorld, RoboCasa, …) use `lerobot-eval` against the matching `env.type`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=${HF_USER}/diffusion_pusht \
|
||||||
|
--env.type=pusht \
|
||||||
|
--eval.n_episodes=50 \
|
||||||
|
--eval.batch_size=10 \
|
||||||
|
--policy.device=cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
- Use `--policy.path=outputs/train/.../checkpoints/<step>/pretrained_model` for local checkpoints.
|
||||||
|
- `--eval.n_episodes` should be ≥ 50 for a stable success-rate estimate.
|
||||||
|
- Available envs live in `src/lerobot/envs/`. See [`docs/source/libero.mdx`](./docs/source/libero.mdx), [`metaworld.mdx`](./docs/source/metaworld.mdx), [`robocasa.mdx`](./docs/source/robocasa.mdx), [`vlabench.mdx`](./docs/source/vlabench.mdx) for specific benchmarks.
|
||||||
|
- To add a new benchmark, see [`docs/source/adding_benchmarks.mdx`](./docs/source/adding_benchmarks.mdx) and [`envhub.mdx`](./docs/source/envhub.mdx).
|
||||||
|
|
||||||
|
### 8.2b Dockerfiles for benchmark eval
|
||||||
|
|
||||||
|
Benchmark envs have native dependencies that are painful to install locally. The repo ships **pre-baked Dockerfiles** for each supported benchmark — use these to run `lerobot-eval` in a reproducible environment:
|
||||||
|
|
||||||
|
| Benchmark | Dockerfile |
|
||||||
|
| ----------- | -------------------------------------------------------------------------------------- |
|
||||||
|
| LIBERO | [`docker/Dockerfile.benchmark.libero`](./docker/Dockerfile.benchmark.libero) |
|
||||||
|
| LIBERO+ | [`docker/Dockerfile.benchmark.libero_plus`](./docker/Dockerfile.benchmark.libero_plus) |
|
||||||
|
| MetaWorld | [`docker/Dockerfile.benchmark.metaworld`](./docker/Dockerfile.benchmark.metaworld) |
|
||||||
|
| RoboCasa | [`docker/Dockerfile.benchmark.robocasa`](./docker/Dockerfile.benchmark.robocasa) |
|
||||||
|
| RoboCerebra | [`docker/Dockerfile.benchmark.robocerebra`](./docker/Dockerfile.benchmark.robocerebra) |
|
||||||
|
| RoboMME | [`docker/Dockerfile.benchmark.robomme`](./docker/Dockerfile.benchmark.robomme) |
|
||||||
|
| RoboTwin | [`docker/Dockerfile.benchmark.robotwin`](./docker/Dockerfile.benchmark.robotwin) |
|
||||||
|
| VLABench | [`docker/Dockerfile.benchmark.vlabench`](./docker/Dockerfile.benchmark.vlabench) |
|
||||||
|
|
||||||
|
Build and run (adapt to your benchmark):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -f docker/Dockerfile.benchmark.robomme -t lerobot-bench-robomme .
|
||||||
|
docker run --gpus all --rm -it \
|
||||||
|
-v $HOME/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
lerobot-bench-robomme \
|
||||||
|
lerobot-eval --policy.path=<your_policy> --env.type=<env> --eval.n_episodes=50
|
||||||
|
```
|
||||||
|
|
||||||
|
See [`docker/README.md`](./docker/README.md) for base-image details.
|
||||||
|
|
||||||
|
### 8.3 Target success rates
|
||||||
|
|
||||||
|
Single-task grasp-and-place with 50 clean episodes: ACT should reach **> 70% success** on the training configuration. Less → data problem (see §5), not model problem. Expect a drop when generalizing to new positions — scale episodes or diversity to recover.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Further reading & resources
|
||||||
|
|
||||||
|
- **Getting started:** [`installation.mdx`](./docs/source/installation.mdx) · [`il_robots.mdx`](./docs/source/il_robots.mdx) · [What makes a good dataset](https://huggingface.co/blog/lerobot-datasets)
|
||||||
|
- **Per-policy docs:** browse [`docs/source/*.mdx`](./docs/source/) (policies, hardware, benchmarks, advanced training).
|
||||||
|
- **Community:** [Discord](https://discord.com/invite/s3KuuzsPFb) · [Hub `LeRobot` tag](https://huggingface.co/datasets?other=LeRobot) · [Dataset visualizer](https://huggingface.co/spaces/lerobot/visualize_dataset)
|
||||||
|
|
||||||
|
> Keep this file current. If you learn a rule that would prevent a class of user mistakes, add it here and in [`AGENTS.md`](./AGENTS.md).
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||||
|
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
|
||||||
include src/lerobot/datasets/card_template.md
|
include src/lerobot/datasets/card_template.md
|
||||||
include src/lerobot/envs/metaworld_config.json
|
include src/lerobot/envs/metaworld_config.json
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ lerobot-train \
|
|||||||
|
|
||||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||||
|
|
||||||
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies).
|
For detailed policy setup guides, see the [Policy Documentation](https://huggingface.co/docs/lerobot/bring_your_own_policies). For GPU/RAM requirements and expected training time per policy, see the [Compute Hardware Guide](https://huggingface.co/docs/lerobot/hardware_guide).
|
||||||
|
|
||||||
## Inference & Evaluation
|
## Inference & Evaluation
|
||||||
|
|
||||||
|
|||||||
@@ -1,288 +0,0 @@
|
|||||||
# Video benchmark
|
|
||||||
|
|
||||||
## Questions
|
|
||||||
|
|
||||||
What is the optimal trade-off between:
|
|
||||||
|
|
||||||
- maximizing loading time with random access,
|
|
||||||
- minimizing memory space on disk,
|
|
||||||
- maximizing success rate of policies,
|
|
||||||
- compatibility across devices/platforms for decoding videos (e.g. video players, web browsers).
|
|
||||||
|
|
||||||
How to encode videos?
|
|
||||||
|
|
||||||
- Which video codec (`-vcodec`) to use? h264, h265, AV1?
|
|
||||||
- What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`?
|
|
||||||
- How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`?
|
|
||||||
- Which frequency to chose for key frames (`-g`)? A key frame every `10` frames?
|
|
||||||
|
|
||||||
How to decode videos?
|
|
||||||
|
|
||||||
- Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`?
|
|
||||||
- What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`)
|
|
||||||
|
|
||||||
## Variables
|
|
||||||
|
|
||||||
**Image content & size**
|
|
||||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
|
||||||
For these reasons, we run this benchmark on four representative datasets:
|
|
||||||
|
|
||||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
|
||||||
- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
|
||||||
- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera.
|
|
||||||
- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera.
|
|
||||||
|
|
||||||
Note: The datasets used for this benchmark need to be image datasets, not video datasets.
|
|
||||||
|
|
||||||
**Data augmentations**
|
|
||||||
We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.).
|
|
||||||
|
|
||||||
### Encoding parameters
|
|
||||||
|
|
||||||
| parameter | values |
|
|
||||||
| ----------- | ------------------------------------------------------------ |
|
|
||||||
| **vcodec** | `libx264`, `libx265`, `libsvtav1` |
|
|
||||||
| **pix_fmt** | `yuv444p`, `yuv420p` |
|
|
||||||
| **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` |
|
|
||||||
| **crf** | `0`, `5`, `10`, `15`, `20`, `25`, `30`, `40`, `50`, `None` |
|
|
||||||
|
|
||||||
Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames.
|
|
||||||
|
|
||||||
For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used:
|
|
||||||
|
|
||||||
- h264: https://trac.ffmpeg.org/wiki/Encode/H.264
|
|
||||||
- h265: https://trac.ffmpeg.org/wiki/Encode/H.265
|
|
||||||
- AV1: https://trac.ffmpeg.org/wiki/Encode/AV1
|
|
||||||
|
|
||||||
### Decoding parameters
|
|
||||||
|
|
||||||
**Decoder**
|
|
||||||
We tested two video decoding backends from torchvision:
|
|
||||||
|
|
||||||
- `pyav`
|
|
||||||
- `video_reader` (requires to build torchvision from source)
|
|
||||||
|
|
||||||
**Requested timestamps**
|
|
||||||
Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast.
|
|
||||||
This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios:
|
|
||||||
|
|
||||||
- `1_frame`: 1 frame,
|
|
||||||
- `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`),
|
|
||||||
- `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`)
|
|
||||||
|
|
||||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
|
||||||
|
|
||||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
|
||||||
|
|
||||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
|
||||||
|
|
||||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
|
||||||
|
|
||||||
## Metrics
|
|
||||||
|
|
||||||
**Data compression ratio (lower is better)**
|
|
||||||
`video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images.
|
|
||||||
|
|
||||||
**Loading time ratio (lower is better)**
|
|
||||||
`video_images_load_time_ratio` is the ratio of the time it takes to decode frames from the video at a given timestamps over the time it takes to load the exact same original images. Lower is better. For instance, `video_images_load_time_ratio=200%` means that decoding from video is 2 times slower than loading the original images.
|
|
||||||
|
|
||||||
**Average Mean Square Error (lower is better)**
|
|
||||||
`avg_mse` is the average mean square error between each decoded frame and its corresponding original image over all requested timestamps, and also divided by the number of pixels in the image to be comparable when switching to different image sizes.
|
|
||||||
|
|
||||||
**Average Peak Signal to Noise Ratio (higher is better)**
|
|
||||||
`avg_psnr` measures the ratio between the maximum possible power of a signal and the power of corrupting noise that affects the fidelity of its representation. Higher PSNR indicates better quality.
|
|
||||||
|
|
||||||
**Average Structural Similarity Index Measure (higher is better)**
|
|
||||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
|
||||||
|
|
||||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
|
||||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
|
||||||
|
|
||||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
|
||||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
|
||||||
|
|
||||||
<!-- **Loss of a pretrained policy (higher is better)** (not available)
|
|
||||||
`loss_pretrained` is the result of evaluating with the selected encoding/decoding settings a policy pretrained on original images. It is easier to understand than `avg_l2_error`.
|
|
||||||
|
|
||||||
**Success rate after retraining (higher is better)** (not available)
|
|
||||||
`success_rate` is the result of training and evaluating a policy with the selected encoding/decoding settings. It is the most difficult metric to get but also the very best. -->
|
|
||||||
|
|
||||||
## How the benchmark works
|
|
||||||
|
|
||||||
The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset.
|
|
||||||
|
|
||||||
**Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy).
|
|
||||||
This gives a unique set of encoding parameters which is used to encode the episode.
|
|
||||||
|
|
||||||
**Decoding:** Then, for each of those unique encodings, we iterate through every combination of the decoding parameters `backend` and `timestamps_mode`. For each of them, we record the metrics of a number of samples (given by `--num-samples`). This is parallelized for efficiency and the number of processes can be controlled with `--num-workers`. Ideally, it's best to have a `--num-samples` that is divisible by `--num-workers`.
|
|
||||||
|
|
||||||
Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv tables.
|
|
||||||
These are then all concatenated to a single table ready for analysis.
|
|
||||||
|
|
||||||
## Caveats
|
|
||||||
|
|
||||||
We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination.
|
|
||||||
|
|
||||||
Additional encoding parameters exist that are not included in this benchmark. In particular:
|
|
||||||
|
|
||||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
|
||||||
- `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.).
|
|
||||||
|
|
||||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
|
||||||
|
|
||||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
|
||||||
|
|
||||||
- `torchaudio`
|
|
||||||
- `ffmpegio`
|
|
||||||
- `decord`
|
|
||||||
- `nvc`
|
|
||||||
|
|
||||||
Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding.
|
|
||||||
However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark.
|
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)).
|
|
||||||
|
|
||||||
**Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built.
|
|
||||||
|
|
||||||
## Adding a video decoder
|
|
||||||
|
|
||||||
Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`.
|
|
||||||
You can easily add a new decoder to benchmark by adding it to this function in the script:
|
|
||||||
|
|
||||||
```diff
|
|
||||||
def decode_video_frames(
|
|
||||||
video_path: str,
|
|
||||||
timestamps: list[float],
|
|
||||||
tolerance_s: float,
|
|
||||||
backend: str,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if backend in ["pyav", "video_reader"]:
|
|
||||||
return decode_video_frames_torchvision(
|
|
||||||
video_path, timestamps, tolerance_s, backend
|
|
||||||
)
|
|
||||||
+ elif backend == ["your_decoder"]:
|
|
||||||
+ return your_decoder_function(
|
|
||||||
+ video_path, timestamps, tolerance_s, backend
|
|
||||||
+ )
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(backend)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Example
|
|
||||||
|
|
||||||
For a quick run, you can try these parameters:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python benchmark/video/run_video_benchmark.py \
|
|
||||||
--output-dir outputs/video_benchmark \
|
|
||||||
--repo-ids \
|
|
||||||
lerobot/pusht_image \
|
|
||||||
lerobot/aloha_mobile_shrimp_image \
|
|
||||||
--vcodec libx264 libx265 \
|
|
||||||
--pix-fmt yuv444p yuv420p \
|
|
||||||
--g 2 20 None \
|
|
||||||
--crf 10 40 None \
|
|
||||||
--timestamps-modes 1_frame 2_frames \
|
|
||||||
--backends pyav video_reader \
|
|
||||||
--num-samples 5 \
|
|
||||||
--num-workers 5 \
|
|
||||||
--save-frames 0
|
|
||||||
```
|
|
||||||
|
|
||||||
## Results
|
|
||||||
|
|
||||||
### Reproduce
|
|
||||||
|
|
||||||
We ran the benchmark with the following parameters:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# h264 and h265 encodings
|
|
||||||
python benchmark/video/run_video_benchmark.py \
|
|
||||||
--output-dir outputs/video_benchmark \
|
|
||||||
--repo-ids \
|
|
||||||
lerobot/pusht_image \
|
|
||||||
lerobot/aloha_mobile_shrimp_image \
|
|
||||||
lerobot/paris_street \
|
|
||||||
lerobot/kitchen \
|
|
||||||
--vcodec libx264 libx265 \
|
|
||||||
--pix-fmt yuv444p yuv420p \
|
|
||||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
|
||||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
|
||||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
|
||||||
--backends pyav video_reader \
|
|
||||||
--num-samples 50 \
|
|
||||||
--num-workers 5 \
|
|
||||||
--save-frames 1
|
|
||||||
|
|
||||||
# av1 encoding (only compatible with yuv420p and pyav decoder)
|
|
||||||
python benchmark/video/run_video_benchmark.py \
|
|
||||||
--output-dir outputs/video_benchmark \
|
|
||||||
--repo-ids \
|
|
||||||
lerobot/pusht_image \
|
|
||||||
lerobot/aloha_mobile_shrimp_image \
|
|
||||||
lerobot/paris_street \
|
|
||||||
lerobot/kitchen \
|
|
||||||
--vcodec libsvtav1 \
|
|
||||||
--pix-fmt yuv420p \
|
|
||||||
--g 1 2 3 4 5 6 10 15 20 40 None \
|
|
||||||
--crf 0 5 10 15 20 25 30 40 50 None \
|
|
||||||
--timestamps-modes 1_frame 2_frames 6_frames \
|
|
||||||
--backends pyav \
|
|
||||||
--num-samples 50 \
|
|
||||||
--num-workers 5 \
|
|
||||||
--save-frames 1
|
|
||||||
```
|
|
||||||
|
|
||||||
The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing)
|
|
||||||
|
|
||||||
### Parameters selected for LeRobotDataset
|
|
||||||
|
|
||||||
Considering these results, we chose what we think is the best set of encoding parameter:
|
|
||||||
|
|
||||||
- vcodec: `libsvtav1`
|
|
||||||
- pix-fmt: `yuv420p`
|
|
||||||
- g: `2`
|
|
||||||
- crf: `30`
|
|
||||||
|
|
||||||
Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_reader` does not support it (and `pyav` doesn't require a custom build of `torchvision`).
|
|
||||||
|
|
||||||
### Summary
|
|
||||||
|
|
||||||
These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav`
|
|
||||||
|
|
||||||
| video_images_size_ratio | vcodec | pix_fmt | | | |
|
|
||||||
| --------------------------------- | ---------- | ------- | --------- | --------- | --------- |
|
|
||||||
| | libx264 | | libx265 | | libsvtav1 |
|
|
||||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
|
||||||
| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% |
|
|
||||||
| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% |
|
|
||||||
| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% |
|
|
||||||
| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% |
|
|
||||||
|
|
||||||
| video_images_load_time_ratio | vcodec | pix_fmt | | | |
|
|
||||||
| --------------------------------- | ------- | ------- | -------- | ------- | --------- |
|
|
||||||
| | libx264 | | libx265 | | libsvtav1 |
|
|
||||||
| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
|
||||||
| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 |
|
|
||||||
| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** |
|
|
||||||
| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** |
|
|
||||||
| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** |
|
|
||||||
|
|
||||||
| | | vcodec | pix_fmt | | | |
|
|
||||||
| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ |
|
|
||||||
| | | libx264 | | libx265 | | libsvtav1 |
|
|
||||||
| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p |
|
|
||||||
| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 |
|
|
||||||
| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 |
|
|
||||||
| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% |
|
|
||||||
| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** |
|
|
||||||
| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** |
|
|
||||||
| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** |
|
|
||||||
| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** |
|
|
||||||
| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** |
|
|
||||||
| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** |
|
|
||||||
| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** |
|
|
||||||
| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** |
|
|
||||||
| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** |
|
|
||||||
@@ -1,488 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Assess the performance of video decoding in various configurations.
|
|
||||||
|
|
||||||
This script will benchmark different video encoding and decoding parameters.
|
|
||||||
See the provided README.md or run `python benchmark/video/run_video_benchmark.py --help` for usage info.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import datetime as dt
|
|
||||||
import itertools
|
|
||||||
import random
|
|
||||||
import shutil
|
|
||||||
from collections import OrderedDict
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
from pathlib import Path
|
|
||||||
from threading import Lock
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import PIL
|
|
||||||
import torch
|
|
||||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
from lerobot.datasets.video_utils import (
|
|
||||||
decode_video_frames,
|
|
||||||
encode_video_frames,
|
|
||||||
)
|
|
||||||
from lerobot.utils.constants import OBS_IMAGE
|
|
||||||
from lerobot.utils.utils import TimerManager
|
|
||||||
|
|
||||||
BASE_ENCODING = OrderedDict(
|
|
||||||
[
|
|
||||||
("vcodec", "libx264"),
|
|
||||||
("pix_fmt", "yuv444p"),
|
|
||||||
("g", 2),
|
|
||||||
("crf", None),
|
|
||||||
# TODO(aliberts): Add fastdecode
|
|
||||||
# ("fastdecode", 0),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(rcadene, aliberts): move to `utils.py` folder when we want to refactor
|
|
||||||
def parse_int_or_none(value) -> int | None:
|
|
||||||
if value.lower() == "none":
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except ValueError as e:
|
|
||||||
raise argparse.ArgumentTypeError(f"Invalid int or None: {value}") from e
|
|
||||||
|
|
||||||
|
|
||||||
def check_datasets_formats(repo_ids: list) -> None:
|
|
||||||
for repo_id in repo_ids:
|
|
||||||
dataset = LeRobotDataset(repo_id)
|
|
||||||
if len(dataset.meta.video_keys) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_directory_size(directory: Path) -> int:
|
|
||||||
total_size = 0
|
|
||||||
for item in directory.rglob("*"):
|
|
||||||
if item.is_file():
|
|
||||||
total_size += item.stat().st_size
|
|
||||||
return total_size
|
|
||||||
|
|
||||||
|
|
||||||
def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
|
|
||||||
frames = []
|
|
||||||
for ts in timestamps:
|
|
||||||
idx = int(ts * fps)
|
|
||||||
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
|
|
||||||
frame = torch.from_numpy(np.array(frame))
|
|
||||||
frame = frame.type(torch.float32) / 255
|
|
||||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
|
||||||
frames.append(frame)
|
|
||||||
return torch.stack(frames)
|
|
||||||
|
|
||||||
|
|
||||||
def save_decoded_frames(
|
|
||||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
|
||||||
) -> None:
|
|
||||||
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
|
|
||||||
return
|
|
||||||
|
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
for i, ts in enumerate(timestamps):
|
|
||||||
idx = int(ts * fps)
|
|
||||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
|
||||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
|
|
||||||
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
|
|
||||||
|
|
||||||
|
|
||||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
|
||||||
episode_index = 0
|
|
||||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
|
||||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
|
|
||||||
return
|
|
||||||
|
|
||||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
|
||||||
|
|
||||||
# We only save images from the first camera
|
|
||||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
|
||||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
|
||||||
|
|
||||||
for i, item in enumerate(
|
|
||||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
|
||||||
):
|
|
||||||
img = item[img_keys[0]]
|
|
||||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
|
||||||
|
|
||||||
if i >= ep_num_images - 1:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
|
|
||||||
# Start at 5 to allow for 2_frames_4_space and 6_frames
|
|
||||||
idx = random.randint(5, ep_num_images - 1)
|
|
||||||
match timestamps_mode:
|
|
||||||
case "1_frame":
|
|
||||||
frame_indexes = [idx]
|
|
||||||
case "2_frames":
|
|
||||||
frame_indexes = [idx - 1, idx]
|
|
||||||
case "2_frames_4_space":
|
|
||||||
frame_indexes = [idx - 5, idx]
|
|
||||||
case "6_frames":
|
|
||||||
frame_indexes = [idx - i for i in range(6)][::-1]
|
|
||||||
case _:
|
|
||||||
raise ValueError(timestamps_mode)
|
|
||||||
|
|
||||||
return [idx / fps for idx in frame_indexes]
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_decoding(
|
|
||||||
imgs_dir: Path,
|
|
||||||
video_path: Path,
|
|
||||||
timestamps_mode: str,
|
|
||||||
backend: str,
|
|
||||||
ep_num_images: int,
|
|
||||||
fps: int,
|
|
||||||
num_samples: int = 50,
|
|
||||||
num_workers: int = 4,
|
|
||||||
save_frames: bool = False,
|
|
||||||
) -> dict:
|
|
||||||
def process_sample(sample: int, lock: Lock):
|
|
||||||
time_benchmark = TimerManager(log=False)
|
|
||||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
|
||||||
num_frames = len(timestamps)
|
|
||||||
result = {
|
|
||||||
"psnr_values": [],
|
|
||||||
"ssim_values": [],
|
|
||||||
"mse_values": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
with time_benchmark, lock:
|
|
||||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
|
||||||
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
|
|
||||||
|
|
||||||
with time_benchmark:
|
|
||||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
|
||||||
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
|
|
||||||
|
|
||||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
|
||||||
for i in range(num_frames):
|
|
||||||
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
|
|
||||||
result["psnr_values"].append(
|
|
||||||
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
|
|
||||||
)
|
|
||||||
result["ssim_values"].append(
|
|
||||||
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
if save_frames and sample == 0:
|
|
||||||
save_dir = video_path.with_suffix("") / f"{timestamps_mode}_{backend}"
|
|
||||||
save_decoded_frames(imgs_dir, save_dir, frames, timestamps, fps)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
load_times_video_ms = []
|
|
||||||
load_times_images_ms = []
|
|
||||||
mse_values = []
|
|
||||||
psnr_values = []
|
|
||||||
ssim_values = []
|
|
||||||
|
|
||||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
|
||||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
|
||||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
|
||||||
# Use a single shared lock for all worker threads
|
|
||||||
shared_lock = Lock()
|
|
||||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
||||||
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
|
|
||||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
|
||||||
result = future.result()
|
|
||||||
load_times_video_ms.append(result["load_time_video_ms"])
|
|
||||||
load_times_images_ms.append(result["load_time_images_ms"])
|
|
||||||
psnr_values.extend(result["psnr_values"])
|
|
||||||
ssim_values.extend(result["ssim_values"])
|
|
||||||
mse_values.extend(result["mse_values"])
|
|
||||||
|
|
||||||
avg_load_time_video_ms = float(np.array(load_times_video_ms).mean())
|
|
||||||
avg_load_time_images_ms = float(np.array(load_times_images_ms).mean())
|
|
||||||
video_images_load_time_ratio = avg_load_time_video_ms / avg_load_time_images_ms
|
|
||||||
|
|
||||||
return {
|
|
||||||
"avg_load_time_video_ms": avg_load_time_video_ms,
|
|
||||||
"avg_load_time_images_ms": avg_load_time_images_ms,
|
|
||||||
"video_images_load_time_ratio": video_images_load_time_ratio,
|
|
||||||
"avg_mse": float(np.mean(mse_values)),
|
|
||||||
"avg_psnr": float(np.mean(psnr_values)),
|
|
||||||
"avg_ssim": float(np.mean(ssim_values)),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def benchmark_encoding_decoding(
|
|
||||||
dataset: LeRobotDataset,
|
|
||||||
video_path: Path,
|
|
||||||
imgs_dir: Path,
|
|
||||||
encoding_cfg: dict,
|
|
||||||
decoding_cfg: dict,
|
|
||||||
num_samples: int,
|
|
||||||
num_workers: int,
|
|
||||||
save_frames: bool,
|
|
||||||
overwrite: bool = False,
|
|
||||||
seed: int = 1337,
|
|
||||||
) -> list[dict]:
|
|
||||||
fps = dataset.fps
|
|
||||||
|
|
||||||
if overwrite or not video_path.is_file():
|
|
||||||
tqdm.write(f"encoding {video_path}")
|
|
||||||
encode_video_frames(
|
|
||||||
imgs_dir=imgs_dir,
|
|
||||||
video_path=video_path,
|
|
||||||
fps=fps,
|
|
||||||
vcodec=encoding_cfg["vcodec"],
|
|
||||||
pix_fmt=encoding_cfg["pix_fmt"],
|
|
||||||
g=encoding_cfg.get("g"),
|
|
||||||
crf=encoding_cfg.get("crf"),
|
|
||||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
|
||||||
overwrite=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
episode_index = 0
|
|
||||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
|
||||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
|
||||||
num_pixels = width * height
|
|
||||||
video_size_bytes = video_path.stat().st_size
|
|
||||||
images_size_bytes = get_directory_size(imgs_dir)
|
|
||||||
video_images_size_ratio = video_size_bytes / images_size_bytes
|
|
||||||
|
|
||||||
random.seed(seed)
|
|
||||||
benchmark_table = []
|
|
||||||
for timestamps_mode in tqdm(
|
|
||||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
|
||||||
):
|
|
||||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
|
||||||
benchmark_row = benchmark_decoding(
|
|
||||||
imgs_dir,
|
|
||||||
video_path,
|
|
||||||
timestamps_mode,
|
|
||||||
backend,
|
|
||||||
ep_num_images,
|
|
||||||
fps,
|
|
||||||
num_samples,
|
|
||||||
num_workers,
|
|
||||||
save_frames,
|
|
||||||
)
|
|
||||||
benchmark_row.update(
|
|
||||||
**{
|
|
||||||
"repo_id": dataset.repo_id,
|
|
||||||
"resolution": f"{width} x {height}",
|
|
||||||
"num_pixels": num_pixels,
|
|
||||||
"video_size_bytes": video_size_bytes,
|
|
||||||
"images_size_bytes": images_size_bytes,
|
|
||||||
"video_images_size_ratio": video_images_size_ratio,
|
|
||||||
"timestamps_mode": timestamps_mode,
|
|
||||||
"backend": backend,
|
|
||||||
},
|
|
||||||
**encoding_cfg,
|
|
||||||
)
|
|
||||||
benchmark_table.append(benchmark_row)
|
|
||||||
|
|
||||||
return benchmark_table
|
|
||||||
|
|
||||||
|
|
||||||
def main(
|
|
||||||
output_dir: Path,
|
|
||||||
repo_ids: list[str],
|
|
||||||
vcodec: list[str],
|
|
||||||
pix_fmt: list[str],
|
|
||||||
g: list[int],
|
|
||||||
crf: list[int],
|
|
||||||
# fastdecode: list[int],
|
|
||||||
timestamps_modes: list[str],
|
|
||||||
backends: list[str],
|
|
||||||
num_samples: int,
|
|
||||||
num_workers: int,
|
|
||||||
save_frames: bool,
|
|
||||||
):
|
|
||||||
check_datasets_formats(repo_ids)
|
|
||||||
encoding_benchmarks = {
|
|
||||||
"g": g,
|
|
||||||
"crf": crf,
|
|
||||||
# "fastdecode": fastdecode,
|
|
||||||
}
|
|
||||||
decoding_benchmarks = {
|
|
||||||
"timestamps_modes": timestamps_modes,
|
|
||||||
"backends": backends,
|
|
||||||
}
|
|
||||||
headers = ["repo_id", "resolution", "num_pixels"]
|
|
||||||
headers += list(BASE_ENCODING.keys())
|
|
||||||
headers += [
|
|
||||||
"timestamps_mode",
|
|
||||||
"backend",
|
|
||||||
"video_size_bytes",
|
|
||||||
"images_size_bytes",
|
|
||||||
"video_images_size_ratio",
|
|
||||||
"avg_load_time_video_ms",
|
|
||||||
"avg_load_time_images_ms",
|
|
||||||
"video_images_load_time_ratio",
|
|
||||||
"avg_mse",
|
|
||||||
"avg_psnr",
|
|
||||||
"avg_ssim",
|
|
||||||
]
|
|
||||||
file_paths = []
|
|
||||||
for video_codec in tqdm(vcodec, desc="encodings (vcodec)"):
|
|
||||||
for pixel_format in tqdm(pix_fmt, desc="encodings (pix_fmt)", leave=False):
|
|
||||||
benchmark_table = []
|
|
||||||
for repo_id in tqdm(repo_ids, desc="encodings (datasets)", leave=False):
|
|
||||||
dataset = LeRobotDataset(repo_id)
|
|
||||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
|
||||||
# We only use the first episode
|
|
||||||
save_first_episode(imgs_dir, dataset)
|
|
||||||
for duet in [
|
|
||||||
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
|
|
||||||
for unique_combination in itertools.product(*encoding_benchmarks.values())
|
|
||||||
]:
|
|
||||||
encoding_cfg = BASE_ENCODING.copy()
|
|
||||||
encoding_cfg["vcodec"] = video_codec
|
|
||||||
encoding_cfg["pix_fmt"] = pixel_format
|
|
||||||
for key, value in duet.items():
|
|
||||||
encoding_cfg[key] = value
|
|
||||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
|
||||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
|
||||||
benchmark_table += benchmark_encoding_decoding(
|
|
||||||
dataset,
|
|
||||||
video_path,
|
|
||||||
imgs_dir,
|
|
||||||
encoding_cfg,
|
|
||||||
decoding_benchmarks,
|
|
||||||
num_samples,
|
|
||||||
num_workers,
|
|
||||||
save_frames,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save intermediate results
|
|
||||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
|
||||||
now = dt.datetime.now()
|
|
||||||
csv_path = (
|
|
||||||
output_dir
|
|
||||||
/ f"{now:%Y-%m-%d}_{now:%H-%M-%S}_{video_codec}_{pixel_format}_{num_samples}-samples.csv"
|
|
||||||
)
|
|
||||||
benchmark_df.to_csv(csv_path, header=True, index=False)
|
|
||||||
file_paths.append(csv_path)
|
|
||||||
del benchmark_df
|
|
||||||
|
|
||||||
# Concatenate all results
|
|
||||||
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
|
|
||||||
concatenated_df = pd.concat(df_list, ignore_index=True)
|
|
||||||
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
|
|
||||||
concatenated_df.to_csv(concatenated_path, header=True, index=False)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-dir",
|
|
||||||
type=Path,
|
|
||||||
default=Path("outputs/video_benchmark"),
|
|
||||||
help="Directory where the video benchmark outputs are written.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-ids",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
default=[
|
|
||||||
"lerobot/pusht_image",
|
|
||||||
"lerobot/aloha_mobile_shrimp_image",
|
|
||||||
"lerobot/paris_street",
|
|
||||||
"lerobot/kitchen",
|
|
||||||
],
|
|
||||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--vcodec",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
default=["h264", "hevc", "libsvtav1"],
|
|
||||||
help="Video codecs to be tested",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pix-fmt",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
default=["yuv444p", "yuv420p"],
|
|
||||||
help="Pixel formats (chroma subsampling) to be tested",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--g",
|
|
||||||
type=parse_int_or_none,
|
|
||||||
nargs="*",
|
|
||||||
default=[1, 2, 3, 4, 5, 6, 10, 15, 20, 40, 100, None],
|
|
||||||
help="Group of pictures sizes to be tested.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--crf",
|
|
||||||
type=parse_int_or_none,
|
|
||||||
nargs="*",
|
|
||||||
default=[0, 5, 10, 15, 20, 25, 30, 40, 50, None],
|
|
||||||
help="Constant rate factors to be tested.",
|
|
||||||
)
|
|
||||||
# parser.add_argument(
|
|
||||||
# "--fastdecode",
|
|
||||||
# type=int,
|
|
||||||
# nargs="*",
|
|
||||||
# default=[0, 1],
|
|
||||||
# help="Use the fastdecode tuning option. 0 disables it. "
|
|
||||||
# "For libx264 and libx265/hevc, only 1 is possible. "
|
|
||||||
# "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization",
|
|
||||||
# )
|
|
||||||
parser.add_argument(
|
|
||||||
"--timestamps-modes",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
default=[
|
|
||||||
"1_frame",
|
|
||||||
"2_frames",
|
|
||||||
"2_frames_4_space",
|
|
||||||
"6_frames",
|
|
||||||
],
|
|
||||||
help="Timestamps scenarios to be tested.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--backends",
|
|
||||||
type=str,
|
|
||||||
nargs="*",
|
|
||||||
default=["torchcodec", "pyav"],
|
|
||||||
help="Torchvision decoding backend to be tested.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-samples",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="Number of samples for each encoding x decoding config.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-workers",
|
|
||||||
type=int,
|
|
||||||
default=10,
|
|
||||||
help="Number of processes for parallelized sample processing.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save-frames",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Whether to save decoded frames or not. Enter a non-zero number for true.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(**vars(args))
|
|
||||||
@@ -35,7 +35,7 @@ USER root
|
|||||||
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y --no-install-recommends \
|
&& apt-get install -y --no-install-recommends \
|
||||||
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
|
cuda-nvcc-12-8 cuda-cudart-dev-12-8 \
|
||||||
libvulkan1 vulkan-tools \
|
libvulkan1 vulkan-tools \
|
||||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||||
@@ -56,11 +56,11 @@ RUN uv pip install --no-cache --no-build-isolation \
|
|||||||
"git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
"git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
||||||
|
|
||||||
# CuRobo — NVlabs motion generator; TORCH_CUDA_ARCH_LIST must be set or the
|
# CuRobo — NVlabs motion generator; TORCH_CUDA_ARCH_LIST must be set or the
|
||||||
# build aborts on an empty arch list. Pinned SHA for reproducibility.
|
# build aborts on an empty arch list. RoboTwin's own installer pins v0.7.8,
|
||||||
ARG CUROBO_SHA=ca941586c33b8482ed9c0e74d60f23efd64b516a
|
# which still exposes the v1 API (`curobo.types.math`) that RoboTwin imports.
|
||||||
|
ARG CUROBO_REF=v0.7.8
|
||||||
RUN cd ${ROBOTWIN_ROOT}/envs \
|
RUN cd ${ROBOTWIN_ROOT}/envs \
|
||||||
&& git clone https://github.com/NVlabs/curobo.git \
|
&& git clone --branch ${CUROBO_REF} --depth 1 https://github.com/NVlabs/curobo.git \
|
||||||
&& git -C curobo checkout ${CUROBO_SHA} \
|
|
||||||
&& cd curobo \
|
&& cd curobo \
|
||||||
&& TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9;9.0" \
|
&& TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9;9.0" \
|
||||||
uv pip install -e . --no-build-isolation --no-cache
|
uv pip install -e . --no-build-isolation --no-cache
|
||||||
@@ -111,7 +111,23 @@ EOF
|
|||||||
WORKDIR ${ROBOTWIN_ROOT}
|
WORKDIR ${ROBOTWIN_ROOT}
|
||||||
RUN python script/update_embodiment_config_path.py
|
RUN python script/update_embodiment_config_path.py
|
||||||
|
|
||||||
ENV PYTHONPATH="${ROBOTWIN_ROOT}:${PYTHONPATH}"
|
ENV PYTHONPATH="${ROBOTWIN_ROOT}"
|
||||||
|
|
||||||
|
# Fail the image build early if the CuRobo package layout regresses. Importing
|
||||||
|
# RoboTwin's planner here is too eager because CuRobo constructs CUDA-backed
|
||||||
|
# defaults at import time, while Docker builds don't have access to an NVIDIA
|
||||||
|
# driver.
|
||||||
|
RUN python - <<'EOF'
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from curobo.types.math import Pose
|
||||||
|
|
||||||
|
planner_src = (Path("/opt/robotwin/envs/robot/planner.py")).read_text()
|
||||||
|
assert "from curobo.types.math import Pose as CuroboPose" in planner_src
|
||||||
|
|
||||||
|
print("CuRobo import OK:", Pose.__name__)
|
||||||
|
print("RoboTwin planner import references curobo.types.math")
|
||||||
|
EOF
|
||||||
|
|
||||||
# Return to the lerobot source directory (set by base image) before overlaying.
|
# Return to the lerobot source directory (set by base image) before overlaying.
|
||||||
WORKDIR /lerobot
|
WORKDIR /lerobot
|
||||||
|
|||||||
@@ -18,9 +18,8 @@
|
|||||||
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||||
|
|
||||||
# Configure the base image for CI with GPU access
|
# Configure the base image for CI with GPU access
|
||||||
# TODO(Steven): Bump these versions
|
ARG CUDA_VERSION=12.8.1
|
||||||
ARG CUDA_VERSION=12.4.1
|
ARG OS_VERSION=24.04
|
||||||
ARG OS_VERSION=22.04
|
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||||
|
|
||||||
# Define Python version argument
|
# Define Python version argument
|
||||||
@@ -36,16 +35,13 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
|||||||
|
|
||||||
# Install Python, system dependencies, and uv (as root)
|
# Install Python, system dependencies, and uv (as root)
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
software-properties-common build-essential git curl \
|
build-essential git curl \
|
||||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
libglib2.0-0 libgl1 libegl1 ffmpeg \
|
||||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||||
cmake pkg-config ninja-build \
|
cmake pkg-config ninja-build \
|
||||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
python${PYTHON_VERSION} \
|
||||||
&& apt-get update \
|
python${PYTHON_VERSION}-venv \
|
||||||
&& apt-get install -y --no-install-recommends \
|
python${PYTHON_VERSION}-dev \
|
||||||
python${PYTHON_VERSION} \
|
|
||||||
python${PYTHON_VERSION}-venv \
|
|
||||||
python${PYTHON_VERSION}-dev \
|
|
||||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||||
|
|||||||
@@ -8,7 +8,7 @@
|
|||||||
- local: il_robots
|
- local: il_robots
|
||||||
title: Imitation Learning for Robots
|
title: Imitation Learning for Robots
|
||||||
- local: bring_your_own_policies
|
- local: bring_your_own_policies
|
||||||
title: Bring Your Own Policies
|
title: Adding a Policy
|
||||||
- local: integrate_hardware
|
- local: integrate_hardware
|
||||||
title: Bring Your Own Hardware
|
title: Bring Your Own Hardware
|
||||||
- local: hilserl
|
- local: hilserl
|
||||||
@@ -24,6 +24,12 @@
|
|||||||
- local: rename_map
|
- local: rename_map
|
||||||
title: Using Rename Map and Empty Cameras
|
title: Using Rename Map and Empty Cameras
|
||||||
title: "Tutorials"
|
title: "Tutorials"
|
||||||
|
- sections:
|
||||||
|
- local: hardware_guide
|
||||||
|
title: Compute Hardware Guide
|
||||||
|
- local: torch_accelerators
|
||||||
|
title: PyTorch accelerators
|
||||||
|
title: "Compute & Hardware"
|
||||||
- sections:
|
- sections:
|
||||||
- local: lerobot-dataset-v3
|
- local: lerobot-dataset-v3
|
||||||
title: Using LeRobotDataset
|
title: Using LeRobotDataset
|
||||||
@@ -47,6 +53,8 @@
|
|||||||
title: π₀-FAST (Pi0Fast)
|
title: π₀-FAST (Pi0Fast)
|
||||||
- local: pi05
|
- local: pi05
|
||||||
title: π₀.₅ (Pi05)
|
title: π₀.₅ (Pi05)
|
||||||
|
- local: eo1
|
||||||
|
title: EO-1
|
||||||
- local: groot
|
- local: groot
|
||||||
title: NVIDIA GR00T N1.5
|
title: NVIDIA GR00T N1.5
|
||||||
- local: xvla
|
- local: xvla
|
||||||
@@ -61,6 +69,8 @@
|
|||||||
title: SARM
|
title: SARM
|
||||||
title: "Reward Models"
|
title: "Reward Models"
|
||||||
- sections:
|
- sections:
|
||||||
|
- local: inference
|
||||||
|
title: Policy Deployment (lerobot-rollout)
|
||||||
- local: async
|
- local: async
|
||||||
title: Use Async Inference
|
title: Use Async Inference
|
||||||
- local: rtc
|
- local: rtc
|
||||||
@@ -138,10 +148,6 @@
|
|||||||
- local: cameras
|
- local: cameras
|
||||||
title: Cameras
|
title: Cameras
|
||||||
title: "Sensors"
|
title: "Sensors"
|
||||||
- sections:
|
|
||||||
- local: torch_accelerators
|
|
||||||
title: PyTorch accelerators
|
|
||||||
title: "Supported Hardware"
|
|
||||||
- sections:
|
- sections:
|
||||||
- local: notebooks
|
- local: notebooks
|
||||||
title: Notebooks
|
title: Notebooks
|
||||||
|
|||||||
@@ -1,60 +1,37 @@
|
|||||||
# Bring Your Own Policies
|
# Adding a Policy
|
||||||
|
|
||||||
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
This guide walks you through implementing a custom policy and getting it to work with LeRobot's training, evaluation, and deployment tools. There are two paths:
|
||||||
|
|
||||||
## Step 1: Create a Policy Package
|
- **Plugin (out-of-tree)** — ship your policy as a standalone `lerobot_policy_*` package. Faster, no PR required, easy to iterate. Right for experimentation, internal use, or when you want to publish independently.
|
||||||
|
- **In-tree (contributed to LeRobot)** — land your policy directly in `src/lerobot/policies/`. Requires a PR, but makes your policy a first-class citizen of the library.
|
||||||
|
|
||||||
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
The plugin route is usually the right starting point — promote to in-tree once the policy has stabilized and there's clear value in shipping it with the library.
|
||||||
|
|
||||||
### Package Structure
|
Either way, the building blocks are the same: a configuration class, a policy class, and a processor factory. The first half of this guide covers those shared pieces; the second half covers the path-specific scaffolding ([Path A](#path-a-out-of-tree-plugin), [Path B](#path-b-contributing-in-tree)).
|
||||||
|
|
||||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
A note on tone: robot-learning is an actively evolving field, and "what a policy looks like" can shift with each new architecture. The conventions described here exist because they let `lerobot-train` and `lerobot-eval` work uniformly across very different models. When a new policy genuinely doesn't fit them, raise it (in your PR, or an issue) — the conventions are not sacred.
|
||||||
|
|
||||||
```bash
|
---
|
||||||
lerobot_policy_my_custom_policy/
|
|
||||||
├── pyproject.toml
|
|
||||||
└── src/
|
|
||||||
└── lerobot_policy_my_custom_policy/
|
|
||||||
├── __init__.py
|
|
||||||
├── configuration_my_custom_policy.py
|
|
||||||
├── modeling_my_custom_policy.py
|
|
||||||
└── processor_my_custom_policy.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### Package Configuration
|
## Anatomy of a policy
|
||||||
|
|
||||||
Set up your `pyproject.toml`:
|
Three building blocks make up every policy. The names below use `my_policy` as a placeholder — replace with your policy's name. That name is load-bearing: it must match the string you pass to `@PreTrainedConfig.register_subclass`, the `MyPolicy.name` class attribute, and the `make_<name>_pre_post_processors` factory function (more on each below).
|
||||||
|
|
||||||
```toml
|
### Configuration class
|
||||||
[project]
|
|
||||||
name = "lerobot_policy_my_custom_policy"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
# your policy-specific dependencies
|
|
||||||
]
|
|
||||||
requires-python = ">= 3.12"
|
|
||||||
|
|
||||||
[build-system]
|
Inherit from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and register your policy type. Here is a template — customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||||
build-backend = # your-build-backend
|
|
||||||
requires = # your-build-system
|
|
||||||
```
|
|
||||||
|
|
||||||
## Step 2: Define the Policy Configuration
|
|
||||||
|
|
||||||
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
|
|
||||||
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# configuration_my_custom_policy.py
|
# configuration_my_policy.py
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from lerobot.configs import PreTrainedConfig
|
from lerobot.configs import PreTrainedConfig
|
||||||
from lerobot.optim import AdamWConfig
|
from lerobot.optim import AdamWConfig
|
||||||
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
|
from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
@PreTrainedConfig.register_subclass("my_policy")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MyCustomPolicyConfig(PreTrainedConfig):
|
class MyPolicyConfig(PreTrainedConfig):
|
||||||
"""Configuration class for MyCustomPolicy.
|
"""Configuration class for MyPolicy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of observation steps to use as input
|
n_obs_steps: Number of observation steps to use as input
|
||||||
@@ -77,16 +54,20 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
|||||||
raise ValueError("n_action_steps cannot exceed horizon")
|
raise ValueError("n_action_steps cannot exceed horizon")
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate input/output feature compatibility."""
|
"""Validate input/output feature compatibility.
|
||||||
|
|
||||||
|
Call this explicitly from your policy's __init__ — the base class does not.
|
||||||
|
"""
|
||||||
if not self.image_features:
|
if not self.image_features:
|
||||||
raise ValueError("MyCustomPolicy requires at least one image feature.")
|
raise ValueError("MyPolicy requires at least one image feature.")
|
||||||
if self.action_feature is None:
|
if self.action_feature is None:
|
||||||
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
|
raise ValueError("MyPolicy requires 'action' in output_features.")
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> AdamWConfig:
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||||
|
|
||||||
def get_scheduler_preset(self):
|
def get_scheduler_preset(self):
|
||||||
|
"""Return a LRSchedulerConfig from lerobot.optim, or None."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -101,8 +82,7 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list[int]:
|
def action_delta_indices(self) -> list[int]:
|
||||||
"""Relative timestep offsets for the action chunk the dataset loader returns.
|
"""Relative timestep offsets for the action chunk the dataset loader returns."""
|
||||||
"""
|
|
||||||
return list(range(self.horizon))
|
return list(range(self.horizon))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -110,32 +90,34 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
|||||||
return None
|
return None
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 3: Implement the Policy Class
|
The string you pass to `@register_subclass` must match `MyPolicy.name` (next section) and is what users supply as `--policy.type` on the CLI. Default to `AdamW` from `lerobot.optim` for `get_optimizer_preset` unless you genuinely need otherwise.
|
||||||
|
|
||||||
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
|
### Policy class
|
||||||
|
|
||||||
|
Inherit from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py) and set two class attributes — both are checked by `__init_subclass__`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# modeling_my_custom_policy.py
|
# modeling_my_policy.py
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.policies import PreTrainedPolicy
|
from lerobot.policies import PreTrainedPolicy
|
||||||
from lerobot.utils.constants import ACTION
|
from lerobot.utils.constants import ACTION
|
||||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
from .configuration_my_policy import MyPolicyConfig
|
||||||
|
|
||||||
class MyCustomPolicy(PreTrainedPolicy):
|
class MyPolicy(PreTrainedPolicy):
|
||||||
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
|
config_class = MyPolicyConfig # must match the string in @register_subclass
|
||||||
name = "my_custom_policy"
|
name = "my_policy"
|
||||||
|
|
||||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
def __init__(self, config: MyPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||||
super().__init__(config, dataset_stats)
|
super().__init__(config, dataset_stats)
|
||||||
config.validate_features() # not called automatically by the base class
|
config.validate_features() # not called automatically by the base class
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = ... # your nn.Module here
|
self.model = ... # your nn.Module here
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset episode state."""
|
"""Reset per-episode state. Called by lerobot-eval at the start of each episode."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
@@ -147,35 +129,51 @@ class MyCustomPolicy(PreTrainedPolicy):
|
|||||||
...
|
...
|
||||||
|
|
||||||
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||||
"""Return a single action for the current timestep (called at inference)."""
|
"""Return a single action for the current timestep (called every step at inference)."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
def forward(self, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, dict | None]:
|
||||||
"""Compute the training loss.
|
"""Compute the training loss.
|
||||||
|
|
||||||
|
Returns `(loss, output_dict)`. `output_dict` may be `None`; everything in it must be
|
||||||
|
logging-friendly Python natives (no tensors with gradients).
|
||||||
|
|
||||||
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
||||||
timesteps padded because the episode ended before `horizon` steps, you
|
timesteps padded because the episode ended before `horizon` steps; you
|
||||||
can exclude those from your loss.
|
can exclude those from your loss.
|
||||||
"""
|
"""
|
||||||
actions = batch[ACTION]
|
actions = batch[ACTION]
|
||||||
action_is_pad = batch.get("action_is_pad")
|
action_is_pad = batch.get("action_is_pad")
|
||||||
...
|
...
|
||||||
return {"loss": ...}
|
return loss, {"some_loss_component": some_loss_component.item()}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 4: Add Data Processors
|
The methods called by the train/eval loops:
|
||||||
|
|
||||||
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
| Method | Used by | What it does |
|
||||||
|
| ----------------------------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `reset() -> None` | `lerobot-eval` | Clear per-episode state at the start of each episode. |
|
||||||
|
| `select_action(batch, **kwargs) -> Tensor` | `lerobot-eval` | Return the next action `(B, action_dim)`. Called every step. |
|
||||||
|
| `predict_action_chunk(batch, **kwargs) -> Tensor` | the policy itself | Return an action chunk `(B, chunk_size, action_dim)`. Currently abstract on the base class — raise `NotImplementedError` if your policy doesn't chunk. |
|
||||||
|
| `forward(batch, reduction="mean") -> tuple[Tensor, dict \| None]` | `lerobot-train` | Return `(loss, output_dict)`. Accept `reduction="none"` if you want to support per-sample weighting. |
|
||||||
|
| `get_optim_params() -> dict` | the optimizer | Return `self.parameters()` for simple policies; return a named parameter dict for [multi-optimizer policies](https://github.com/huggingface/lerobot/blob/ecd38c50d7d15b4184cf42649ff1185ee2e11eeb/src/lerobot/policies/sac/modeling_sac.py#L61-L73). |
|
||||||
|
| `update() -> None` _(optional)_ | `lerobot-train` | Called after each optimizer step _if defined_. Use for EMA, target nets, replay buffers (TDMPC uses this). |
|
||||||
|
|
||||||
|
Batches are flat dictionaries keyed by the constants in [`lerobot.utils.constants`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/utils/constants.py): `OBS_STATE` (`observation.state.<motor>`), `OBS_IMAGES` (`observation.images.<camera>`), `OBS_LANGUAGE`, `ACTION`, etc. Reuse the constants — don't invent new prefixes.
|
||||||
|
|
||||||
|
### Processor functions
|
||||||
|
|
||||||
|
LeRobot uses `PolicyProcessorPipeline`s to normalize inputs and de-normalize outputs around your policy. For a concrete reference, see [`processor_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [`processor_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# processor_my_custom_policy.py
|
# processor_my_policy.py
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
|
|
||||||
|
|
||||||
def make_my_custom_policy_pre_post_processors(
|
def make_my_policy_pre_post_processors(
|
||||||
config,
|
config,
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
@@ -187,11 +185,48 @@ def make_my_custom_policy_pre_post_processors(
|
|||||||
return preprocessor, postprocessor
|
return preprocessor, postprocessor
|
||||||
```
|
```
|
||||||
|
|
||||||
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
**Important — function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||||
|
|
||||||
## Step 5: Package Initialization
|
---
|
||||||
|
|
||||||
Expose your classes in the package's `__init__.py`:
|
## Path A: Out-of-tree plugin
|
||||||
|
|
||||||
|
The fastest way to ship a policy: package it as a standalone Python distribution and install it alongside LeRobot. No PR required, you own the release cycle, and you can publish to PyPI under your own namespace.
|
||||||
|
|
||||||
|
### Package structure
|
||||||
|
|
||||||
|
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot_policy_my_policy/
|
||||||
|
├── pyproject.toml
|
||||||
|
└── src/
|
||||||
|
└── lerobot_policy_my_policy/
|
||||||
|
├── __init__.py
|
||||||
|
├── configuration_my_policy.py
|
||||||
|
├── modeling_my_policy.py
|
||||||
|
└── processor_my_policy.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### `pyproject.toml`
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project]
|
||||||
|
name = "lerobot_policy_my_policy"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
# your policy-specific dependencies
|
||||||
|
]
|
||||||
|
requires-python = ">= 3.12"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
build-backend = # your-build-backend
|
||||||
|
requires = # your-build-system
|
||||||
|
```
|
||||||
|
|
||||||
|
### Package `__init__.py`
|
||||||
|
|
||||||
|
Expose your classes in the package's `__init__.py` and guard against missing `lerobot`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# __init__.py
|
# __init__.py
|
||||||
@@ -204,44 +239,148 @@ except ImportError:
|
|||||||
"lerobot is not installed. Please install lerobot to use this policy package."
|
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||||
)
|
)
|
||||||
|
|
||||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
from .configuration_my_policy import MyPolicyConfig
|
||||||
from .modeling_my_custom_policy import MyCustomPolicy
|
from .modeling_my_policy import MyPolicy
|
||||||
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
from .processor_my_policy import make_my_policy_pre_post_processors
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MyCustomPolicyConfig",
|
"MyPolicyConfig",
|
||||||
"MyCustomPolicy",
|
"MyPolicy",
|
||||||
"make_my_custom_policy_pre_post_processors",
|
"make_my_policy_pre_post_processors",
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 6: Installation and Usage
|
### Install and use
|
||||||
|
|
||||||
### Install Your Policy Package
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd lerobot_policy_my_custom_policy
|
cd lerobot_policy_my_policy
|
||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
||||||
# Or install from PyPI if published
|
# Or install from PyPI if published
|
||||||
pip install lerobot_policy_my_custom_policy
|
pip install lerobot_policy_my_policy
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use Your Policy
|
|
||||||
|
|
||||||
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
--policy.type my_custom_policy \
|
--policy.type my_policy \
|
||||||
--env.type pusht \
|
--env.type pusht \
|
||||||
--steps 200000
|
--steps 200000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Examples and Community Contributions
|
---
|
||||||
|
|
||||||
|
## Path B: Contributing in-tree
|
||||||
|
|
||||||
|
When your policy has stabilized and there's clear value in shipping it with the library, you can land it directly in LeRobot. Read the general [contribution guide](./contributing) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md) first — that's where you'll find the testing/quality expectations every PR has to meet (`pre-commit run -a`, `pytest`, the community-review rule, etc.). What's below is the policy-specific layer on top of that.
|
||||||
|
|
||||||
|
### In-tree layout
|
||||||
|
|
||||||
|
```
|
||||||
|
src/lerobot/policies/my_policy/
|
||||||
|
├── __init__.py # re-exports config + modeling + processor factory
|
||||||
|
├── configuration_my_policy.py # MyPolicyConfig + @register_subclass
|
||||||
|
├── modeling_my_policy.py # MyPolicy(PreTrainedPolicy)
|
||||||
|
├── processor_my_policy.py # make_my_policy_pre_post_processors
|
||||||
|
└── README.md # symlink → ../../../../docs/source/policy_my_policy_README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
Two notes:
|
||||||
|
|
||||||
|
- The `README.md` next to the source is a **symlink** into `docs/source/policy_<name>_README.md` — the actual file lives under `docs/`. Existing policies (act, smolvla, diffusion, …) all do this; copy one of those symlinks. The policy README is conventionally minimal: paper link + BibTeX citation.
|
||||||
|
- The user-facing tutorial — what to install, how to train, hyperparameters, benchmark numbers — lives separately at `docs/source/<my_policy>.mdx` and is registered in `_toctree.yml` under "Policies".
|
||||||
|
|
||||||
|
The file names are load-bearing: the factory does lazy imports by name, and the processor is discovered by the `make_<policy_name>_pre_post_processors` convention.
|
||||||
|
|
||||||
|
### Wiring
|
||||||
|
|
||||||
|
Three places need to know about your policy. All by name.
|
||||||
|
|
||||||
|
1. **`policies/__init__.py`** — re-export `MyPolicyConfig` and add it to `__all__`. **Don't** re-export the modeling class; it loads lazily through the factory (so `import lerobot` stays fast).
|
||||||
|
2. **`factory.py:get_policy_class`** — add a branch returning `MyPolicy` from a lazy import.
|
||||||
|
3. **`factory.py:make_policy_config`** and **`factory.py:make_pre_post_processors`** — same idea, two more branches.
|
||||||
|
|
||||||
|
Mirror an existing policy that's structurally similar to yours; the diff is small.
|
||||||
|
|
||||||
|
### Heavy / optional dependencies
|
||||||
|
|
||||||
|
Most policies need a heavy backbone (transformers, diffusers, a specific VLM SDK). The convention is **two-step gating**: a `TYPE_CHECKING`-guarded import at module top, and a `require_package` runtime check in the constructor. [`modeling_diffusion.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/modeling_diffusion.py) is the canonical reference:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from lerobot.utils.import_utils import _diffusers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _diffusers_available:
|
||||||
|
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
|
else:
|
||||||
|
DDIMScheduler = None # keeps the symbol bindable at import time
|
||||||
|
|
||||||
|
class DiffusionPolicy(PreTrainedPolicy):
|
||||||
|
def __init__(self, config):
|
||||||
|
require_package("diffusers", extra="diffusion")
|
||||||
|
super().__init__(config)
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
This way:
|
||||||
|
|
||||||
|
- `import lerobot.policies` keeps working without the extra installed (the symbol is just bound to `None`).
|
||||||
|
- Type checkers see the real symbol.
|
||||||
|
- Instantiating the policy without the extra raises a clear `ImportError` pointing at `pip install 'lerobot[diffusion]'`.
|
||||||
|
|
||||||
|
Add a matching extra to [`pyproject.toml`](https://github.com/huggingface/lerobot/blob/main/pyproject.toml) `[project.optional-dependencies]` and include it in the `all` extra so `pip install 'lerobot[all]'` keeps installing everything.
|
||||||
|
|
||||||
|
### Benchmarks and a published checkpoint
|
||||||
|
|
||||||
|
A new policy is much easier to review — and far more useful — when it ships with a working checkpoint and at least one number you can reproduce.
|
||||||
|
|
||||||
|
**Pick at least one in-tree benchmark.** LeRobot ships sim benchmarks with per-benchmark Docker images (LIBERO, LIBERO-plus, Meta-World, RoboTwin 2.0, RoboCasa365, RoboCerebra, RoboMME, VLABench and more). Pick the one that matches your policy's modality — VLAs usually go to LIBERO or VLABench; image-only BC to LIBERO or Meta-World. The full list lives under [Benchmarks](./libero) in the docs sidebar.
|
||||||
|
|
||||||
|
**Push the checkpoint & processors** to the Hub under `lerobot/<policy>_<benchmark>` (or your namespace if you don't have write access; a maintainer can mirror it). Use `PreTrainedPolicy.push_model_to_hub` so the repo gets `config.json`, `model.safetensors`, and a model card.
|
||||||
|
|
||||||
|
**Report results in your policy's MDX**, with the exact `lerobot-eval` command and hardware so anyone can re-run:
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
## Results
|
||||||
|
|
||||||
|
Evaluated on LIBERO with `lerobot/<policy>_libero`:
|
||||||
|
|
||||||
|
| Suite | Success rate | n_episodes |
|
||||||
|
| -------------- | -----------: | ---------: |
|
||||||
|
| libero_spatial | 87.5% | 50 |
|
||||||
|
| libero_object | 93.0% | 50 |
|
||||||
|
| libero_goal | 81.5% | 50 |
|
||||||
|
| libero_10 | 62.0% | 50 |
|
||||||
|
| **average** | **81.0%** | 200 |
|
||||||
|
|
||||||
|
Reproduce: `lerobot-eval --policy.path=lerobot/<policy>_libero --env.type=libero --env.task=libero_spatial --eval.n_episodes=50` (1× A100 40 GB).
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `n_episodes ≥ 50` per suite for stable success-rate estimates.
|
||||||
|
|
||||||
|
If your policy is real-robot-only and no sim benchmark applies, swap the sim eval for: a public training dataset on the Hub, the `lerobot-train` command, the checkpoint, and a real-robot success rate over ≥10 episodes via `lerobot-rollout --policy.path=...`.
|
||||||
|
|
||||||
|
### PR checklist
|
||||||
|
|
||||||
|
The general expectations are in [`CONTRIBUTING.md`](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md) and the [PR template](https://github.com/huggingface/lerobot/blob/main/.github/PULL_REQUEST_TEMPLATE.md). On top of those, reviewers will look for:
|
||||||
|
|
||||||
|
- [ ] `MyPolicy` and `MyPolicyConfig` cover the surface above; `__init_subclass__` accepts the class.
|
||||||
|
- [ ] `factory.py` and `policies/__init__.py` are wired (lazy imports for modeling).
|
||||||
|
- [ ] `make_my_policy_pre_post_processors` follows the naming convention.
|
||||||
|
- [ ] Optional deps live behind a `[project.optional-dependencies]` extra and the `TYPE_CHECKING + require_package` guard.
|
||||||
|
- [ ] `tests/policies/` updated; backward-compat artifact committed & policy-specific tests.
|
||||||
|
- [ ] `src/lerobot/policies/<name>/README.md` symlinked into `docs/source/policy_<name>_README.md`; user-facing `docs/source/<name>.mdx` written and added to `_toctree.yml`.
|
||||||
|
- [ ] At least one reproducible benchmark eval in the policy MDX with a published checkpoint (sim benchmark, or real-robot dataset + checkpoint).
|
||||||
|
|
||||||
|
The fastest way to get a clean PR is to copy the directory of the existing policy closest to yours, rename, and replace contents method by method. Don't wait until everything is polished — open a draft PR early and iterate with us; reviewers would much rather give feedback on a half-finished branch than a fully-merged one.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Examples and community contributions
|
||||||
|
|
||||||
Check out these example policy implementations:
|
Check out these example policy implementations:
|
||||||
|
|
||||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) — Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||||
|
|
||||||
Share your policy implementations with the community! 🤗
|
Thanks for taking the time to bring a new policy into LeRobot. Every architecture that lands in `main` — and every plugin published by the community — makes the library a little more useful for the next person, and a little more representative of where robot learning is going. We're looking forward to seeing what you ship. 🤗
|
||||||
|
|||||||
168
docs/source/eo1.mdx
Normal file
168
docs/source/eo1.mdx
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
# EO-1
|
||||||
|
|
||||||
|
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|
||||||
|
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
|
||||||
|
|
||||||
|
<img
|
||||||
|
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
|
||||||
|
alt="An overview of EO-1"
|
||||||
|
width="85%"
|
||||||
|
/>
|
||||||
|
|
||||||
|
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
|
||||||
|
|
||||||
|
### What the LeRobot Integration Covers
|
||||||
|
|
||||||
|
- Standard `policy.type=eo1` configuration through LeRobot
|
||||||
|
- Qwen2.5-VL image and text preprocessing through policy processors
|
||||||
|
- Continuous flow-matching action prediction
|
||||||
|
- Checkpoint save/load through LeRobot policy APIs
|
||||||
|
- Training with `lerobot-train` and evaluation with `lerobot-eval`
|
||||||
|
|
||||||
|
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
|
||||||
|
|
||||||
|
## Installation Requirements
|
||||||
|
|
||||||
|
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||||
|
2. Install EO-1 dependencies by running:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[eo1]"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[eo1,libero]"
|
||||||
|
```
|
||||||
|
|
||||||
|
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
|
||||||
|
|
||||||
|
## Data Requirements
|
||||||
|
|
||||||
|
EO-1 expects a LeRobot dataset with:
|
||||||
|
|
||||||
|
- At least one visual observation, for example `observation.images.image`
|
||||||
|
- `observation.state`
|
||||||
|
- `action`
|
||||||
|
- A language task instruction through the dataset `task` field
|
||||||
|
|
||||||
|
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use EO-1 in a LeRobot configuration, specify the policy type as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.type=eo1
|
||||||
|
```
|
||||||
|
|
||||||
|
By default, a new EO-1 policy initializes its backbone from:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
Once a LeRobot-format EO-1 checkpoint is available, load it with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
policy.path=your-org/your-eo1-checkpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
### Training Command Example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=your_org/your_dataset \
|
||||||
|
--policy.type=eo1 \
|
||||||
|
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||||
|
--policy.dtype=bfloat16 \
|
||||||
|
--policy.attn_implementation=sdpa \
|
||||||
|
--policy.gradient_checkpointing=false \
|
||||||
|
--output_dir=./outputs/eo1_training \
|
||||||
|
--job_name=eo1_training \
|
||||||
|
--steps=300000 \
|
||||||
|
--batch_size=16 \
|
||||||
|
--policy.device=cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Training Parameters
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
|
||||||
|
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
|
||||||
|
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
|
||||||
|
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
|
||||||
|
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
|
||||||
|
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
|
||||||
|
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
|
||||||
|
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
|
||||||
|
| `policy.max_state_dim` | `32` | State padding dimension |
|
||||||
|
| `policy.max_action_dim` | `32` | Action padding dimension |
|
||||||
|
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
|
||||||
|
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
|
||||||
|
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=your-org/your-eo1-checkpoint \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_object \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=20
|
||||||
|
```
|
||||||
|
|
||||||
|
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=your-org/your-eo1-checkpoint \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_object \
|
||||||
|
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Notes
|
||||||
|
|
||||||
|
### Image Processing
|
||||||
|
|
||||||
|
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
|
||||||
|
|
||||||
|
### State and Action Dimensions
|
||||||
|
|
||||||
|
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
|
||||||
|
|
||||||
|
### Attention Backend
|
||||||
|
|
||||||
|
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [EO-1 project](https://github.com/EO-Robotics/EO1)
|
||||||
|
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
|
||||||
|
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{eo1,
|
||||||
|
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
|
||||||
|
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
|
||||||
|
journal={arXiv preprint},
|
||||||
|
year={2025},
|
||||||
|
url={https://arxiv.org/abs/2508.21112}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.
|
||||||
98
docs/source/hardware_guide.mdx
Normal file
98
docs/source/hardware_guide.mdx
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# Compute HW Guide for LeRobot Training
|
||||||
|
|
||||||
|
Rough sizing for training a LeRobot policy: how much VRAM each policy needs, what training time looks like, and where to run when local hardware isn't enough.
|
||||||
|
|
||||||
|
The numbers below are **indicative** — order-of-magnitude figures for picking hardware, not exact predictions. Throughput depends heavily on dataset I/O, image resolution, batch size, and number of GPUs.
|
||||||
|
|
||||||
|
## Memory by policy group
|
||||||
|
|
||||||
|
Policies cluster by backbone size; the groupings below give a single VRAM envelope per group instead of repeating numbers per policy. Memory scales roughly linearly with batch size; AdamW (the LeRobot default) carries optimizer state that adds ~30–100% over a forward+backward pass alone.
|
||||||
|
|
||||||
|
| Group | Policies | Peak VRAM (BS 8, AdamW) | Suitable starter GPUs |
|
||||||
|
| ---------- | ------------------------------------------- | ----------------------: | --------------------------------- |
|
||||||
|
| Light BC | `act`, `vqbet`, `tdmpc` | ~2–6GB | Laptop GPU (RTX 3060), L4, A10G |
|
||||||
|
| Diffusion | `diffusion`, `multi_task_dit` | ~8–14GB | RTX 4070+ / L4 / A10G |
|
||||||
|
| Small VLA | `smolvla` | ~10–16GB | RTX 4080+ / L4 / A10G |
|
||||||
|
| Large VLA | `pi0`, `pi0_fast`, `pi05`, `xvla`, `wall_x` | ~24–40GB | A100 40 GB+ (24 GB tight at BS 1) |
|
||||||
|
| Multimodal | `groot`, `eo1` | ~24–40GB | A100 40 GB+ |
|
||||||
|
| RL | `sac` | config-dep. | See [HIL-SERL guide](./hilserl) |
|
||||||
|
|
||||||
|
Memory-bound? Drop the batch size (~linear), use gradient accumulation to recover effective batch, or for SmolVLA leave `freeze_vision_encoder=True`.
|
||||||
|
|
||||||
|
## Training time
|
||||||
|
|
||||||
|
Robotics imitation learning typically converges in **5–10 epochs over the dataset**, not hundreds of thousands of raw steps. Once you know your epoch count, wall-clock is essentially:
|
||||||
|
|
||||||
|
```text
|
||||||
|
total_frames = sum of frames over all episodes # 50 ep × 30 fps × 30 s ≈ 45,000
|
||||||
|
steps_per_epoch = ceil(total_frames / (num_gpus × batch_size))
|
||||||
|
total_steps = epochs × steps_per_epoch
|
||||||
|
wall_clock ≈ total_steps × per_step_time
|
||||||
|
```
|
||||||
|
|
||||||
|
Per-step time depends on the policy and the GPU. The numbers in the table below are anchors — pick the row closest to your setup and scale linearly with `total_steps` if you train longer or shorter.
|
||||||
|
|
||||||
|
### Common scenarios
|
||||||
|
|
||||||
|
Indicative wall-clock for **5 epochs on a ~50-episode dataset (~45k frames at 30 fps × 30 s)**, default optimizer (AdamW), 640×480 images:
|
||||||
|
|
||||||
|
| Setup | Policy | Batch | Wall-clock |
|
||||||
|
| ------------------------------------ | -------------- | ----- | ---------: |
|
||||||
|
| Single RTX 4090 / RTX 3090 (24 GB) | `act` | 8 | ~30–60min |
|
||||||
|
| Single RTX 4090 / RTX 3090 (24 GB) | `diffusion` | 8 | ~2–4h |
|
||||||
|
| Single L4 / A10G (24 GB) | `act` | 8 | ~1–2h |
|
||||||
|
| Single L4 / A10G (24 GB) | `smolvla` | 4 | ~3–6h |
|
||||||
|
| Single A100 40 GB | `smolvla` | 16 | ~1–2h |
|
||||||
|
| Single A100 40 GB | `pi0` / `pi05` | 4 | ~4–8h |
|
||||||
|
| 4× H100 80 GB cluster (`accelerate`) | `diffusion` | 32 | ~30–60min |
|
||||||
|
| 4× H100 80 GB cluster (`accelerate`) | `smolvla` | 32 | ~1–2h |
|
||||||
|
| Apple Silicon M1/M2/M3 Max (MPS) | `act` | 4 | ~6–14h |
|
||||||
|
|
||||||
|
These are order-of-magnitude figures. Real runs deviate by ±50% depending on image resolution, dataset I/O, dataloader threading, and exact GPU SKU. They are useful as "is this run going to take an hour or a day?" intuition, not as SLAs.
|
||||||
|
|
||||||
|
### Multi-GPU matters a lot
|
||||||
|
|
||||||
|
`accelerate launch --num_processes=N` is the easiest way to cut training time. Each optimizer step processes `N × batch_size` samples in roughly the same wall-clock as a single-GPU step, so 4 GPUs ≈ 4× speedup for compute-bound runs. See the [Multi GPU training](./multi_gpu_training) guide for the full setup.
|
||||||
|
|
||||||
|
Reference data points on a 4×H100 80 GB cluster (`accelerate launch --num_processes=4`), 5000 steps, batch 32, AdamW, dataset [`imstevenpmwork/super_poulain_draft`](https://huggingface.co/datasets/imstevenpmwork/super_poulain_draft) (~50 episodes, ~640×480 images):
|
||||||
|
|
||||||
|
| Policy | Wall-clock | `update_s` | `dataloading_s` | GPU util | Notable flags |
|
||||||
|
| ----------- | ---------- | ---------: | --------------: | -------- | ------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| `diffusion` | 16m 17s | 0.167 | 0.015 | ~90% | defaults (training from scratch) |
|
||||||
|
| `smolvla` | 27m 49s | 0.312 | 0.011 | ~80% | `--policy.path=lerobot/smolvla_base`, `freeze_vision_encoder=false`, `train_expert_only=false` |
|
||||||
|
| `pi05` | 3h 41m | 2.548 | 0.014 | ~95% | `--policy.pretrained_path=lerobot/pi05_base`, `gradient_checkpointing=true`, `dtype=bfloat16`, vision encoder + expert trained |
|
||||||
|
|
||||||
|
The `dataloading_s` vs. `update_s` ratio is the diagnostic that matters: when `dataloading_s` approaches `update_s`, more GPUs stop helping — your dataloader is the bottleneck and you should look at `--num_workers`, image resolution, and disk speed before adding compute.
|
||||||
|
|
||||||
|
### Schedule and checkpoints
|
||||||
|
|
||||||
|
If you shorten training (e.g. 5k–10k steps on a small dataset), also shorten the LR schedule with `--policy.scheduler_decay_steps≈--steps`. Otherwise the LR stays near its peak and never decays. Same for `--save_freq`.
|
||||||
|
|
||||||
|
## Where to run
|
||||||
|
|
||||||
|
VRAM is the first filter. Within a tier, pick by budget and availability — the `$`–`$$$$` columns are relative; check current pricing on the provider you actually use.
|
||||||
|
|
||||||
|
| Class | VRAM | Tier | Comfortable for |
|
||||||
|
| -------------------------- | ----- | ------ | ----------------------------------------------------------- |
|
||||||
|
| RTX 3090 / 4090 (consumer) | 24 GB | `$` | Light BC, Diffusion, SmolVLA. Tight for VLAs at batch 1. |
|
||||||
|
| L4 / A10G (cloud) | 24 GB | `$–$$` | Same envelope; common on Google Cloud, RunPod, AWS `g5/g6`. |
|
||||||
|
| A100 40 GB | 40 GB | `$$$` | Any policy at reasonable batch sizes. |
|
||||||
|
| A100 80 GB / H100 80 GB | 80 GB | `$$$$` | Multi-GPU clusters; large batches for VLAs. |
|
||||||
|
| **CPU only** | — | — | Don't train. Use Colab or rent a GPU. |
|
||||||
|
|
||||||
|
### Hugging Face Jobs
|
||||||
|
|
||||||
|
[Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) lets you run training on managed HF infrastructure, billed by the second. The repo publishes a ready-to-use image: **`huggingface/lerobot-gpu:latest`**, rebuilt **every night at 02:00 UTC from `main`** ([`docker_publish.yml`](https://github.com/huggingface/lerobot/blob/main/.github/workflows/docker_publish.yml)) — so it tracks the current state of the repo, not a tagged release.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
hf jobs run --flavor a10g-large huggingface/lerobot-gpu:latest \
|
||||||
|
bash -c "nvidia-smi && lerobot-train \
|
||||||
|
--policy.type=act --dataset.repo_id=<USER>/<DATASET> \
|
||||||
|
--policy.repo_id=<USER>/act_<task> --batch_size=8 --steps=50000"
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
|
||||||
|
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
|
||||||
|
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
|
||||||
@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
|
|||||||
|
|
||||||
### Teleoperator Requirements
|
### Teleoperator Requirements
|
||||||
|
|
||||||
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
|
The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
|
||||||
|
|
||||||
- Enable/disable torque programmatically
|
- Enable/disable torque programmatically
|
||||||
- Move to target positions (to mirror the robot state when pausing)
|
- Move to target positions (to mirror the robot state when pausing)
|
||||||
|
|
||||||
**Compatible teleoperators in the current `examples/hil` scripts:**
|
**Compatible teleoperators:**
|
||||||
|
|
||||||
- `openarm_mini` - OpenArm Mini
|
- `openarm_mini` - OpenArm Mini
|
||||||
- `so_leader` - SO100 / SO101 leader arm
|
- `so_leader` - SO100 / SO101 leader arm
|
||||||
|
|
||||||
> [!IMPORTANT]
|
> [!IMPORTANT]
|
||||||
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
|
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Script
|
## Script
|
||||||
|
|
||||||
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
|
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
|
||||||
|
|
||||||
| Mode | Flag | Models |
|
| Mode | Flag | Models |
|
||||||
| ------------------------ | -------------------- | --------------------- |
|
| ------------------------ | ---------------------- | --------------------- |
|
||||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||||
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
|
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
|||||||
**Standard inference (ACT, Diffusion Policy):**
|
**Standard inference (ACT, Diffusion Policy):**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/hil/hil_data_collection.py \
|
lerobot-rollout --strategy.type=dagger \
|
||||||
--robot.type=bi_openarm_follower \
|
--robot.type=bi_openarm_follower \
|
||||||
--robot.left_arm_config.port=can1 \
|
--robot.left_arm_config.port=can1 \
|
||||||
--robot.left_arm_config.side=left \
|
--robot.left_arm_config.side=left \
|
||||||
@@ -108,11 +108,10 @@ python examples/hil/hil_data_collection.py \
|
|||||||
--teleop.port_left=/dev/ttyACM0 \
|
--teleop.port_left=/dev/ttyACM0 \
|
||||||
--teleop.port_right=/dev/ttyACM1 \
|
--teleop.port_right=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/hil-dataset \
|
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
--dataset.fps=30 \
|
--dataset.fps=30 \
|
||||||
--dataset.episode_time_s=1000 \
|
--strategy.num_episodes=50 \
|
||||||
--dataset.num_episodes=50 \
|
|
||||||
--interpolation_multiplier=2
|
--interpolation_multiplier=2
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -121,11 +120,11 @@ python examples/hil/hil_data_collection.py \
|
|||||||
For models with high inference latency, enable RTC for smooth execution:
|
For models with high inference latency, enable RTC for smooth execution:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/hil/hil_data_collection.py \
|
lerobot-rollout --strategy.type=dagger \
|
||||||
--rtc.enabled=true \
|
--inference.type=rtc \
|
||||||
--rtc.execution_horizon=20 \
|
--inference.rtc.execution_horizon=20 \
|
||||||
--rtc.max_guidance_weight=5.0 \
|
--inference.rtc.max_guidance_weight=5.0 \
|
||||||
--rtc.prefix_attention_schedule=LINEAR \
|
--inference.rtc.prefix_attention_schedule=LINEAR \
|
||||||
--robot.type=bi_openarm_follower \
|
--robot.type=bi_openarm_follower \
|
||||||
--robot.left_arm_config.port=can1 \
|
--robot.left_arm_config.port=can1 \
|
||||||
--robot.left_arm_config.side=left \
|
--robot.left_arm_config.side=left \
|
||||||
@@ -136,11 +135,10 @@ python examples/hil/hil_data_collection.py \
|
|||||||
--teleop.port_left=/dev/ttyACM0 \
|
--teleop.port_left=/dev/ttyACM0 \
|
||||||
--teleop.port_right=/dev/ttyACM1 \
|
--teleop.port_right=/dev/ttyACM1 \
|
||||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||||
--dataset.single_task="Fold the T-shirt properly" \
|
--dataset.single_task="Fold the T-shirt properly" \
|
||||||
--dataset.fps=30 \
|
--dataset.fps=30 \
|
||||||
--dataset.episode_time_s=1000 \
|
--strategy.num_episodes=50 \
|
||||||
--dataset.num_episodes=50 \
|
|
||||||
--interpolation_multiplier=3
|
--interpolation_multiplier=3
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -235,7 +233,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
|
|||||||
|
|
||||||
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
|
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
|
||||||
|
|
||||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
|
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
|
||||||
|
|
||||||
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
|
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ pip install -e ".[hilserl]"
|
|||||||
|
|
||||||
### Understanding Configuration
|
### Understanding Configuration
|
||||||
|
|
||||||
The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
The training process begins with proper configuration for the HILSERl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` (defined in `lerobot/envs/configs.py`) and `DatasetConfig`. The configuration is organized into focused, nested sub-configs:
|
||||||
|
|
||||||
<!-- prettier-ignore-start -->
|
<!-- prettier-ignore-start -->
|
||||||
```python
|
```python
|
||||||
@@ -95,6 +95,7 @@ class HILSerlProcessorConfig:
|
|||||||
class ObservationConfig:
|
class ObservationConfig:
|
||||||
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
|
||||||
add_current_to_observation: bool = False # Add motor currents to state
|
add_current_to_observation: bool = False # Add motor currents to state
|
||||||
|
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
|
||||||
display_cameras: bool = False # Display camera feeds during execution
|
display_cameras: bool = False # Display camera feeds during execution
|
||||||
|
|
||||||
class ImagePreprocessingConfig:
|
class ImagePreprocessingConfig:
|
||||||
@@ -326,14 +327,22 @@ lerobot-find-joint-limits \
|
|||||||
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
|
Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0]
|
||||||
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
|
Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
|
||||||
```
|
```
|
||||||
3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field
|
3. Use these values in your environment configuration under `env.processor.inverse_kinematics.end_effector_bounds` (see `InverseKinematicsConfig` in `lerobot/envs/configs.py`)
|
||||||
|
|
||||||
**Example Configuration**
|
**Example Configuration**
|
||||||
|
|
||||||
```json
|
```json
|
||||||
"end_effector_bounds": {
|
{
|
||||||
"max": [0.24, 0.20, 0.10],
|
"env": {
|
||||||
"min": [0.16, -0.08, 0.03]
|
"processor": {
|
||||||
|
"inverse_kinematics": {
|
||||||
|
"end_effector_bounds": {
|
||||||
|
"max": [0.24, 0.2, 0.1],
|
||||||
|
"min": [0.16, -0.08, 0.03]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -404,30 +413,24 @@ We support using a gamepad or a keyboard or the leader arm of the robot.
|
|||||||
|
|
||||||
HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements.
|
HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements.
|
||||||
|
|
||||||
For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space.
|
The end-effector transformation is applied by the processor pipeline (`InverseKinematicsRLStep`, `EEBoundsAndSafety`, `EEReferenceAndDelta`, `GripperVelocityToJoint`) configured under `env.processor.inverse_kinematics` (`InverseKinematicsConfig`) and `env.processor.gripper` / `env.processor.max_gripper_pos`. The defaults related to the end-effector space are:
|
||||||
|
|
||||||
<!-- prettier-ignore-start -->
|
<!-- prettier-ignore-start -->
|
||||||
```python
|
```python
|
||||||
class SO100FollowerEndEffectorConfig(SO100FollowerConfig):
|
class InverseKinematicsConfig:
|
||||||
"""Configuration for the SO100FollowerEndEffector robot."""
|
"""Configuration for inverse kinematics processing."""
|
||||||
|
|
||||||
# Default bounds for the end-effector position (in meters)
|
urdf_path: str | None = None
|
||||||
end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction
|
target_frame_name: str | None = None
|
||||||
default_factory=lambda: {
|
# bounds for the end-effector in x,y,z direction
|
||||||
"min": [-1.0, -1.0, -1.0], # min x, y, z
|
end_effector_bounds: dict[str, list[float]] | None = None
|
||||||
"max": [1.0, 1.0, 1.0], # max x, y, z
|
# maximum step size for the end-effector in x,y,z direction
|
||||||
}
|
end_effector_step_sizes: dict[str, float] | None = None
|
||||||
)
|
|
||||||
|
|
||||||
max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at
|
class HILSerlProcessorConfig:
|
||||||
|
...
|
||||||
end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction
|
# maximum gripper position that the gripper will be open at
|
||||||
default_factory=lambda: {
|
max_gripper_pos: float | None = 100.0
|
||||||
"x": 0.02,
|
|
||||||
"y": 0.02,
|
|
||||||
"z": 0.02,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
```
|
```
|
||||||
<!-- prettier-ignore-end -->
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
@@ -606,11 +609,11 @@ This guide explains how to train a reward classifier for human-in-the-loop reinf
|
|||||||
|
|
||||||
**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device.
|
**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device.
|
||||||
|
|
||||||
The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
The reward classifier implementation in `lerobot/rewards/classifier/modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings.
|
||||||
|
|
||||||
**Collecting a Dataset for the reward classifier**
|
**Collecting a Dataset for the reward classifier**
|
||||||
|
|
||||||
Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
Before training, you need to collect a dataset with labeled examples. Setting `mode: "record"` in your config and running `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards.
|
||||||
|
|
||||||
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
|
To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig.
|
||||||
|
|
||||||
@@ -658,7 +661,7 @@ Example configuration section for data collection:
|
|||||||
},
|
},
|
||||||
"dataset": {
|
"dataset": {
|
||||||
"repo_id": "hf_username/dataset_name",
|
"repo_id": "hf_username/dataset_name",
|
||||||
"dataset_root": "data/your_dataset",
|
"root": "data/your_dataset",
|
||||||
"task": "reward_classifier_task",
|
"task": "reward_classifier_task",
|
||||||
"num_episodes_to_record": 20,
|
"num_episodes_to_record": 20,
|
||||||
"replay_episode": null,
|
"replay_episode": null,
|
||||||
@@ -671,7 +674,7 @@ Example configuration section for data collection:
|
|||||||
|
|
||||||
**Reward Classifier Configuration**
|
**Reward Classifier Configuration**
|
||||||
|
|
||||||
The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters:
|
The reward classifier is configured using `lerobot/rewards/classifier/configuration_classifier.py`. Here are the key parameters:
|
||||||
|
|
||||||
- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`)
|
- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`)
|
||||||
- **model_type**: `"cnn"` or `"transformer"`
|
- **model_type**: `"cnn"` or `"transformer"`
|
||||||
@@ -689,7 +692,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
|||||||
"repo_id": "hf_username/dataset_name",
|
"repo_id": "hf_username/dataset_name",
|
||||||
"root": null
|
"root": null
|
||||||
},
|
},
|
||||||
"policy": {
|
"reward_model": {
|
||||||
"type": "reward_classifier",
|
"type": "reward_classifier",
|
||||||
"model_name": "helper2424/resnet10",
|
"model_name": "helper2424/resnet10",
|
||||||
"model_type": "cnn",
|
"model_type": "cnn",
|
||||||
@@ -699,7 +702,6 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
|||||||
"dropout_rate": 0.1,
|
"dropout_rate": 0.1,
|
||||||
"learning_rate": 1e-4,
|
"learning_rate": 1e-4,
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
"use_amp": true,
|
|
||||||
"input_features": {
|
"input_features": {
|
||||||
"observation.images.front": {
|
"observation.images.front": {
|
||||||
"type": "VISUAL",
|
"type": "VISUAL",
|
||||||
@@ -818,13 +820,14 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
|
|||||||
|
|
||||||
**Configuration Setup**
|
**Configuration Setup**
|
||||||
|
|
||||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
|
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/rl/train_rl.py`.
|
||||||
|
|
||||||
1. Configure the policy settings (`type="sac"`, `device`, etc.)
|
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
|
||||||
2. Set `dataset` to your cropped dataset
|
2. Configure the algorithm settings under the top-level `algorithm` block (`type="sac"`, learning rates, discount, etc., defined in `lerobot/rl/algorithms/sac/configuration_sac.py`).
|
||||||
3. Configure environment settings with crop parameters
|
3. Set `dataset` to your cropped dataset
|
||||||
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
|
4. Configure environment settings with crop parameters
|
||||||
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
5. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
|
||||||
|
6. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||||
|
|
||||||
**Starting the Learner**
|
**Starting the Learner**
|
||||||
|
|
||||||
@@ -926,7 +929,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
|
|||||||
|
|
||||||
Some configuration values have a disproportionate impact on training stability and speed:
|
Some configuration values have a disproportionate impact on training stability and speed:
|
||||||
|
|
||||||
- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||||
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
||||||
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
||||||
|
|
||||||
|
|||||||
@@ -509,121 +509,42 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
|
|||||||
|
|
||||||
## Run inference and evaluate your policy
|
## Run inference and evaluate your policy
|
||||||
|
|
||||||
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
|
||||||
|
|
||||||
<hfoptions id="eval">
|
<hfoptions id="eval">
|
||||||
<hfoption id="Command">
|
<hfoption id="Base mode (no recording)">
|
||||||
```bash
|
```bash
|
||||||
lerobot-record \
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
--robot.type=so100_follower \
|
--robot.type=so100_follower \
|
||||||
--robot.port=/dev/ttyACM1 \
|
--robot.port=/dev/ttyACM1 \
|
||||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||||
--robot.id=my_awesome_follower_arm \
|
--task="Put lego brick into the transparent box" \
|
||||||
--display_data=false \
|
--duration=60
|
||||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
|
||||||
--dataset.single_task="Put lego brick into the transparent box" \
|
|
||||||
--dataset.streaming_encoding=true \
|
|
||||||
--dataset.encoder_threads=2 \
|
|
||||||
# --dataset.vcodec=auto \
|
|
||||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
|
||||||
# --teleop.type=so100_leader \
|
|
||||||
# --teleop.port=/dev/ttyACM0 \
|
|
||||||
# --teleop.id=my_awesome_leader_arm \
|
|
||||||
--policy.path=${HF_USER}/my_policy
|
|
||||||
```
|
```
|
||||||
</hfoption>
|
</hfoption>
|
||||||
<hfoption id="API example">
|
<hfoption id="Sentry mode (with recording)">
|
||||||
|
```bash
|
||||||
<!-- prettier-ignore-start -->
|
lerobot-rollout \
|
||||||
```python
|
--strategy.type=sentry \
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
--strategy.upload_every_n_episodes=5 \
|
||||||
from lerobot.datasets import LeRobotDataset
|
--policy.path=${HF_USER}/my_policy \
|
||||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
--robot.type=so100_follower \
|
||||||
from lerobot.policies.act import ACTPolicy
|
--robot.port=/dev/ttyACM1 \
|
||||||
from lerobot.policies import make_pre_post_processors
|
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||||
from lerobot.scripts.lerobot_record import record_loop
|
--dataset.single_task="Put lego brick into the transparent box" \
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
--duration=600
|
||||||
from lerobot.utils.utils import log_say
|
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
|
||||||
|
|
||||||
|
|
||||||
NUM_EPISODES = 5
|
|
||||||
FPS = 30
|
|
||||||
EPISODE_TIME_SEC = 60
|
|
||||||
TASK_DESCRIPTION = "My task description"
|
|
||||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
|
||||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
|
||||||
|
|
||||||
# Create the robot configuration
|
|
||||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
|
||||||
robot_config = SO100FollowerConfig(
|
|
||||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the robot
|
|
||||||
robot = SO100Follower(robot_config)
|
|
||||||
|
|
||||||
# Initialize the policy
|
|
||||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
|
||||||
|
|
||||||
# Configure the dataset features
|
|
||||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
|
||||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
|
||||||
dataset_features = {**action_features, **obs_features}
|
|
||||||
|
|
||||||
# Create the dataset
|
|
||||||
dataset = LeRobotDataset.create(
|
|
||||||
repo_id=HF_DATASET_ID,
|
|
||||||
fps=FPS,
|
|
||||||
features=dataset_features,
|
|
||||||
robot_type=robot.name,
|
|
||||||
use_videos=True,
|
|
||||||
image_writer_threads=4,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize the keyboard listener and rerun visualization
|
|
||||||
_, events = init_keyboard_listener()
|
|
||||||
init_rerun(session_name="recording")
|
|
||||||
|
|
||||||
# Connect the robot
|
|
||||||
robot.connect()
|
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
|
||||||
policy_cfg=policy,
|
|
||||||
pretrained_path=HF_MODEL_ID,
|
|
||||||
dataset_stats=dataset.meta.stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
for episode_idx in range(NUM_EPISODES):
|
|
||||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
|
||||||
|
|
||||||
# Run the policy inference loop
|
|
||||||
record_loop(
|
|
||||||
robot=robot,
|
|
||||||
events=events,
|
|
||||||
fps=FPS,
|
|
||||||
policy=policy,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
dataset=dataset,
|
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
|
||||||
display_data=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset.save_episode()
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
robot.disconnect()
|
|
||||||
dataset.push_to_hub()
|
|
||||||
```
|
```
|
||||||
<!-- prettier-ignore-end -->
|
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
The `--strategy.type` flag selects the execution mode:
|
||||||
|
|
||||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
|
- `base`: Autonomous rollout with no data recording (useful for quick evaluation)
|
||||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
|
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||||
|
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||||
|
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||||
|
|
||||||
|
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||||
|
|||||||
261
docs/source/inference.mdx
Normal file
261
docs/source/inference.mdx
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
# Policy Deployment (lerobot-rollout)
|
||||||
|
|
||||||
|
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
No extra dependencies are needed beyond your robot and policy extras.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--policy.path=lerobot/act_koch_real \
|
||||||
|
--robot.type=koch_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--task="pick up cube" \
|
||||||
|
--duration=30
|
||||||
|
```
|
||||||
|
|
||||||
|
This runs the policy for 30 seconds with no recording.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Strategies
|
||||||
|
|
||||||
|
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
|
||||||
|
|
||||||
|
### Base (`--strategy.type=base`)
|
||||||
|
|
||||||
|
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--task="Put lego brick into the box" \
|
||||||
|
--duration=60
|
||||||
|
```
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| ---------------- | ------------------------------------------------------ |
|
||||||
|
| `--duration` | Run time in seconds (0 = infinite) |
|
||||||
|
| `--task` | Task description passed to the policy |
|
||||||
|
| `--display_data` | Stream observations/actions to Rerun for visualization |
|
||||||
|
|
||||||
|
### Sentry (`--strategy.type=sentry`)
|
||||||
|
|
||||||
|
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
|
||||||
|
|
||||||
|
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=sentry \
|
||||||
|
--strategy.upload_every_n_episodes=5 \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--dataset.repo_id=${HF_USER}/rollout_eval_data \
|
||||||
|
--dataset.single_task="Put lego brick into the box" \
|
||||||
|
--duration=3600
|
||||||
|
```
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| -------------------------------------- | ----------------------------------------------------------- |
|
||||||
|
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||||
|
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
|
||||||
|
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
|
||||||
|
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
|
||||||
|
|
||||||
|
### Highlight (`--strategy.type=highlight`)
|
||||||
|
|
||||||
|
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=highlight \
|
||||||
|
--strategy.ring_buffer_seconds=30 \
|
||||||
|
--strategy.save_key=s \
|
||||||
|
--strategy.push_key=h \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=koch_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--dataset.repo_id=${HF_USER}/rollout_highlight_data \
|
||||||
|
--dataset.single_task="Pick up the red cube"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Keyboard controls:**
|
||||||
|
|
||||||
|
| Key | Action |
|
||||||
|
| ------------------ | -------------------------------------------------------- |
|
||||||
|
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
|
||||||
|
| `h` (configurable) | Push dataset to Hub |
|
||||||
|
| `ESC` | Stop the session |
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| -------------------------------------- | ---------------------------------------------- |
|
||||||
|
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
|
||||||
|
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
|
||||||
|
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
|
||||||
|
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
|
||||||
|
|
||||||
|
### DAgger (`--strategy.type=dagger`)
|
||||||
|
|
||||||
|
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
|
||||||
|
|
||||||
|
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
|
||||||
|
|
||||||
|
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=dagger \
|
||||||
|
--strategy.num_episodes=20 \
|
||||||
|
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||||
|
--robot.type=bi_openarm_follower \
|
||||||
|
--teleop.type=openarm_mini \
|
||||||
|
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||||
|
--dataset.single_task="Fold the T-shirt"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=dagger \
|
||||||
|
--strategy.record_autonomous=true \
|
||||||
|
--strategy.num_episodes=50 \
|
||||||
|
--policy.path=${HF_USER}/my_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--teleop.type=so101_leader \
|
||||||
|
--teleop.port=/dev/ttyACM1 \
|
||||||
|
--dataset.repo_id=${HF_USER}/rollout_dagger_data \
|
||||||
|
--dataset.single_task="Grasp the block"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Keyboard controls** (default input device):
|
||||||
|
|
||||||
|
| Key | Action |
|
||||||
|
| ------- | ------------------------------------------- |
|
||||||
|
| `Space` | Pause / resume policy execution |
|
||||||
|
| `Tab` | Start / stop human correction |
|
||||||
|
| `Enter` | Push dataset to Hub (corrections-only mode) |
|
||||||
|
| `ESC` | Stop the session |
|
||||||
|
|
||||||
|
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| ------------------------------------ | ------------------------------------------------------- |
|
||||||
|
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
|
||||||
|
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
|
||||||
|
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||||
|
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||||
|
| `--teleop.type` | **Required.** Teleoperator type |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Inference Backends
|
||||||
|
|
||||||
|
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
|
||||||
|
|
||||||
|
### Sync (default)
|
||||||
|
|
||||||
|
One policy call per control tick. The main loop blocks until the action is computed.
|
||||||
|
|
||||||
|
Works with all policies. No extra flags needed.
|
||||||
|
|
||||||
|
### Real-Time Chunking (`--inference.type=rtc`)
|
||||||
|
|
||||||
|
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
|
||||||
|
|
||||||
|
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--inference.type=rtc \
|
||||||
|
--inference.rtc.execution_horizon=10 \
|
||||||
|
--inference.rtc.max_guidance_weight=10.0 \
|
||||||
|
--policy.path=${HF_USER}/pi0_policy \
|
||||||
|
--robot.type=so100_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
|
--task="Pick up the cube" \
|
||||||
|
--duration=60 \
|
||||||
|
--device=cuda
|
||||||
|
```
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
| ------------------------------------------- | -------------------------------------------------------------- |
|
||||||
|
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
|
||||||
|
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
|
||||||
|
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
|
||||||
|
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
|
||||||
|
|
||||||
|
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Common Flags
|
||||||
|
|
||||||
|
| Flag | Description | Default |
|
||||||
|
| --------------------------------- | ----------------------------------------------------------------- | ------- |
|
||||||
|
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
|
||||||
|
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
|
||||||
|
| `--robot.port` | Serial port for the robot | -- |
|
||||||
|
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
|
||||||
|
| `--fps` | Control loop frequency | 30 |
|
||||||
|
| `--duration` | Run time in seconds (0 = infinite) | 0 |
|
||||||
|
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
|
||||||
|
| `--task` | Task description (used when no dataset is provided) | -- |
|
||||||
|
| `--display_data` | Stream telemetry to Rerun visualization | false |
|
||||||
|
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
|
||||||
|
| `--interpolation_multiplier` | Action interpolation factor | 1 |
|
||||||
|
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
|
||||||
|
| `--resume` | Resume a previous recording session | false |
|
||||||
|
| `--play_sounds` | Vocal synthesis for events | true |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Programmatic Usage
|
||||||
|
|
||||||
|
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
|
||||||
|
cfg = RolloutConfig(
|
||||||
|
robot=my_robot_config,
|
||||||
|
policy=my_policy_config,
|
||||||
|
strategy=BaseStrategyConfig(),
|
||||||
|
inference=SyncInferenceConfig(),
|
||||||
|
fps=30,
|
||||||
|
duration=60,
|
||||||
|
task="my task",
|
||||||
|
)
|
||||||
|
|
||||||
|
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||||
|
ctx = build_rollout_context(
|
||||||
|
cfg,
|
||||||
|
signal_handler.shutdown_event,
|
||||||
|
robot_action_processor=my_custom_action_processor, # optional
|
||||||
|
robot_observation_processor=my_custom_obs_processor, # optional
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = BaseStrategy(cfg.strategy)
|
||||||
|
try:
|
||||||
|
strategy.setup(ctx)
|
||||||
|
strategy.run(ctx)
|
||||||
|
finally:
|
||||||
|
strategy.teardown(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.
|
||||||
@@ -207,6 +207,56 @@ pip install 'lerobot[feetech]' # Feetech motor support
|
|||||||
|
|
||||||
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
|
_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
|
||||||
|
|
||||||
|
### PyTorch CUDA variant (Linux only)
|
||||||
|
|
||||||
|
On Linux, the install path determines which CUDA wheel you get. macOS and Windows installs use the PyPI default (MPS / CPU / CUDA-Windows wheel respectively) and can skip this section.
|
||||||
|
|
||||||
|
<!-- prettier-ignore-start -->
|
||||||
|
|
||||||
|
<hfoptions id="cuda_variant">
|
||||||
|
<hfoption id="uv-source">
|
||||||
|
|
||||||
|
**Source install via `uv` (`uv sync` or `uv pip install -e .`)**
|
||||||
|
|
||||||
|
`torch` and `torchvision` are pinned by the project to the **CUDA 12.8** PyTorch index (`https://download.pytorch.org/whl/cu128`, driver floor **570.86**) — covers Ampere/Ada/Hopper/Blackwell GPUs. No action needed for typical NVIDIA setups.
|
||||||
|
|
||||||
|
To override for a different CUDA variant:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install --force-reinstall torch torchvision \
|
||||||
|
--index-url https://download.pytorch.org/whl/cu126 # older drivers; or cu130 for Blackwell on driver ≥ 580
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="pip-conda">
|
||||||
|
|
||||||
|
**Source install via `pip`/`conda`, or `pip install lerobot` from PyPI**
|
||||||
|
|
||||||
|
PyPI default torch wheel is currently a cu130-bundled Linux wheel, driver floor **580.65**.
|
||||||
|
|
||||||
|
To pick a specific CUDA variant:
|
||||||
|
|
||||||
|
**Using `pip` or `conda`** — install torch first with an explicit index, then lerobot:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --index-url https://download.pytorch.org/whl/cu128 torch torchvision
|
||||||
|
pip install -e ".[all]" # source
|
||||||
|
# — or —
|
||||||
|
pip install lerobot # from PyPI
|
||||||
|
```
|
||||||
|
|
||||||
|
**Using `uv` to install from PyPI** — one-liner via `--torch-backend` (uv ≥ 0.6):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install --torch-backend cu128 lerobot
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported values include `auto`, `cpu`, `cu126`, `cu128`, `cu129`, `cu130`, plus various `rocm*` and `xpu`. Swap as needed for your driver.
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||||
|
|||||||
@@ -28,13 +28,15 @@ lerobot-train \
|
|||||||
--steps=100000 \
|
--steps=100000 \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--peft.method_type=LORA \
|
--peft.method_type=LORA \
|
||||||
--peft.r=64
|
--peft.r=64 \
|
||||||
|
--peft.lora_alpha=64
|
||||||
```
|
```
|
||||||
|
|
||||||
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
|
Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use
|
||||||
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
|
[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most
|
||||||
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
|
popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank
|
||||||
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank
|
instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter, and the LoRA scaling factor with
|
||||||
|
`--peft.lora_alpha` (where `scaling = lora_alpha / r`). The higher the rank
|
||||||
the closer you get to full fine-tuning
|
the closer you get to full fine-tuning
|
||||||
|
|
||||||
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
|
There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue
|
||||||
|
|||||||
@@ -61,17 +61,6 @@ lerobot-eval \
|
|||||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Recording
|
|
||||||
|
|
||||||
`lerobot-record` also supports rename maps, nested under the dataset config:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
lerobot-record \ # When running inference
|
|
||||||
--policy.path="<user>/smolVLA_finetuned" \
|
|
||||||
... \
|
|
||||||
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
|
||||||
```
|
|
||||||
|
|
||||||
## Alternative: edit the policy config directly
|
## Alternative: edit the policy config directly
|
||||||
|
|
||||||
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||||
@@ -105,10 +94,10 @@ XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset
|
|||||||
|
|
||||||
## Quick reference
|
## Quick reference
|
||||||
|
|
||||||
| Goal | What to do |
|
| Goal | What to do |
|
||||||
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
| --------------------------------------- | --------------------------------------------------------------------------- |
|
||||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||||
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
| Rollout with different keys (inference) | `--rename_map='{"source_key": "policy_key", ...}'`. |
|
||||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
|
|||||||
|
|
||||||
### Using RTC with Pi0
|
### Using RTC with Pi0
|
||||||
|
|
||||||
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
|
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
|
||||||
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -137,8 +137,12 @@ The script generates a visualization of the denoising process, comparing standar
|
|||||||
## Testing RTC with a Real Robot
|
## Testing RTC with a Real Robot
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/rtc/eval_with_real_robot.py \
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
--policy.path=${HF_USERNAME}/policy_repo_id \
|
--policy.path=${HF_USERNAME}/policy_repo_id \
|
||||||
|
--inference.type=rtc \
|
||||||
|
--inference.rtc.execution_horizon=10 \
|
||||||
|
--inference.rtc.max_guidance_weight=10.0 \
|
||||||
--robot.type=so100_follower \
|
--robot.type=so100_follower \
|
||||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
@@ -178,7 +182,7 @@ visualizer = RTCDebugVisualizer()
|
|||||||
# ... create plots
|
# ... create plots
|
||||||
```
|
```
|
||||||
|
|
||||||
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
|
See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
|
||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
|
|||||||
|
|
||||||
## Inputs and Targets (What the new code expects)
|
## Inputs and Targets (What the new code expects)
|
||||||
|
|
||||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
|
||||||
|
|
||||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||||
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
|
|||||||
<hfoption id="single_stage">
|
<hfoption id="single_stage">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--visualize-only \
|
--visualize-only \
|
||||||
@@ -360,7 +360,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
|||||||
<hfoption id="dense_only">
|
<hfoption id="dense_only">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--visualize-only \
|
--visualize-only \
|
||||||
@@ -373,7 +373,7 @@ python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
|||||||
<hfoption id="dual">
|
<hfoption id="dual">
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--visualize-only \
|
--visualize-only \
|
||||||
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
|
|||||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||||
--dataset-repo-id your-username/your-dataset \
|
--dataset-repo-id your-username/your-dataset \
|
||||||
--reward-model-path your-username/sarm-model \
|
--reward-model-path your-username/sarm-model \
|
||||||
--head-mode sparse \
|
--head-mode sparse \
|
||||||
@@ -465,15 +465,15 @@ This script:
|
|||||||
|
|
||||||
### Step 5b: Train Policy with RA-BC
|
### Step 5b: Train Policy with RA-BC
|
||||||
|
|
||||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-train \
|
lerobot-train \
|
||||||
--dataset.repo_id=your-username/your-dataset \
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
--policy.type=pi0 \
|
--policy.type=pi0 \
|
||||||
--use_rabc=true \
|
--sample_weighting.type=rabc \
|
||||||
--rabc_head_mode=sparse \
|
--sample_weighting.head_mode=sparse \
|
||||||
--rabc_kappa=0.01 \
|
--sample_weighting.kappa=0.01 \
|
||||||
--output_dir=outputs/train/policy_rabc \
|
--output_dir=outputs/train/policy_rabc \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--steps=40000
|
--steps=40000
|
||||||
@@ -488,12 +488,13 @@ The training script automatically:
|
|||||||
|
|
||||||
**RA-BC Arguments:**
|
**RA-BC Arguments:**
|
||||||
|
|
||||||
| Argument | Description | Default |
|
| Argument | Description | Default |
|
||||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||||
|
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||||
|
|
||||||
### Tuning RA-BC Kappa
|
### Tuning RA-BC Kappa
|
||||||
|
|
||||||
@@ -511,30 +512,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
|||||||
|
|
||||||
Monitor these WandB metrics during training:
|
Monitor these WandB metrics during training:
|
||||||
|
|
||||||
| Metric | Healthy Range | Problem Indicator |
|
| Metric | Healthy Range | Problem Indicator |
|
||||||
| ------------------ | ------------- | ------------------------- |
|
| ----------------------------- | ------------- | ------------------------- |
|
||||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||||
|
|
||||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||||
|
|
||||||
**Setting kappa based on your data:**
|
**Setting kappa based on your data:**
|
||||||
|
|
||||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||||
|
|
||||||
```
|
```
|
||||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||||
# Most deltas fall in range [0.01, 0.05]
|
# Most deltas fall in range [0.01, 0.05]
|
||||||
|
|
||||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||||
--rabc_kappa=0.03
|
--sample_weighting.kappa=0.03
|
||||||
|
|
||||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||||
--rabc_kappa=0.05
|
--sample_weighting.kappa=0.05
|
||||||
|
|
||||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||||
--rabc_kappa=0.07
|
--sample_weighting.kappa=0.07
|
||||||
```
|
```
|
||||||
|
|
||||||
**When RA-BC may not help:**
|
**When RA-BC may not help:**
|
||||||
@@ -550,8 +551,8 @@ accelerate launch \
|
|||||||
src/lerobot/scripts/lerobot_train.py \
|
src/lerobot/scripts/lerobot_train.py \
|
||||||
--dataset.repo_id=your-username/your-dataset \
|
--dataset.repo_id=your-username/your-dataset \
|
||||||
--policy.type=pi0 \
|
--policy.type=pi0 \
|
||||||
--use_rabc=true \
|
--sample_weighting.type=rabc \
|
||||||
--rabc_kappa=0.01 \
|
--sample_weighting.kappa=0.01 \
|
||||||
--output_dir=outputs/train/policy_rabc \
|
--output_dir=outputs/train/policy_rabc \
|
||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--steps=40000
|
--steps=40000
|
||||||
@@ -576,7 +577,7 @@ accelerate launch \
|
|||||||
### RA-BC
|
### RA-BC
|
||||||
|
|
||||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -274,7 +274,8 @@ python src/lerobot/scripts/lerobot_train.py \
|
|||||||
Once trained, we recommend deploying policies using inference-time RTC:
|
Once trained, we recommend deploying policies using inference-time RTC:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python examples/rtc/eval_with_real_robot.py \
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
--policy.path=your-username/your-repo-id \
|
--policy.path=your-username/your-repo-id \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--robot.type=unitree_g1 \
|
--robot.type=unitree_g1 \
|
||||||
@@ -284,7 +285,7 @@ python examples/rtc/eval_with_real_robot.py \
|
|||||||
--task="task_description" \
|
--task="task_description" \
|
||||||
--duration=1000 \
|
--duration=1000 \
|
||||||
--fps=30 \
|
--fps=30 \
|
||||||
--rtc.enabled=true
|
--inference.type=rtc
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ REAL_DIM = 12
|
|||||||
# Postprocessing: Trim 20D predictions to 12D for deployment
|
# Postprocessing: Trim 20D predictions to 12D for deployment
|
||||||
```
|
```
|
||||||
|
|
||||||
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
See the [action_hub.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||||
|
|
||||||
#### Auto Action Mode (Recommended)
|
#### Auto Action Mode (Recommended)
|
||||||
|
|
||||||
@@ -519,9 +519,9 @@ If you use X-VLA in your research, please cite:
|
|||||||
|
|
||||||
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
||||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||||
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
|
- [Action Registry Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py)
|
||||||
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
|
- [Processor Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/processor_xvla.py)
|
||||||
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
|
- [Model Configuration](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
|||||||
244
examples/dataset/create_robometer_progress_videos.py
Normal file
244
examples/dataset/create_robometer_progress_videos.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Create videos with a Robometer progress overlay for one LeRobot dataset episode.
|
||||||
|
|
||||||
|
This is a lightweight smoke-test utility for Robometer checkpoints. It downloads
|
||||||
|
one episode video, samples a small number of frames, runs Robometer on those
|
||||||
|
frames, and reuses the progress overlay renderer from
|
||||||
|
``examples/dataset/create_progress_videos.py``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
uv run python examples/dataset/create_robometer_progress_videos.py \\
|
||||||
|
--repo-id lerobot/aloha_mobile_cabinet \\
|
||||||
|
--episode 0 \\
|
||||||
|
--reward-model-path lilkm/robometer-4b \\
|
||||||
|
--device cuda
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from examples.dataset.create_progress_videos import (
|
||||||
|
composite_progress_video,
|
||||||
|
convert_mp4_to_gif,
|
||||||
|
download_episode_metadata,
|
||||||
|
download_video_file,
|
||||||
|
load_episode_meta,
|
||||||
|
)
|
||||||
|
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||||
|
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
|
||||||
|
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
def _default_device() -> str:
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def sample_episode_frames(
|
||||||
|
video_path: Path,
|
||||||
|
*,
|
||||||
|
from_timestamp: float,
|
||||||
|
to_timestamp: float,
|
||||||
|
fps: float,
|
||||||
|
num_frames: int,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Sample RGB frames uniformly from an episode video segment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``(frames, frame_indices)`` where ``frames`` is ``(T,H,W,C)`` uint8 RGB
|
||||||
|
and ``frame_indices`` are local episode frame indices used for overlay.
|
||||||
|
"""
|
||||||
|
if num_frames <= 0:
|
||||||
|
raise ValueError(f"num_frames must be positive, got {num_frames}")
|
||||||
|
|
||||||
|
duration_seconds = to_timestamp - from_timestamp
|
||||||
|
total_frames = max(int(round(duration_seconds * fps)), 1)
|
||||||
|
frame_indices = np.linspace(0, total_frames - 1, num=min(num_frames, total_frames), dtype=int)
|
||||||
|
|
||||||
|
capture = cv2.VideoCapture(str(video_path))
|
||||||
|
frames: list[np.ndarray] = []
|
||||||
|
try:
|
||||||
|
for frame_idx in frame_indices:
|
||||||
|
timestamp = from_timestamp + frame_idx / fps
|
||||||
|
capture.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
|
||||||
|
ret, frame_bgr = capture.read()
|
||||||
|
if not ret:
|
||||||
|
logging.warning("Could not read frame %d at %.3fs", frame_idx, timestamp)
|
||||||
|
continue
|
||||||
|
frames.append(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
|
||||||
|
finally:
|
||||||
|
capture.release()
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
raise RuntimeError(f"No frames could be sampled from {video_path}")
|
||||||
|
|
||||||
|
return np.stack(frames), frame_indices[: len(frames)]
|
||||||
|
|
||||||
|
|
||||||
|
def predict_robometer_progress(
|
||||||
|
frames: np.ndarray,
|
||||||
|
*,
|
||||||
|
task: str,
|
||||||
|
reward_model_path: str,
|
||||||
|
device: str,
|
||||||
|
) -> list[float]:
|
||||||
|
"""Run Robometer and return per-sampled-frame progress predictions."""
|
||||||
|
config = RobometerConfig(pretrained_path=reward_model_path, device=device, max_frames=None)
|
||||||
|
model = RobometerRewardModel.from_pretrained(reward_model_path, config=config)
|
||||||
|
|
||||||
|
encoder = RobometerEncoderProcessorStep(
|
||||||
|
base_model_id=model.config.base_model_id,
|
||||||
|
use_multi_image=model.config.use_multi_image,
|
||||||
|
use_per_frame_progress_token=model.config.use_per_frame_progress_token,
|
||||||
|
max_frames=None,
|
||||||
|
)
|
||||||
|
batch = encoder.encode_samples([(frames, task)])
|
||||||
|
|
||||||
|
model_device = next(model.model.parameters()).device
|
||||||
|
inputs = {key: value.to(model_device) if hasattr(value, "to") else value for key, value in batch.items()}
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
progress_logits, success_logits = model._compute_rbm_logits(inputs)
|
||||||
|
|
||||||
|
decoded = decode_progress_outputs(
|
||||||
|
progress_logits,
|
||||||
|
success_logits,
|
||||||
|
is_discrete_mode=model.config.use_discrete_progress,
|
||||||
|
)
|
||||||
|
return decoded["progress_pred"][0]
|
||||||
|
|
||||||
|
|
||||||
|
def process_dataset(
|
||||||
|
repo_id: str,
|
||||||
|
episode: int,
|
||||||
|
reward_model_path: str,
|
||||||
|
device: str,
|
||||||
|
camera_key: str | None,
|
||||||
|
output_dir: Path,
|
||||||
|
num_frames: int,
|
||||||
|
task: str | None = None,
|
||||||
|
create_gif: bool = False,
|
||||||
|
) -> Path:
|
||||||
|
safe_name = repo_id.replace("/", "_")
|
||||||
|
logging.info("Processing %s episode %d with Robometer %s", repo_id, episode, reward_model_path)
|
||||||
|
|
||||||
|
local_path = download_episode_metadata(repo_id, episode)
|
||||||
|
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||||
|
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||||
|
|
||||||
|
task_name = task or episode_meta.get("task_name", "")
|
||||||
|
if not task_name:
|
||||||
|
raise ValueError("No task found in dataset metadata. Pass --task explicitly.")
|
||||||
|
|
||||||
|
frames, frame_indices = sample_episode_frames(
|
||||||
|
video_path,
|
||||||
|
from_timestamp=episode_meta["from_ts"],
|
||||||
|
to_timestamp=episode_meta["to_ts"],
|
||||||
|
fps=episode_meta["fps"],
|
||||||
|
num_frames=num_frames,
|
||||||
|
)
|
||||||
|
logging.info("Sampled %d frames for Robometer inference", len(frames))
|
||||||
|
|
||||||
|
progress = predict_robometer_progress(
|
||||||
|
frames,
|
||||||
|
task=task_name,
|
||||||
|
reward_model_path=reward_model_path,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
progress_data = np.stack([frame_indices, np.asarray(progress, dtype=np.float32)], axis=1)
|
||||||
|
logging.info("Progress predictions: %s", [round(float(value), 3) for value in progress])
|
||||||
|
|
||||||
|
output_path = output_dir / f"{safe_name}_ep{episode}_robometer_progress.mp4"
|
||||||
|
final_path = composite_progress_video(
|
||||||
|
video_path=video_path,
|
||||||
|
from_timestamp=episode_meta["from_ts"],
|
||||||
|
to_timestamp=episode_meta["to_ts"],
|
||||||
|
progress_data=progress_data,
|
||||||
|
output_path=output_path,
|
||||||
|
fps=episode_meta["fps"],
|
||||||
|
task_name=task_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if create_gif:
|
||||||
|
final_path = convert_mp4_to_gif(final_path)
|
||||||
|
return final_path
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Create MP4/GIF videos with Robometer progress overlay for dataset episodes."
|
||||||
|
)
|
||||||
|
parser.add_argument("--repo-id", required=True, help="Hugging Face LeRobot dataset repo id.")
|
||||||
|
parser.add_argument("--episode", type=int, required=True, help="Episode index to visualize.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--reward-model-path",
|
||||||
|
default="lilkm/robometer-4b",
|
||||||
|
help="Robometer checkpoint path or Hub repo id (e.g. lilkm/robometer-4b).",
|
||||||
|
)
|
||||||
|
parser.add_argument("--device", default=_default_device(), help="Torch device for Robometer inference.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--camera-key",
|
||||||
|
default=None,
|
||||||
|
help="Camera observation key (e.g. observation.images.top). Auto-selects first camera if omitted.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task", default=None, help="Task description override if dataset metadata lacks one."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-frames",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of episode frames to sample for Robometer inference.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("progress_videos"),
|
||||||
|
help="Directory to write output files.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--gif", action="store_true", help="Also generate a GIF from the MP4 output.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
result = process_dataset(
|
||||||
|
repo_id=args.repo_id,
|
||||||
|
episode=args.episode,
|
||||||
|
reward_model_path=args.reward_model_path,
|
||||||
|
device=args.device,
|
||||||
|
camera_key=args.camera_key,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
task=args.task,
|
||||||
|
create_gif=args.gif,
|
||||||
|
)
|
||||||
|
logging.info("Output: %s", result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
from lerobot.rewards.sarm.compute_rabc_weights import (
|
||||||
generate_all_frame_indices,
|
generate_all_frame_indices,
|
||||||
interpolate_progress,
|
interpolate_progress,
|
||||||
load_sarm_resources,
|
load_sarm_resources,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,226 +0,0 @@
|
|||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from lerobot.common.control_utils import is_headless
|
|
||||||
from lerobot.processor import (
|
|
||||||
IdentityProcessorStep,
|
|
||||||
RobotAction,
|
|
||||||
RobotObservation,
|
|
||||||
RobotProcessorPipeline,
|
|
||||||
observation_to_transition,
|
|
||||||
robot_action_observation_to_transition,
|
|
||||||
transition_to_observation,
|
|
||||||
transition_to_robot_action,
|
|
||||||
)
|
|
||||||
from lerobot.robots import Robot
|
|
||||||
from lerobot.teleoperators import Teleoperator
|
|
||||||
from lerobot.utils.robot_utils import precise_sleep
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HILDatasetConfig:
|
|
||||||
repo_id: str
|
|
||||||
single_task: str
|
|
||||||
root: str | Path | None = None
|
|
||||||
fps: int = 30
|
|
||||||
episode_time_s: float = 120
|
|
||||||
num_episodes: int = 50
|
|
||||||
video: bool = True
|
|
||||||
push_to_hub: bool = True
|
|
||||||
private: bool = False
|
|
||||||
tags: list[str] | None = None
|
|
||||||
num_image_writer_processes: int = 0
|
|
||||||
num_image_writer_threads_per_camera: int = 4
|
|
||||||
video_encoding_batch_size: int = 1
|
|
||||||
vcodec: str = "auto"
|
|
||||||
streaming_encoding: bool = True
|
|
||||||
encoder_queue_maxsize: int = 30
|
|
||||||
encoder_threads: int | None = None
|
|
||||||
rename_map: dict[str, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
|
||||||
"""Check if teleoperator has motor control capabilities."""
|
|
||||||
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
|
|
||||||
|
|
||||||
|
|
||||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
|
||||||
"""Disable teleop torque if supported."""
|
|
||||||
if hasattr(teleop, "disable_torque"):
|
|
||||||
teleop.disable_torque()
|
|
||||||
|
|
||||||
|
|
||||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
|
||||||
"""Enable teleop torque if supported."""
|
|
||||||
if hasattr(teleop, "enable_torque"):
|
|
||||||
teleop.enable_torque()
|
|
||||||
|
|
||||||
|
|
||||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
|
||||||
"""Smoothly move teleop to target position if motor control is available."""
|
|
||||||
if not teleop_has_motor_control(teleop):
|
|
||||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
|
||||||
return
|
|
||||||
|
|
||||||
teleop_enable_torque(teleop)
|
|
||||||
current = teleop.get_action()
|
|
||||||
steps = max(int(duration_s * fps), 1)
|
|
||||||
|
|
||||||
for step in range(steps + 1):
|
|
||||||
t = step / steps
|
|
||||||
interp = {}
|
|
||||||
for k in current:
|
|
||||||
if k in target_pos:
|
|
||||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
|
||||||
else:
|
|
||||||
interp[k] = current[k]
|
|
||||||
teleop.write_goal_positions(interp)
|
|
||||||
time.sleep(1 / fps)
|
|
||||||
|
|
||||||
|
|
||||||
def init_keyboard_listener():
|
|
||||||
"""Initialize keyboard listener with HIL controls."""
|
|
||||||
events = {
|
|
||||||
"exit_early": False,
|
|
||||||
"rerecord_episode": False,
|
|
||||||
"stop_recording": False,
|
|
||||||
"policy_paused": False,
|
|
||||||
"correction_active": False,
|
|
||||||
"resume_policy": False,
|
|
||||||
"in_reset": False,
|
|
||||||
"start_next_episode": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_headless():
|
|
||||||
logger.warning("Headless environment - keyboard controls unavailable")
|
|
||||||
return None, events
|
|
||||||
|
|
||||||
from pynput import keyboard
|
|
||||||
|
|
||||||
def on_press(key):
|
|
||||||
try:
|
|
||||||
if events["in_reset"]:
|
|
||||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
|
||||||
logger.info("[HIL] Starting next episode...")
|
|
||||||
events["start_next_episode"] = True
|
|
||||||
elif hasattr(key, "char") and key.char == "c":
|
|
||||||
events["start_next_episode"] = True
|
|
||||||
elif key == keyboard.Key.esc:
|
|
||||||
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
|
|
||||||
events["stop_recording"] = True
|
|
||||||
events["start_next_episode"] = True
|
|
||||||
else:
|
|
||||||
if key == keyboard.Key.space:
|
|
||||||
if not events["policy_paused"] and not events["correction_active"]:
|
|
||||||
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
|
|
||||||
events["policy_paused"] = True
|
|
||||||
elif hasattr(key, "char") and key.char == "c":
|
|
||||||
if events["policy_paused"] and not events["correction_active"]:
|
|
||||||
logger.info("[HIL] Taking control...")
|
|
||||||
events["start_next_episode"] = True
|
|
||||||
elif hasattr(key, "char") and key.char == "p":
|
|
||||||
if events["policy_paused"] or events["correction_active"]:
|
|
||||||
logger.info("[HIL] Resuming policy...")
|
|
||||||
events["resume_policy"] = True
|
|
||||||
elif key == keyboard.Key.right:
|
|
||||||
logger.info("[HIL] End episode")
|
|
||||||
events["exit_early"] = True
|
|
||||||
elif key == keyboard.Key.left:
|
|
||||||
logger.info("[HIL] Re-record episode")
|
|
||||||
events["rerecord_episode"] = True
|
|
||||||
events["exit_early"] = True
|
|
||||||
elif key == keyboard.Key.esc:
|
|
||||||
logger.info("[HIL] ESC - Stop recording...")
|
|
||||||
events["stop_recording"] = True
|
|
||||||
events["exit_early"] = True
|
|
||||||
except Exception as e:
|
|
||||||
logger.info(f"Key error: {e}")
|
|
||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
|
||||||
listener.start()
|
|
||||||
return listener, events
|
|
||||||
|
|
||||||
|
|
||||||
def make_identity_processors():
|
|
||||||
"""Create identity processors for recording."""
|
|
||||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
|
||||||
steps=[IdentityProcessorStep()],
|
|
||||||
to_transition=robot_action_observation_to_transition,
|
|
||||||
to_output=transition_to_robot_action,
|
|
||||||
)
|
|
||||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
|
||||||
steps=[IdentityProcessorStep()],
|
|
||||||
to_transition=observation_to_transition,
|
|
||||||
to_output=transition_to_observation,
|
|
||||||
)
|
|
||||||
return teleop_proc, obs_proc
|
|
||||||
|
|
||||||
|
|
||||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
|
||||||
"""Reset period where human repositions environment."""
|
|
||||||
logger.info("[HIL] RESET")
|
|
||||||
|
|
||||||
events["in_reset"] = True
|
|
||||||
events["start_next_episode"] = False
|
|
||||||
|
|
||||||
obs = robot.get_observation()
|
|
||||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
|
||||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
|
||||||
|
|
||||||
logger.info("Press any key to enable teleoperation")
|
|
||||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
|
||||||
precise_sleep(0.05)
|
|
||||||
|
|
||||||
if events["stop_recording"]:
|
|
||||||
return
|
|
||||||
|
|
||||||
events["start_next_episode"] = False
|
|
||||||
teleop_disable_torque(teleop)
|
|
||||||
logger.info("Teleop enabled - press any key to start episode")
|
|
||||||
|
|
||||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
|
||||||
loop_start = time.perf_counter()
|
|
||||||
action = teleop.get_action()
|
|
||||||
robot.send_action(action)
|
|
||||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
|
||||||
|
|
||||||
events["in_reset"] = False
|
|
||||||
events["start_next_episode"] = False
|
|
||||||
events["exit_early"] = False
|
|
||||||
events["policy_paused"] = False
|
|
||||||
events["correction_active"] = False
|
|
||||||
events["resume_policy"] = False
|
|
||||||
|
|
||||||
|
|
||||||
def print_controls(rtc: bool = False):
|
|
||||||
"""Print control instructions."""
|
|
||||||
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
|
|
||||||
logger.info(
|
|
||||||
"%s\n Controls:\n"
|
|
||||||
" SPACE - Pause policy\n"
|
|
||||||
" c - Take control\n"
|
|
||||||
" p - Resume policy after pause/correction\n"
|
|
||||||
" → - End episode\n"
|
|
||||||
" ESC - Stop and push to hub",
|
|
||||||
mode,
|
|
||||||
)
|
|
||||||
@@ -14,17 +14,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset
|
||||||
from lerobot.policies import make_pre_post_processors
|
from lerobot.policies import make_pre_post_processors
|
||||||
from lerobot.policies.act import ACTPolicy
|
from lerobot.policies.act import ACTPolicy
|
||||||
|
from lerobot.policies.utils import make_robot_action
|
||||||
from lerobot.processor import make_default_processors
|
from lerobot.processor import make_default_processors
|
||||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||||
from lerobot.scripts.lerobot_record import record_loop
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
|
|
||||||
NUM_EPISODES = 2
|
NUM_EPISODES = 2
|
||||||
FPS = 30
|
FPS = 30
|
||||||
@@ -35,6 +39,9 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||||
|
# This script provides a self-contained example for educational purposes.
|
||||||
|
|
||||||
# Create the robot configuration & robot
|
# Create the robot configuration & robot
|
||||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||||
|
|
||||||
@@ -83,43 +90,67 @@ def main():
|
|||||||
raise ValueError("Robot is not connected!")
|
raise ValueError("Robot is not connected!")
|
||||||
|
|
||||||
print("Starting evaluate loop...")
|
print("Starting evaluate loop...")
|
||||||
|
control_interval = 1 / FPS
|
||||||
recorded_episodes = 0
|
recorded_episodes = 0
|
||||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||||
|
|
||||||
# Main record loop
|
# Inline evaluation loop: predict actions and send to robot
|
||||||
record_loop(
|
timestamp = 0
|
||||||
robot=robot,
|
start_episode_t = time.perf_counter()
|
||||||
events=events,
|
while timestamp < EPISODE_TIME_SEC:
|
||||||
fps=FPS,
|
start_loop_t = time.perf_counter()
|
||||||
policy=policy,
|
|
||||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
if events["exit_early"]:
|
||||||
postprocessor=postprocessor,
|
events["exit_early"] = False
|
||||||
dataset=dataset,
|
break
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
# Get robot observation
|
||||||
display_data=True,
|
obs = robot.get_observation()
|
||||||
teleop_action_processor=teleop_action_processor,
|
obs_processed = robot_observation_processor(obs)
|
||||||
robot_action_processor=robot_action_processor,
|
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||||
robot_observation_processor=robot_observation_processor,
|
|
||||||
)
|
# Predict action using the policy
|
||||||
|
action_tensor = predict_action(
|
||||||
|
observation=observation_frame,
|
||||||
|
policy=policy,
|
||||||
|
device=policy.config.device,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
use_amp=policy.config.device.type == "cuda",
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
robot_type=robot.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert policy output to robot action dict
|
||||||
|
action_values = make_robot_action(action_tensor, dataset.features)
|
||||||
|
|
||||||
|
# Process and send action to robot
|
||||||
|
robot_action_to_send = robot_action_processor((action_values, obs))
|
||||||
|
robot.send_action(robot_action_to_send)
|
||||||
|
|
||||||
|
# Write to dataset
|
||||||
|
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||||
|
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
log_rerun_data(observation=obs_processed, action=action_values)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
sleep_time_s = control_interval - dt_s
|
||||||
|
if sleep_time_s < 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||||
|
)
|
||||||
|
precise_sleep(max(sleep_time_s, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
if not events["stop_recording"] and (
|
if not events["stop_recording"] and (
|
||||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||||
):
|
):
|
||||||
log_say("Reset the environment")
|
log_say("Reset the environment")
|
||||||
record_loop(
|
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||||
robot=robot,
|
|
||||||
events=events,
|
|
||||||
fps=FPS,
|
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
|
||||||
display_data=True,
|
|
||||||
teleop_action_processor=teleop_action_processor,
|
|
||||||
robot_action_processor=robot_action_processor,
|
|
||||||
robot_observation_processor=robot_observation_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
log_say("Re-record episode")
|
log_say("Re-record episode")
|
||||||
|
|||||||
@@ -45,9 +45,6 @@ def main():
|
|||||||
leader_arm = SO100Leader(leader_arm_config)
|
leader_arm = SO100Leader(leader_arm_config)
|
||||||
keyboard = KeyboardTeleop(keyboard_config)
|
keyboard = KeyboardTeleop(keyboard_config)
|
||||||
|
|
||||||
# TODO(Steven): Update this example to use pipelines
|
|
||||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
|
||||||
|
|
||||||
# Configure the dataset features
|
# Configure the dataset features
|
||||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||||
@@ -77,6 +74,10 @@ def main():
|
|||||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||||
raise ValueError("Robot or teleop is not connected!")
|
raise ValueError("Robot or teleop is not connected!")
|
||||||
|
|
||||||
|
teleop_action_processor, robot_action_processor, robot_observation_processor = (
|
||||||
|
make_default_processors()
|
||||||
|
)
|
||||||
|
|
||||||
print("Starting record loop...")
|
print("Starting record loop...")
|
||||||
recorded_episodes = 0
|
recorded_episodes = 0
|
||||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||||
@@ -87,14 +88,14 @@ def main():
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=teleop_action_processor,
|
||||||
|
robot_action_processor=robot_action_processor,
|
||||||
|
robot_observation_processor=robot_observation_processor,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
teleop=[leader_arm, keyboard],
|
teleop=[leader_arm, keyboard],
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=teleop_action_processor,
|
|
||||||
robot_action_processor=robot_action_processor,
|
|
||||||
robot_observation_processor=robot_observation_processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
@@ -106,13 +107,13 @@ def main():
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=teleop_action_processor,
|
||||||
|
robot_action_processor=robot_action_processor,
|
||||||
|
robot_observation_processor=robot_observation_processor,
|
||||||
teleop=[leader_arm, keyboard],
|
teleop=[leader_arm, keyboard],
|
||||||
control_time_s=RESET_TIME_SEC,
|
control_time_s=RESET_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=teleop_action_processor,
|
|
||||||
robot_action_processor=robot_action_processor,
|
|
||||||
robot_observation_processor=robot_observation_processor,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
|
|||||||
77
examples/lekiwi/rollout.py
Normal file
77
examples/lekiwi/rollout.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# !/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Run a trained policy on LeKiwi without recording (base rollout).
|
||||||
|
|
||||||
|
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||||
|
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||||
|
control tick). For a CLI entry point with the same capabilities plus
|
||||||
|
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from lerobot.configs import PreTrainedConfig
|
||||||
|
from lerobot.robots.lekiwi import LeKiwiClientConfig
|
||||||
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
FPS = 30
|
||||||
|
DURATION_SEC = 60
|
||||||
|
TASK_DESCRIPTION = "My task description"
|
||||||
|
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
|
||||||
|
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||||
|
|
||||||
|
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
|
||||||
|
# by ``build_rollout_context`` to reload the full model.
|
||||||
|
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||||
|
policy_config.pretrained_path = HF_MODEL_ID
|
||||||
|
|
||||||
|
# Assemble the rollout config: base strategy (no recording) + sync inference.
|
||||||
|
cfg = RolloutConfig(
|
||||||
|
robot=robot_config,
|
||||||
|
policy=policy_config,
|
||||||
|
strategy=BaseStrategyConfig(),
|
||||||
|
inference=SyncInferenceConfig(),
|
||||||
|
fps=FPS,
|
||||||
|
duration=DURATION_SEC,
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
|
||||||
|
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||||
|
|
||||||
|
# Build the context (connects robot, loads policy, wires the inference strategy).
|
||||||
|
# No custom processors here — LeKiwi runs on raw joint features.
|
||||||
|
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
|
||||||
|
|
||||||
|
strategy = BaseStrategy(cfg.strategy)
|
||||||
|
try:
|
||||||
|
strategy.setup(ctx)
|
||||||
|
strategy.run(ctx)
|
||||||
|
finally:
|
||||||
|
strategy.teardown(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
136
examples/omx/README.md
Normal file
136
examples/omx/README.md
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
# OMX Follower — Cube Pick And Place Example
|
||||||
|
|
||||||
|
This is an example of what is possible to do with LeRobot on a physical setup.
|
||||||
|
It is a WIP and being used internally at LeRobot and specific to our setup, but we hope it can be a useful reference for how to use LeRobot APIs and CLIs.
|
||||||
|
|
||||||
|
It includes an end-to-end example for the **OMX Follower** robot arm: pick and place a cube dataset, train a policy, and deploy it autonomously.
|
||||||
|
|
||||||
|
## Hardware
|
||||||
|
|
||||||
|
| Component | Value |
|
||||||
|
| --------- | ------------------------------------ |
|
||||||
|
| Robot | OMX Follower |
|
||||||
|
| Cameras | 2× OpenCV cameras (wrist + top-down) |
|
||||||
|
|
||||||
|
## Scripts
|
||||||
|
|
||||||
|
| Script | Purpose |
|
||||||
|
| ---------------------- | --------------------------------------------------------------- |
|
||||||
|
| `reset_environment.py` | Standalone utility: sweep workspace, grab cube, place cube |
|
||||||
|
| `record_grab.py` | Automated data collection: reset → place → record grab episodes |
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
Make sure you have LeRobot installed in your env. (See [the installation guide](https://huggingface.co/docs/lerobot/installation))
|
||||||
|
|
||||||
|
Next, we will declare some environment variables for convenience. Adjust the camera indices and robot port to match your system configuration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export ROBOT_PORT=/dev/ttyACM0
|
||||||
|
export TELEOP_PORT=/dev/ttyACM1
|
||||||
|
export HF_USERNAME=<your_hf_username>
|
||||||
|
export ROBOT_CAMERAS="{ wrist: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: MJPG} }"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 1 — Collect Data
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-record \
|
||||||
|
--robot.type=omx_follower \
|
||||||
|
--robot.port=$ROBOT_PORT \
|
||||||
|
--robot.id=omx_follower \
|
||||||
|
--robot.cameras="$ROBOT_CAMERAS" \
|
||||||
|
--teleop.type=omx_leader \
|
||||||
|
--teleop.port=$TELEOP_PORT \
|
||||||
|
--teleop.id=omx_leader \
|
||||||
|
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||||
|
--dataset.root=data/omx_pickandplace \
|
||||||
|
--dataset.num_episodes=50 \
|
||||||
|
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||||
|
--dataset.streaming_encoding=true \
|
||||||
|
--dataset.push_to_hub=true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Bonus Auto-Collect script
|
||||||
|
|
||||||
|
/!\ This is specific to our setup and the task of picking and placing a cube. It is not a general-purpose data collection script. As you may notice, it doesn't require a teleop.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m examples.omx.record_grab \
|
||||||
|
--robot.type=omx_follower \
|
||||||
|
--robot.port=$ROBOT_PORT \
|
||||||
|
--robot.id=omx_follower \
|
||||||
|
--robot.cameras="$ROBOT_CAMERAS" \
|
||||||
|
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||||
|
--dataset.root=data/omx_pickandplace \
|
||||||
|
--dataset.num_episodes=50 \
|
||||||
|
--dataset.single_task="Pick the cube and place it in the blue square" \
|
||||||
|
--dataset.streaming_encoding=true \
|
||||||
|
--dataset.push_to_hub=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Each episode:
|
||||||
|
|
||||||
|
1. The arm grabs the cube from the center of the workspace and places it at a random position.
|
||||||
|
2. The arm returns to HOME.
|
||||||
|
3. A targeted grab is recorded: HOME → approach raised → lower onto cube → grasp → lift → carry → drop → HOME.
|
||||||
|
|
||||||
|
A dataset is already available here [`maximellerbach/omx_pickandplace`](https://huggingface.co/datasets/maximellerbach/omx_pickandplace), so you can skip directly to training if you want.
|
||||||
|
|
||||||
|
## Step 2 — Train
|
||||||
|
|
||||||
|
To train a simple `ACT` policy on the collected dataset, you can use the `lerobot-train` CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=$HF_USERNAME/omx_pickandplace \
|
||||||
|
--policy.type=act \
|
||||||
|
--output_dir=outputs/train/omx_pickandplace_act \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.repo_id=$HF_USERNAME/omx_pickandplace_act \
|
||||||
|
--steps=20000 \
|
||||||
|
--wandb.enable=true
|
||||||
|
```
|
||||||
|
|
||||||
|
A pretrained `ACT` policy is already available here [`maximellerbach/omx_pickandplace_act`](https://huggingface.co/maximellerbach/omx_pickandplace_act).
|
||||||
|
|
||||||
|
## Step 3 — Rollout
|
||||||
|
|
||||||
|
Use the `lerobot-rollout` CLI with base strategy:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
|
--robot.type=omx_follower \
|
||||||
|
--robot.port=$ROBOT_PORT \
|
||||||
|
--robot.id=omx_follower \
|
||||||
|
--robot.cameras="$ROBOT_CAMERAS" \
|
||||||
|
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||||
|
```
|
||||||
|
|
||||||
|
For continuous recording with automatic upload (sentry mode):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-rollout \
|
||||||
|
--strategy.type=sentry \
|
||||||
|
--strategy.upload_every_n_episodes=10 \
|
||||||
|
--robot.type=omx_follower \
|
||||||
|
--robot.port=$ROBOT_PORT \
|
||||||
|
--robot.id=omx_follower \
|
||||||
|
--robot.cameras="$ROBOT_CAMERAS" \
|
||||||
|
--policy.path=$HF_USERNAME/omx_pickandplace_act \
|
||||||
|
--dataset.repo_id=$HF_USERNAME/rollout_omx_pickandplace_act \
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Reset Utility
|
||||||
|
|
||||||
|
Those are specific to this particular physical setup. Those are scripts that execute hardcoded sequences of actions on the robot to reset the environment, which is useful for data collection and evaluation. They are not general-purpose scripts.
|
||||||
|
|
||||||
|
`reset_environment.py` can be run standalone to prepare the workspace:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Grab cube + place it at a random position on the left side
|
||||||
|
python -m examples.omx.reset_environment --port $ROBOT_PORT --mode grab_and_place
|
||||||
|
```
|
||||||
|
|
||||||
|
It also exposes `grab_cube(robot)` and `place_cube(robot)` for use in custom scripts.
|
||||||
422
examples/omx/record_grab.py
Normal file
422
examples/omx/record_grab.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Auto-record grab episodes for the OMX robot arm.
|
||||||
|
|
||||||
|
Each episode cycle:
|
||||||
|
1. grab_and_place — grab cube from workspace center and place at a random (pan, reach) position
|
||||||
|
2. HOME — return arm to home with gripper open
|
||||||
|
3. record_grab — execute a targeted grab to the stored position while recording
|
||||||
|
observations + actions to a LeRobotDataset
|
||||||
|
|
||||||
|
Usage (run from repo root):
|
||||||
|
python -m examples.omx.record_grab \\
|
||||||
|
--robot.type=omx_follower \\
|
||||||
|
--robot.port=/dev/ttyACM0 \\
|
||||||
|
--robot.id=omx_follower \\
|
||||||
|
--robot.cameras="{ wrist: {type: opencv, index_or_path: 6, width: 640, height: 480, fps: 30, fourcc: MJPG}, top: {type: opencv, index_or_path: 4, width: 640, height: 480, fps: 30, fourcc: MJPG} }" \\
|
||||||
|
--dataset.repo_id=<hf_username>/<dataset_name> \\
|
||||||
|
--dataset.root=data/omx_grab \\
|
||||||
|
--dataset.num_episodes=50 \\
|
||||||
|
--dataset.single_task="Grab the cube" \\
|
||||||
|
--dataset.streaming_encoding=true
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pprint import pformat
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.cameras import CameraConfig # noqa: F401
|
||||||
|
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.configs.dataset import DatasetRecordConfig
|
||||||
|
from lerobot.datasets import (
|
||||||
|
LeRobotDataset,
|
||||||
|
VideoEncodingManager,
|
||||||
|
aggregate_pipeline_dataset_features,
|
||||||
|
create_initial_features,
|
||||||
|
)
|
||||||
|
from lerobot.processor import make_default_processors
|
||||||
|
from lerobot.robots import RobotConfig, make_robot_from_config
|
||||||
|
from lerobot.robots.omx_follower import OmxFollower
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
|
||||||
|
from .reset_environment import (
|
||||||
|
APPROACH_SPEED,
|
||||||
|
GRIPPER_CLOSE_POS,
|
||||||
|
HOME_POSE,
|
||||||
|
PUSH_END_ELBOW_FLEX,
|
||||||
|
PUSH_END_SHOULDER_LIFT,
|
||||||
|
PUSH_START_ELBOW_FLEX,
|
||||||
|
PUSH_START_SHOULDER_LIFT,
|
||||||
|
array_to_pose,
|
||||||
|
grab_cube,
|
||||||
|
horizontal_wrist_flex,
|
||||||
|
move_to_pose,
|
||||||
|
place_cube,
|
||||||
|
pose_to_array,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Grab-episode motion parameters ────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Shoulder-lift offset for the raised approach phase (subtracted from the target sl, arm is higher).
|
||||||
|
GRAB_RAISE_SL_OFFSET = 20.0
|
||||||
|
GRAB_LOWER_SPEED = 20.0
|
||||||
|
RECORD_SPEED = 30.0
|
||||||
|
|
||||||
|
# Pose the arm travels to after closing the gripper (cube held).
|
||||||
|
GRAB_CARRY_POSE = {
|
||||||
|
"shoulder_pan.pos": -23.0,
|
||||||
|
"shoulder_lift.pos": 5.0,
|
||||||
|
"elbow_flex.pos": 18.0,
|
||||||
|
"wrist_flex.pos": -14.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Per-joint jitter limits (degrees) applied to transit waypoints for human-like variation.
|
||||||
|
# Cube-approach and carry poses are never jittered to preserve precision.
|
||||||
|
_JITTER_LIMITS: dict[str, float] = {
|
||||||
|
"shoulder_pan.pos": 5.0,
|
||||||
|
"shoulder_lift.pos": 4.0,
|
||||||
|
"elbow_flex.pos": 4.0,
|
||||||
|
"wrist_flex.pos": 3.0,
|
||||||
|
"wrist_roll.pos": 2.0,
|
||||||
|
"gripper.pos": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _jitter_pose(pose: dict, rng: np.random.Generator) -> dict:
|
||||||
|
"""Return a copy of pose with independent per-joint random perturbations."""
|
||||||
|
return {
|
||||||
|
k: v + rng.uniform(-_JITTER_LIMITS.get(k, 0.0), _JITTER_LIMITS.get(k, 0.0)) for k, v in pose.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _random_stuck_pose(rng: np.random.Generator) -> dict:
|
||||||
|
"""Return a physically plausible stuck pose (failed grasp), gripper closed.
|
||||||
|
|
||||||
|
ef bounds are piecewise-linear in sl so the arm stays in a reachable,
|
||||||
|
table-safe envelope across the full sl range:
|
||||||
|
sl=-50 → ef ∈ [ 0, 50] (arm raised, can be bent forward)
|
||||||
|
sl= 0 → ef ∈ [-25, 25] (mid reach)
|
||||||
|
sl= 30 → ef ∈ [-20, 0] (arm extended, little room to flex)
|
||||||
|
wrist_flex is randomly offset from the horizontal value.
|
||||||
|
"""
|
||||||
|
pan = float(rng.uniform(-5.0, 35.0))
|
||||||
|
sl = float(rng.uniform(-50.0, 30.0))
|
||||||
|
|
||||||
|
if sl <= 0.0:
|
||||||
|
alpha = (sl + 50.0) / 50.0 # 0 at sl=-50, 1 at sl=0
|
||||||
|
ef_lo = alpha * -25.0 # 0 → -25
|
||||||
|
ef_hi = 50.0 + alpha * -25.0 # 50 → 25
|
||||||
|
else:
|
||||||
|
alpha = sl / 30.0 # 0 at sl=0, 1 at sl=30
|
||||||
|
ef_lo = -25.0 + alpha * 5.0 # -25 → -20
|
||||||
|
ef_hi = 25.0 + alpha * -25.0 # 25 → 0
|
||||||
|
|
||||||
|
ef = float(rng.uniform(ef_lo, ef_hi))
|
||||||
|
wf = horizontal_wrist_flex(sl, ef) + float(rng.uniform(-15.0, 15.0))
|
||||||
|
return {
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": wf,
|
||||||
|
"wrist_roll.pos": float(rng.uniform(-15.0, 15.0)),
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OmxRecordGrabConfig:
|
||||||
|
robot: RobotConfig
|
||||||
|
dataset: DatasetRecordConfig
|
||||||
|
# Resume recording on an existing dataset.
|
||||||
|
resume: bool = False
|
||||||
|
# Fraction of episodes that start from a random stuck pose (gripper closed) to
|
||||||
|
# generate recovery data. 0.0 = disabled, 1.0 = all episodes are recovery starts.
|
||||||
|
recovery_prob: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def record_episode_spline(
|
||||||
|
robot: OmxFollower,
|
||||||
|
waypoints: list[dict],
|
||||||
|
speeds: list[float],
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
task: str,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a Catmull-Rom-style spline through waypoints, recording each frame.
|
||||||
|
|
||||||
|
Segment durations are parameterized from the maximum absolute joint delta
|
||||||
|
between consecutive waypoints divided by the requested segment speed,
|
||||||
|
producing non-uniform timing in joint space. Interior tangents are derived
|
||||||
|
from the adjacent per-segment velocities, with clamped (zero-velocity)
|
||||||
|
endpoints so the arm starts and stops smoothly. Each segment is cubic
|
||||||
|
Hermite, giving C1 continuity at every waypoint.
|
||||||
|
"""
|
||||||
|
pts = [pose_to_array(w) for w in waypoints]
|
||||||
|
n = len(pts)
|
||||||
|
|
||||||
|
# Steps and duration per segment
|
||||||
|
n_steps_list = []
|
||||||
|
timestamps = []
|
||||||
|
for i in range(n - 1):
|
||||||
|
max_dist = float(np.max(np.abs(pts[i + 1] - pts[i])))
|
||||||
|
ns = max(1, int(max_dist / speeds[i] * dataset.fps)) if max_dist >= 0.5 else 0
|
||||||
|
n_steps_list.append(ns)
|
||||||
|
timestamps.append(ns / dataset.fps)
|
||||||
|
|
||||||
|
# Velocity tangents (deg/sec) — clamped at endpoints, Catmull-Rom for interior
|
||||||
|
vels = [np.zeros_like(pts[0])]
|
||||||
|
for i in range(1, n - 1):
|
||||||
|
v_prev = (pts[i] - pts[i - 1]) / timestamps[i - 1] if timestamps[i - 1] > 0 else np.zeros_like(pts[0])
|
||||||
|
v_next = (pts[i + 1] - pts[i]) / timestamps[i] if timestamps[i] > 0 else np.zeros_like(pts[0])
|
||||||
|
vels.append(0.5 * (v_prev + v_next))
|
||||||
|
vels.append(np.zeros_like(pts[0]))
|
||||||
|
|
||||||
|
dt = 1.0 / dataset.fps
|
||||||
|
for seg in range(n - 1):
|
||||||
|
ns = n_steps_list[seg]
|
||||||
|
if ns == 0:
|
||||||
|
continue
|
||||||
|
p0, p1 = pts[seg], pts[seg + 1]
|
||||||
|
# Scale velocity (deg/sec) to t-space tangent (deg/t-unit, where t: 0→1 over ns steps)
|
||||||
|
m0 = vels[seg] * timestamps[seg]
|
||||||
|
m1 = vels[seg + 1] * timestamps[seg]
|
||||||
|
|
||||||
|
for step in range(1, ns + 1):
|
||||||
|
t = step / ns
|
||||||
|
h00 = 2 * t**3 - 3 * t**2 + 1
|
||||||
|
h10 = t**3 - 2 * t**2 + t
|
||||||
|
h01 = -2 * t**3 + 3 * t**2
|
||||||
|
h11 = t**3 - t**2
|
||||||
|
commanded = h00 * p0 + h10 * m0 + h01 * p1 + h11 * m1
|
||||||
|
|
||||||
|
action = array_to_pose(commanded)
|
||||||
|
robot.send_action(action)
|
||||||
|
obs = robot.get_observation()
|
||||||
|
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||||
|
action_frame = build_dataset_frame(dataset.features, action, prefix=ACTION)
|
||||||
|
dataset.add_frame({**obs_frame, **action_frame, "task": task})
|
||||||
|
precise_sleep(dt)
|
||||||
|
|
||||||
|
|
||||||
|
def record_grab_episode(
|
||||||
|
robot: OmxFollower,
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
pan: float,
|
||||||
|
t: float,
|
||||||
|
task: str,
|
||||||
|
recovery_start: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a targeted grab to the stored (pan, t) position, recording every frame.
|
||||||
|
|
||||||
|
Normal sequence (initial HOME move is NOT recorded):
|
||||||
|
HOME → raised approach above cube → lower → close gripper
|
||||||
|
→ raise [jittered] → retract [jittered] → GRAB_CARRY_POSE → drop → HOME
|
||||||
|
|
||||||
|
Recovery sequence (recovery_start=True): arm is moved to a random stuck pose
|
||||||
|
(gripper closed) without recording, then recording begins from there:
|
||||||
|
stuck_pose → raised approach above cube → [normal grab sequence from there]
|
||||||
|
|
||||||
|
All segments are joined by a Catmull-Rom spline (C1-continuous velocities).
|
||||||
|
"""
|
||||||
|
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||||
|
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||||
|
sl_raised = sl - GRAB_RAISE_SL_OFFSET
|
||||||
|
wf_horizontal = horizontal_wrist_flex(sl, ef)
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
|
||||||
|
if recovery_start:
|
||||||
|
stuck_pose = _random_stuck_pose(rng)
|
||||||
|
logger.info(f"Recovery start: {stuck_pose}")
|
||||||
|
move_to_pose(robot, stuck_pose, APPROACH_SPEED)
|
||||||
|
first_waypoints = [stuck_pose]
|
||||||
|
first_speeds = []
|
||||||
|
else:
|
||||||
|
jittery_start = _jitter_pose(HOME_POSE, rng)
|
||||||
|
move_to_pose(robot, jittery_start, APPROACH_SPEED)
|
||||||
|
first_waypoints = [jittery_start]
|
||||||
|
first_speeds = []
|
||||||
|
|
||||||
|
waypoints = first_waypoints + [
|
||||||
|
{ # raised approach: arm above cube
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl_raised,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
{ # lower onto cube — no jitter: precision needed
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": wf_horizontal,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
{ # close gripper — no jitter: precision needed
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": wf_horizontal,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
},
|
||||||
|
_jitter_pose(
|
||||||
|
{ # raise with cube
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl_raised,
|
||||||
|
"elbow_flex.pos": ef,
|
||||||
|
"wrist_flex.pos": horizontal_wrist_flex(sl_raised, ef),
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
},
|
||||||
|
rng,
|
||||||
|
),
|
||||||
|
_jitter_pose(
|
||||||
|
{ # retract: fold arm toward HOME before sweeping to carry zone
|
||||||
|
"shoulder_pan.pos": pan * 0.25,
|
||||||
|
"shoulder_lift.pos": HOME_POSE["shoulder_lift.pos"] + 5.0,
|
||||||
|
"elbow_flex.pos": HOME_POSE["elbow_flex.pos"] - 5.0,
|
||||||
|
"wrist_flex.pos": 0.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": GRIPPER_CLOSE_POS,
|
||||||
|
},
|
||||||
|
rng,
|
||||||
|
),
|
||||||
|
GRAB_CARRY_POSE, # no jitter: target drop zone
|
||||||
|
{**GRAB_CARRY_POSE, "gripper.pos": 60.0}, # drop cube
|
||||||
|
HOME_POSE,
|
||||||
|
]
|
||||||
|
speeds = first_speeds + [
|
||||||
|
RECORD_SPEED, # (HOME →) raised approach
|
||||||
|
GRAB_LOWER_SPEED, # raised approach → lower
|
||||||
|
GRAB_LOWER_SPEED, # lower → close gripper
|
||||||
|
RECORD_SPEED, # close gripper → raise
|
||||||
|
RECORD_SPEED, # raise → retract
|
||||||
|
RECORD_SPEED, # retract → carry pose
|
||||||
|
RECORD_SPEED, # carry pose → drop
|
||||||
|
RECORD_SPEED, # drop → HOME
|
||||||
|
]
|
||||||
|
|
||||||
|
record_episode_spline(robot, waypoints, speeds, dataset, task)
|
||||||
|
|
||||||
|
# Dwell at HOME for ~0.5 s before next episode
|
||||||
|
home_action = build_dataset_frame(dataset.features, HOME_POSE, prefix=ACTION)
|
||||||
|
dt = 1.0 / dataset.fps
|
||||||
|
for _ in range(int(dataset.fps * 0.5)):
|
||||||
|
robot.send_action(HOME_POSE)
|
||||||
|
obs = robot.get_observation()
|
||||||
|
obs_frame = build_dataset_frame(dataset.features, obs, prefix=OBS_STR)
|
||||||
|
dataset.add_frame({**obs_frame, **home_action, "task": task})
|
||||||
|
precise_sleep(dt)
|
||||||
|
|
||||||
|
|
||||||
|
@parser.wrap()
|
||||||
|
def record_grab(cfg: OmxRecordGrabConfig) -> LeRobotDataset:
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
logger.info(pformat(cfg))
|
||||||
|
|
||||||
|
robot = make_robot_from_config(cfg.robot)
|
||||||
|
use_videos = cfg.dataset.video
|
||||||
|
|
||||||
|
teleop_action_processor, _, robot_obs_processor = make_default_processors()
|
||||||
|
|
||||||
|
dataset_features = combine_feature_dicts(
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=teleop_action_processor,
|
||||||
|
initial_features=create_initial_features(action=robot.action_features),
|
||||||
|
use_videos=use_videos,
|
||||||
|
),
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=robot_obs_processor,
|
||||||
|
initial_features=create_initial_features(observation=robot.observation_features),
|
||||||
|
use_videos=use_videos,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
|
||||||
|
dataset = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if cfg.resume:
|
||||||
|
dataset = LeRobotDataset.resume(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||||
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||||
|
vcodec=cfg.dataset.vcodec,
|
||||||
|
encoder_threads=cfg.dataset.encoder_threads,
|
||||||
|
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||||
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||||
|
if num_cameras > 0
|
||||||
|
else 0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg.dataset.stamp_repo_id()
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
cfg.dataset.fps,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
robot_type=robot.name,
|
||||||
|
features=dataset_features,
|
||||||
|
use_videos=use_videos,
|
||||||
|
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||||
|
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||||
|
vcodec=cfg.dataset.vcodec,
|
||||||
|
encoder_threads=cfg.dataset.encoder_threads,
|
||||||
|
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||||
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||||
|
if num_cameras > 0
|
||||||
|
else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
robot.connect(calibrate=True)
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
with VideoEncodingManager(dataset):
|
||||||
|
for episode_idx in range(cfg.dataset.num_episodes):
|
||||||
|
logger.info(f"=== Episode {episode_idx + 1}/{cfg.dataset.num_episodes} ===")
|
||||||
|
|
||||||
|
logger.info("Step 1: grabbing and placing cube...")
|
||||||
|
grab_cube(robot)
|
||||||
|
pan, t = place_cube(robot)
|
||||||
|
logger.info(f"Cube placed at pan={pan:.1f}, reach={t:.2f}")
|
||||||
|
|
||||||
|
recovery_start = cfg.recovery_prob > 0 and float(rng.random()) < cfg.recovery_prob
|
||||||
|
logger.info(f"Step 2: recording {'recovery ' if recovery_start else ''}grab episode...")
|
||||||
|
record_grab_episode(
|
||||||
|
robot,
|
||||||
|
dataset,
|
||||||
|
pan,
|
||||||
|
t,
|
||||||
|
cfg.dataset.single_task,
|
||||||
|
recovery_start=recovery_start,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
logger.info(f"Episode {episode_idx + 1} saved.")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if dataset:
|
||||||
|
dataset.finalize()
|
||||||
|
if robot.is_connected:
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
if cfg.dataset.push_to_hub and dataset and dataset.num_episodes > 0:
|
||||||
|
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
record_grab()
|
||||||
267
examples/omx/reset_environment.py
Normal file
267
examples/omx/reset_environment.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Auto-reset and cube-grab utility for the OMX robot arm.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- grab_cube(robot): sweep workspace, center cube, close gripper
|
||||||
|
- place_cube(robot): carry cube to a random position, release
|
||||||
|
|
||||||
|
Standalone usage (run from repo root):
|
||||||
|
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab
|
||||||
|
python -m examples.omx.reset_environment --port /dev/ttyACM1 --mode grab_and_place
|
||||||
|
|
||||||
|
Joint range: -100 to 100 for arm joints; gripper: 50 = closed, 80 = open.
|
||||||
|
|
||||||
|
To read current joint values for calibration, add after robot.connect():
|
||||||
|
obs = robot.get_observation()
|
||||||
|
print({k: round(obs[k], 1) for k in JOINT_NAMES})
|
||||||
|
robot.disconnect(); raise SystemExit
|
||||||
|
|
||||||
|
Parallel-to-ground IK: wrist_flex = WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex.
|
||||||
|
Linear interpolation preserves this constraint between any two poses that satisfy it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.robots.omx_follower import OmxFollower, OmxFollowerConfig
|
||||||
|
from lerobot.robots.robot import Robot
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Poses ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
HOME_POSE = {
|
||||||
|
"shoulder_pan.pos": 0.0,
|
||||||
|
"shoulder_lift.pos": -50.0,
|
||||||
|
"elbow_flex.pos": 50.0,
|
||||||
|
"wrist_flex.pos": 0.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
SWEEP_WAYPOINTS = [
|
||||||
|
{
|
||||||
|
"shoulder_pan.pos": -60.0,
|
||||||
|
"shoulder_lift.pos": 50.0,
|
||||||
|
"elbow_flex.pos": -60.0,
|
||||||
|
"wrist_flex.pos": -20.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shoulder_pan.pos": -30.0,
|
||||||
|
"shoulder_lift.pos": 50.0,
|
||||||
|
"elbow_flex.pos": -60.0,
|
||||||
|
"wrist_flex.pos": -5.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"shoulder_pan.pos": 20.0,
|
||||||
|
"shoulder_lift.pos": 50.0,
|
||||||
|
"elbow_flex.pos": -55.0,
|
||||||
|
"wrist_flex.pos": -5.0,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Motion parameters ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
CONTROL_HZ = 30
|
||||||
|
APPROACH_SPEED = 50.0
|
||||||
|
SWEEP_SPEED = 40.0
|
||||||
|
|
||||||
|
# ── Grab-sequence parameters ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
GRAB_PAN = 0.0
|
||||||
|
SWEEP_LEFT_PAN = -60.0
|
||||||
|
SWEEP_RIGHT_PAN = 60.0
|
||||||
|
SWEEP_END_OFFSET = 5.0 # stop before center so the cube isn't pushed past GRAB_PAN
|
||||||
|
SWEEP_END_PAN_RANGE = (15.0, 20.0)
|
||||||
|
|
||||||
|
SWEEP_LOW_SHOULDER_LIFT = 50.0
|
||||||
|
SWEEP_LOW_ELBOW_FLEX_START = -60.0
|
||||||
|
SWEEP_LOW_ELBOW_FLEX_END = -55.0
|
||||||
|
|
||||||
|
SWEEP_HIGH_WRIST_FLEX = -20.0 # wrist tilted up during high approach to clear obstacles
|
||||||
|
|
||||||
|
PUSH_START_SHOULDER_LIFT = 0.0
|
||||||
|
PUSH_START_ELBOW_FLEX = 45.0
|
||||||
|
PUSH_END_SHOULDER_LIFT = 50.0
|
||||||
|
PUSH_END_ELBOW_FLEX = -50.0
|
||||||
|
# Subtracted from shoulder_lift during the push sweep to clear the platform surface.
|
||||||
|
# Does not affect the grab-target interpolation in record_grab.py.
|
||||||
|
PUSH_RAISE_OFFSET = 5.0
|
||||||
|
|
||||||
|
WRIST_HORIZONTAL_OFFSET = 0.0 # tune if gripper tilts during push: + tilts nose up, - down
|
||||||
|
GRIPPER_CLOSE_POS = 50.0
|
||||||
|
|
||||||
|
PLACE_LEFT_PAN_RANGE = (5.0, 30.0) # random pan range for cube placement on the left side
|
||||||
|
PLACE_REACH_RANGE = (0.1, 0.7) # 0 = arm retracted (PUSH_START), 1 = fully extended (PUSH_END)
|
||||||
|
|
||||||
|
JOINT_NAMES = [
|
||||||
|
"shoulder_pan.pos",
|
||||||
|
"shoulder_lift.pos",
|
||||||
|
"elbow_flex.pos",
|
||||||
|
"wrist_flex.pos",
|
||||||
|
"wrist_roll.pos",
|
||||||
|
"gripper.pos",
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def pose_to_array(pose: dict) -> np.ndarray:
|
||||||
|
return np.array([pose[k] for k in JOINT_NAMES])
|
||||||
|
|
||||||
|
|
||||||
|
def array_to_pose(arr: np.ndarray) -> dict:
|
||||||
|
return {k: float(arr[i]) for i, k in enumerate(JOINT_NAMES)}
|
||||||
|
|
||||||
|
|
||||||
|
def horizontal_wrist_flex(shoulder_lift: float, elbow_flex: float) -> float:
|
||||||
|
return WRIST_HORIZONTAL_OFFSET - shoulder_lift - elbow_flex
|
||||||
|
|
||||||
|
|
||||||
|
def _low_sweep_pose(pan: float, elbow_flex: float, wrist_flex: float | None = None) -> dict:
|
||||||
|
sl = SWEEP_LOW_SHOULDER_LIFT
|
||||||
|
return {
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": sl,
|
||||||
|
"elbow_flex.pos": elbow_flex,
|
||||||
|
"wrist_flex.pos": horizontal_wrist_flex(sl, elbow_flex) if wrist_flex is None else wrist_flex,
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": 60.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _high_sweep_pose(pan: float) -> dict:
|
||||||
|
return {**HOME_POSE, "shoulder_pan.pos": pan, "wrist_flex.pos": SWEEP_HIGH_WRIST_FLEX}
|
||||||
|
|
||||||
|
|
||||||
|
def _push_pose(shoulder_lift: float, elbow_flex: float, pan: float = GRAB_PAN, gripper: float = 70.0) -> dict:
|
||||||
|
return {
|
||||||
|
"shoulder_pan.pos": pan,
|
||||||
|
"shoulder_lift.pos": shoulder_lift,
|
||||||
|
"elbow_flex.pos": elbow_flex,
|
||||||
|
"wrist_flex.pos": horizontal_wrist_flex(shoulder_lift, elbow_flex),
|
||||||
|
"wrist_roll.pos": 0.0,
|
||||||
|
"gripper.pos": gripper,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_pose(robot: Robot, target: dict, speed: float) -> None:
|
||||||
|
"""Interpolate from current position to target at the given speed (units/s)."""
|
||||||
|
obs = robot.get_observation()
|
||||||
|
current = np.array([obs[k] for k in JOINT_NAMES])
|
||||||
|
goal = pose_to_array(target)
|
||||||
|
|
||||||
|
max_distance = float(np.max(np.abs(goal - current)))
|
||||||
|
if max_distance < 0.5:
|
||||||
|
return
|
||||||
|
|
||||||
|
n_steps = max(1, int(max_distance / speed * CONTROL_HZ))
|
||||||
|
dt = 1.0 / CONTROL_HZ
|
||||||
|
for step in range(1, n_steps + 1):
|
||||||
|
t = step / n_steps
|
||||||
|
robot.send_action(array_to_pose(current + t * (goal - current)))
|
||||||
|
precise_sleep(dt)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Sequences ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def grab_cube(robot: Robot) -> None:
|
||||||
|
"""Left sweep → right sweep → extend arm parallel to ground → close gripper."""
|
||||||
|
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||||
|
|
||||||
|
for pan, end_pan in [
|
||||||
|
(SWEEP_LEFT_PAN, GRAB_PAN - SWEEP_END_OFFSET),
|
||||||
|
(SWEEP_RIGHT_PAN, GRAB_PAN + SWEEP_END_OFFSET),
|
||||||
|
]:
|
||||||
|
logger.info(f"Sweeping {'left' if pan < 0 else 'right'} → center...")
|
||||||
|
move_to_pose(robot, _high_sweep_pose(pan), APPROACH_SPEED)
|
||||||
|
move_to_pose(
|
||||||
|
robot, _low_sweep_pose(pan, SWEEP_LOW_ELBOW_FLEX_START, wrist_flex=-20.0), APPROACH_SPEED
|
||||||
|
)
|
||||||
|
move_to_pose(robot, _low_sweep_pose(end_pan, SWEEP_LOW_ELBOW_FLEX_END, wrist_flex=0.0), SWEEP_SPEED)
|
||||||
|
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||||
|
|
||||||
|
logger.info("Extending to push cube into gripper...")
|
||||||
|
move_to_pose(
|
||||||
|
robot,
|
||||||
|
_push_pose(PUSH_START_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_START_ELBOW_FLEX),
|
||||||
|
APPROACH_SPEED,
|
||||||
|
)
|
||||||
|
move_to_pose(
|
||||||
|
robot,
|
||||||
|
_push_pose(PUSH_END_SHOULDER_LIFT - PUSH_RAISE_OFFSET, PUSH_END_ELBOW_FLEX),
|
||||||
|
SWEEP_SPEED,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Closing gripper...")
|
||||||
|
move_to_pose(
|
||||||
|
robot,
|
||||||
|
_push_pose(PUSH_END_SHOULDER_LIFT, PUSH_END_ELBOW_FLEX, gripper=GRIPPER_CLOSE_POS),
|
||||||
|
APPROACH_SPEED,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Grab complete.")
|
||||||
|
|
||||||
|
|
||||||
|
def place_cube(robot: Robot) -> tuple[float, float]:
|
||||||
|
"""Carry the cube (gripper closed) to a random position on the left side, then release.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(pan, t): pan angle and reach scalar [0, 1] of the placement position.
|
||||||
|
"""
|
||||||
|
pan = float(np.random.uniform(*PLACE_LEFT_PAN_RANGE))
|
||||||
|
t = float(np.random.uniform(*PLACE_REACH_RANGE))
|
||||||
|
sl = PUSH_START_SHOULDER_LIFT + t * (PUSH_END_SHOULDER_LIFT - PUSH_START_SHOULDER_LIFT)
|
||||||
|
ef = PUSH_START_ELBOW_FLEX + t * (PUSH_END_ELBOW_FLEX - PUSH_START_ELBOW_FLEX)
|
||||||
|
logger.info(f"Placing cube at pan={pan:.1f}, reach={t:.2f}...")
|
||||||
|
|
||||||
|
move_to_pose(robot, {**HOME_POSE, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED)
|
||||||
|
move_to_pose(
|
||||||
|
robot, {**HOME_POSE, "shoulder_pan.pos": pan, "gripper.pos": GRIPPER_CLOSE_POS}, APPROACH_SPEED
|
||||||
|
)
|
||||||
|
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=GRIPPER_CLOSE_POS), APPROACH_SPEED)
|
||||||
|
move_to_pose(robot, _push_pose(sl, ef, pan=pan, gripper=80.0), APPROACH_SPEED)
|
||||||
|
move_to_pose(robot, HOME_POSE, APPROACH_SPEED)
|
||||||
|
logger.info("Place complete.")
|
||||||
|
return pan, t
|
||||||
|
|
||||||
|
|
||||||
|
# ── Entry point ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="OMX arm reset / grab script")
|
||||||
|
parser.add_argument("--port", default="/dev/ttyACM1")
|
||||||
|
parser.add_argument("--robot_id", default="omx_follower")
|
||||||
|
parser.add_argument("--mode", choices=["grab", "grab_and_place"], default="grab_and_place")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
|
robot = OmxFollower(OmxFollowerConfig(port=args.port, id=args.robot_id))
|
||||||
|
robot.connect(calibrate=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.mode == "grab":
|
||||||
|
grab_cube(robot)
|
||||||
|
elif args.mode == "grab_and_place":
|
||||||
|
grab_cube(robot)
|
||||||
|
place_cube(robot)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
robot.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -14,13 +14,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||||
from lerobot.configs import FeatureType, PolicyFeature
|
from lerobot.configs import FeatureType, PolicyFeature
|
||||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.policies import make_pre_post_processors
|
from lerobot.policies import make_pre_post_processors
|
||||||
from lerobot.policies.act import ACTPolicy
|
from lerobot.policies.act import ACTPolicy
|
||||||
|
from lerobot.policies.utils import make_robot_action
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
RobotProcessorPipeline,
|
RobotProcessorPipeline,
|
||||||
make_default_teleop_action_processor,
|
make_default_teleop_action_processor,
|
||||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
|||||||
ForwardKinematicsJointsToEE,
|
ForwardKinematicsJointsToEE,
|
||||||
InverseKinematicsEEToJoints,
|
InverseKinematicsEEToJoints,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.lerobot_record import record_loop
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
|
|
||||||
NUM_EPISODES = 5
|
NUM_EPISODES = 5
|
||||||
FPS = 30
|
FPS = 30
|
||||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||||
|
# This script provides a self-contained example for educational purposes.
|
||||||
|
|
||||||
# Create the robot configuration & robot
|
# Create the robot configuration & robot
|
||||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||||
robot_config = SO100FollowerConfig(
|
robot_config = SO100FollowerConfig(
|
||||||
@@ -143,43 +151,67 @@ def main():
|
|||||||
raise ValueError("Robot is not connected!")
|
raise ValueError("Robot is not connected!")
|
||||||
|
|
||||||
print("Starting evaluate loop...")
|
print("Starting evaluate loop...")
|
||||||
|
control_interval = 1 / FPS
|
||||||
episode_idx = 0
|
episode_idx = 0
|
||||||
for episode_idx in range(NUM_EPISODES):
|
for episode_idx in range(NUM_EPISODES):
|
||||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
# Main record loop
|
# Inline evaluation loop: predict actions and send to robot
|
||||||
record_loop(
|
timestamp = 0
|
||||||
robot=robot,
|
start_episode_t = time.perf_counter()
|
||||||
events=events,
|
while timestamp < EPISODE_TIME_SEC:
|
||||||
fps=FPS,
|
start_loop_t = time.perf_counter()
|
||||||
policy=policy,
|
|
||||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
if events["exit_early"]:
|
||||||
postprocessor=postprocessor,
|
events["exit_early"] = False
|
||||||
dataset=dataset,
|
break
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
# Get robot observation
|
||||||
display_data=True,
|
obs = robot.get_observation()
|
||||||
teleop_action_processor=make_default_teleop_action_processor(),
|
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
|
||||||
)
|
# Predict action using the policy
|
||||||
|
action_tensor = predict_action(
|
||||||
|
observation=observation_frame,
|
||||||
|
policy=policy,
|
||||||
|
device=policy.config.device,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
use_amp=policy.config.device.type == "cuda",
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
robot_type=robot.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert policy output to robot action dict
|
||||||
|
action_values = make_robot_action(action_tensor, dataset.features)
|
||||||
|
|
||||||
|
# Process and send action to robot (EE -> joints via IK)
|
||||||
|
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||||
|
robot.send_action(robot_action_to_send)
|
||||||
|
|
||||||
|
# Write to dataset
|
||||||
|
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||||
|
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
log_rerun_data(observation=obs_processed, action=action_values)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
sleep_time_s = control_interval - dt_s
|
||||||
|
if sleep_time_s < 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||||
|
)
|
||||||
|
precise_sleep(max(sleep_time_s, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
if not events["stop_recording"] and (
|
if not events["stop_recording"] and (
|
||||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||||
):
|
):
|
||||||
log_say("Reset the environment")
|
log_say("Reset the environment")
|
||||||
record_loop(
|
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||||
robot=robot,
|
|
||||||
events=events,
|
|
||||||
fps=FPS,
|
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
|
||||||
display_data=True,
|
|
||||||
teleop_action_processor=make_default_teleop_action_processor(),
|
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
|
||||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
log_say("Re-record episode")
|
log_say("Re-record episode")
|
||||||
@@ -190,7 +222,6 @@ def main():
|
|||||||
|
|
||||||
# Save episode
|
# Save episode
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
episode_idx += 1
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up
|
# Clean up
|
||||||
log_say("Stop recording")
|
log_say("Stop recording")
|
||||||
|
|||||||
@@ -65,14 +65,15 @@ def main():
|
|||||||
robot = SO100Follower(robot_config)
|
robot = SO100Follower(robot_config)
|
||||||
phone = Phone(teleop_config)
|
phone = Phone(teleop_config)
|
||||||
|
|
||||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||||
|
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||||
kinematics_solver = RobotKinematics(
|
kinematics_solver = RobotKinematics(
|
||||||
urdf_path="./SO101/so101_new_calib.urdf",
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
target_frame_name="gripper_frame_link",
|
target_frame_name="gripper_frame_link",
|
||||||
joint_names=list(robot.bus.motors.keys()),
|
joint_names=list(robot.bus.motors.keys()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert phone action to EE action
|
# Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
|
||||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||||
tuple[RobotAction, RobotObservation], RobotAction
|
tuple[RobotAction, RobotObservation], RobotAction
|
||||||
](
|
](
|
||||||
@@ -94,7 +95,7 @@ def main():
|
|||||||
to_output=transition_to_robot_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert EE action to joints action
|
# Build pipeline to convert EE action to joints action (IK).
|
||||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
steps=[
|
steps=[
|
||||||
InverseKinematicsEEToJoints(
|
InverseKinematicsEEToJoints(
|
||||||
@@ -107,7 +108,7 @@ def main():
|
|||||||
to_output=transition_to_robot_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert joint observation to EE observation
|
# Build pipeline to convert joint observation to EE observation (FK).
|
||||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||||
steps=[
|
steps=[
|
||||||
ForwardKinematicsJointsToEE(
|
ForwardKinematicsJointsToEE(
|
||||||
@@ -118,13 +119,12 @@ def main():
|
|||||||
to_output=transition_to_observation,
|
to_output=transition_to_observation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the dataset
|
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||||
|
# matches exactly what the pipelines produce at runtime.
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id=HF_REPO_ID,
|
repo_id=HF_REPO_ID,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
features=combine_feature_dicts(
|
features=combine_feature_dicts(
|
||||||
# Run the feature contract of the pipelines
|
|
||||||
# This tells you how the features would look like after the pipeline steps
|
|
||||||
aggregate_pipeline_dataset_features(
|
aggregate_pipeline_dataset_features(
|
||||||
pipeline=phone_to_robot_ee_pose_processor,
|
pipeline=phone_to_robot_ee_pose_processor,
|
||||||
initial_features=create_initial_features(action=phone.action_features),
|
initial_features=create_initial_features(action=phone.action_features),
|
||||||
@@ -163,14 +163,14 @@ def main():
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||||
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
|
robot_observation_processor=robot_joints_to_ee_pose,
|
||||||
teleop=phone,
|
teleop=phone,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
|
||||||
robot_observation_processor=robot_joints_to_ee_pose,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
@@ -182,13 +182,13 @@ def main():
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||||
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
|
robot_observation_processor=robot_joints_to_ee_pose,
|
||||||
teleop=phone,
|
teleop=phone,
|
||||||
control_time_s=RESET_TIME_SEC,
|
control_time_s=RESET_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
|
||||||
robot_observation_processor=robot_joints_to_ee_pose,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
|
|||||||
126
examples/phone_to_so100/rollout.py
Normal file
126
examples/phone_to_so100/rollout.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
# !/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
|
||||||
|
|
||||||
|
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
|
||||||
|
with phone teleoperation in EE space, so at deployment we only need the
|
||||||
|
joint↔EE conversion on the robot side; the phone is not used.
|
||||||
|
|
||||||
|
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
|
||||||
|
(inline policy call). For recording during rollout, switch to Sentry,
|
||||||
|
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.configs import PreTrainedConfig
|
||||||
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
|
from lerobot.processor import (
|
||||||
|
RobotProcessorPipeline,
|
||||||
|
observation_to_transition,
|
||||||
|
robot_action_observation_to_transition,
|
||||||
|
transition_to_observation,
|
||||||
|
transition_to_robot_action,
|
||||||
|
)
|
||||||
|
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||||
|
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||||
|
ForwardKinematicsJointsToEE,
|
||||||
|
InverseKinematicsEEToJoints,
|
||||||
|
)
|
||||||
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
FPS = 30
|
||||||
|
DURATION_SEC = 60
|
||||||
|
TASK_DESCRIPTION = "My task description"
|
||||||
|
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||||
|
robot_config = SO100FollowerConfig(
|
||||||
|
port="/dev/tty.usbmodem58760434471",
|
||||||
|
id="my_awesome_follower_arm",
|
||||||
|
cameras=camera_config,
|
||||||
|
use_degrees=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Peek at motor names once to build the kinematic solver.
|
||||||
|
temp_robot = SO100Follower(robot_config)
|
||||||
|
motor_names = list(temp_robot.bus.motors.keys())
|
||||||
|
|
||||||
|
kinematics_solver = RobotKinematics(
|
||||||
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
|
target_frame_name="gripper_frame_link",
|
||||||
|
joint_names=motor_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||||
|
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||||
|
to_transition=observation_to_transition,
|
||||||
|
to_output=transition_to_observation,
|
||||||
|
)
|
||||||
|
|
||||||
|
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
|
steps=[
|
||||||
|
InverseKinematicsEEToJoints(
|
||||||
|
kinematics=kinematics_solver,
|
||||||
|
motor_names=motor_names,
|
||||||
|
initial_guess_current_joints=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
to_transition=robot_action_observation_to_transition,
|
||||||
|
to_output=transition_to_robot_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||||
|
policy_config.pretrained_path = HF_MODEL_ID
|
||||||
|
|
||||||
|
cfg = RolloutConfig(
|
||||||
|
robot=robot_config,
|
||||||
|
policy=policy_config,
|
||||||
|
strategy=BaseStrategyConfig(),
|
||||||
|
inference=SyncInferenceConfig(),
|
||||||
|
fps=FPS,
|
||||||
|
duration=DURATION_SEC,
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||||
|
|
||||||
|
ctx = build_rollout_context(
|
||||||
|
cfg,
|
||||||
|
signal_handler.shutdown_event,
|
||||||
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
|
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = BaseStrategy(cfg.strategy)
|
||||||
|
try:
|
||||||
|
strategy.setup(ctx)
|
||||||
|
strategy.run(ctx)
|
||||||
|
finally:
|
||||||
|
strategy.teardown(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,673 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
|
||||||
|
|
||||||
This script demonstrates:
|
|
||||||
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
|
||||||
2. Consuming actions from the policy while the robot executes
|
|
||||||
3. Periodically requesting new action chunks in the background using threads
|
|
||||||
4. Managing action buffers and timing for real-time operation
|
|
||||||
|
|
||||||
For simulation environments, see eval_with_simulation.py
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Run RTC with Real robot with RTC
|
|
||||||
uv run examples/rtc/eval_with_real_robot.py \
|
|
||||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
|
||||||
--policy.device=mps \
|
|
||||||
--rtc.enabled=true \
|
|
||||||
--rtc.execution_horizon=20 \
|
|
||||||
--robot.type=so100_follower \
|
|
||||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
||||||
--robot.id=so100_follower \
|
|
||||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
|
||||||
--task="Move green small object into the purple platform" \
|
|
||||||
--duration=120
|
|
||||||
|
|
||||||
# Run RTC with Real robot without RTC
|
|
||||||
uv run examples/rtc/eval_with_real_robot.py \
|
|
||||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
|
||||||
--policy.device=mps \
|
|
||||||
--rtc.enabled=false \
|
|
||||||
--robot.type=so100_follower \
|
|
||||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
||||||
--robot.id=so100_follower \
|
|
||||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
|
||||||
--task="Move green small object into the purple platform" \
|
|
||||||
--duration=120
|
|
||||||
|
|
||||||
# Run RTC with Real robot with pi0.5 policy
|
|
||||||
uv run examples/rtc/eval_with_real_robot.py \
|
|
||||||
--policy.path=<USER>/pi05_check_rtc \
|
|
||||||
--policy.device=mps \
|
|
||||||
--rtc.enabled=true \
|
|
||||||
--rtc.execution_horizon=20 \
|
|
||||||
--robot.type=so100_follower \
|
|
||||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
|
||||||
--robot.id=so100_follower \
|
|
||||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
|
||||||
--task="Move green small object into the purple platform" \
|
|
||||||
--duration=120
|
|
||||||
|
|
||||||
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
|
|
||||||
python examples/rtc/eval_with_real_robot.py \
|
|
||||||
--policy.path=lerobot-data-collection/folding_final \
|
|
||||||
--robot.type=bi_openarm_follower \
|
|
||||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
|
||||||
--robot.left_arm_config.port=can0 \
|
|
||||||
--robot.left_arm_config.side=left \
|
|
||||||
--robot.left_arm_config.can_interface=socketcan \
|
|
||||||
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
|
||||||
--robot.left_arm_config.max_relative_target=8.0 \
|
|
||||||
--robot.right_arm_config.port=can1 \
|
|
||||||
--robot.right_arm_config.side=right \
|
|
||||||
--robot.right_arm_config.can_interface=socketcan \
|
|
||||||
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
|
||||||
--robot.right_arm_config.max_relative_target=8.0 \
|
|
||||||
--task="Fold the T-shirt properly" \
|
|
||||||
--fps=30 \
|
|
||||||
--duration=2000 \
|
|
||||||
--interpolation_multiplier=3 \
|
|
||||||
--rtc.enabled=true \
|
|
||||||
--rtc.execution_horizon=20 \
|
|
||||||
--rtc.max_guidance_weight=5.0 \
|
|
||||||
--rtc.prefix_attention_schedule=LINEAR \
|
|
||||||
--device=cuda
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from threading import Event, Lock, Thread
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
|
||||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
|
||||||
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
|
||||||
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
|
|
||||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
|
||||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
|
||||||
from lerobot.processor import (
|
|
||||||
NormalizerProcessorStep,
|
|
||||||
RelativeActionsProcessorStep,
|
|
||||||
TransitionKey,
|
|
||||||
create_transition,
|
|
||||||
make_default_robot_action_processor,
|
|
||||||
make_default_robot_observation_processor,
|
|
||||||
to_relative_actions,
|
|
||||||
)
|
|
||||||
from lerobot.rl.process import ProcessSignalHandler
|
|
||||||
from lerobot.robots import ( # noqa: F401
|
|
||||||
Robot,
|
|
||||||
RobotConfig,
|
|
||||||
bi_openarm_follower,
|
|
||||||
bi_so_follower,
|
|
||||||
koch_follower,
|
|
||||||
so_follower,
|
|
||||||
unitree_g1,
|
|
||||||
)
|
|
||||||
from lerobot.robots.utils import make_robot_from_config
|
|
||||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
|
||||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
|
||||||
from lerobot.utils.hub import HubMixin
|
|
||||||
from lerobot.utils.utils import init_logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class RobotWrapper:
|
|
||||||
def __init__(self, robot: Robot):
|
|
||||||
self.robot = robot
|
|
||||||
self.lock = Lock()
|
|
||||||
|
|
||||||
def get_observation(self) -> dict[str, Tensor]:
|
|
||||||
with self.lock:
|
|
||||||
return self.robot.get_observation()
|
|
||||||
|
|
||||||
def send_action(self, action: Tensor):
|
|
||||||
with self.lock:
|
|
||||||
self.robot.send_action(action)
|
|
||||||
|
|
||||||
def observation_features(self) -> list[str]:
|
|
||||||
with self.lock:
|
|
||||||
return self.robot.observation_features
|
|
||||||
|
|
||||||
def action_features(self) -> list[str]:
|
|
||||||
with self.lock:
|
|
||||||
return self.robot.action_features
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RTCDemoConfig(HubMixin):
|
|
||||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
|
||||||
|
|
||||||
# Policy configuration
|
|
||||||
policy: PreTrainedConfig | None = None
|
|
||||||
|
|
||||||
# Robot configuration
|
|
||||||
robot: RobotConfig | None = None
|
|
||||||
|
|
||||||
# RTC configuration
|
|
||||||
rtc: RTCConfig = field(
|
|
||||||
default_factory=lambda: RTCConfig(
|
|
||||||
execution_horizon=10,
|
|
||||||
max_guidance_weight=1.0,
|
|
||||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Demo parameters
|
|
||||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
|
||||||
fps: float = 10.0 # Action execution frequency (Hz)
|
|
||||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
|
||||||
|
|
||||||
# Compute device
|
|
||||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
|
||||||
|
|
||||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
|
||||||
# It should be higher than inference delay + execution horizon.
|
|
||||||
action_queue_size_to_get_new_actions: int = 30
|
|
||||||
|
|
||||||
# Task to execute
|
|
||||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
|
||||||
|
|
||||||
# Torch compile configuration
|
|
||||||
use_torch_compile: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_compile_backend: str = field(
|
|
||||||
default="inductor",
|
|
||||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_compile_mode: str = field(
|
|
||||||
default="default",
|
|
||||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_compile_disable_cudagraphs: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={
|
|
||||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
|
||||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
|
||||||
policy_path = parser.get_path_arg("policy")
|
|
||||||
if policy_path:
|
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
||||||
self.policy.pretrained_path = policy_path
|
|
||||||
else:
|
|
||||||
raise ValueError("Policy path is required")
|
|
||||||
|
|
||||||
# Validate that robot configuration is provided
|
|
||||||
if self.robot is None:
|
|
||||||
raise ValueError("Robot configuration must be provided")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
|
||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
||||||
return ["policy"]
|
|
||||||
|
|
||||||
|
|
||||||
def is_image_key(k: str) -> bool:
|
|
||||||
return k.startswith(OBS_IMAGES)
|
|
||||||
|
|
||||||
|
|
||||||
def _reanchor_relative_rtc_prefix(
|
|
||||||
prev_actions_absolute: Tensor,
|
|
||||||
current_state: Tensor,
|
|
||||||
relative_step: RelativeActionsProcessorStep,
|
|
||||||
normalizer_step: NormalizerProcessorStep | None,
|
|
||||||
policy_device: torch.device | str,
|
|
||||||
) -> Tensor:
|
|
||||||
"""Convert absolute leftovers into model-space for relative-action RTC policies.
|
|
||||||
|
|
||||||
When a policy uses relative actions, the RTC prefix (leftover actions from
|
|
||||||
the previous chunk) is stored in absolute space. Before feeding it back to
|
|
||||||
the policy we need to re-express it relative to the *current* robot state
|
|
||||||
and then re-normalize.
|
|
||||||
"""
|
|
||||||
state = current_state.detach().cpu()
|
|
||||||
if state.dim() == 1:
|
|
||||||
state = state.unsqueeze(0)
|
|
||||||
|
|
||||||
action_cpu = prev_actions_absolute.detach().cpu()
|
|
||||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
|
||||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
|
||||||
|
|
||||||
transition = create_transition(action=relative_actions)
|
|
||||||
if normalizer_step is not None:
|
|
||||||
transition = normalizer_step(transition)
|
|
||||||
|
|
||||||
return transition[TransitionKey.ACTION].to(policy_device)
|
|
||||||
|
|
||||||
|
|
||||||
def get_actions(
|
|
||||||
policy,
|
|
||||||
robot: RobotWrapper,
|
|
||||||
robot_observation_processor,
|
|
||||||
action_queue: ActionQueue,
|
|
||||||
shutdown_event: Event,
|
|
||||||
cfg: RTCDemoConfig,
|
|
||||||
):
|
|
||||||
"""Thread function to request action chunks from the policy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
|
||||||
robot: The robot instance for getting observations
|
|
||||||
robot_observation_processor: Processor for raw robot observations
|
|
||||||
action_queue: Queue to put new action chunks
|
|
||||||
shutdown_event: Event to signal shutdown
|
|
||||||
cfg: Demo configuration
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
|
||||||
|
|
||||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
|
||||||
fps = cfg.fps
|
|
||||||
time_per_chunk = 1.0 / fps
|
|
||||||
|
|
||||||
# Only keep .pos joints + camera streams if the policy was trained on positions,
|
|
||||||
# not the full pos/vel/torque state the robot exposes.
|
|
||||||
observation_features_hw = {
|
|
||||||
key: value
|
|
||||||
for key, value in robot.observation_features().items()
|
|
||||||
if key.endswith(".pos") or isinstance(value, tuple)
|
|
||||||
}
|
|
||||||
|
|
||||||
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
|
|
||||||
policy_device = policy.config.device
|
|
||||||
|
|
||||||
# Load preprocessor and postprocessor from pretrained files
|
|
||||||
# The stats are embedded in the processor .safetensors files
|
|
||||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
|
||||||
policy_cfg=cfg.policy,
|
|
||||||
pretrained_path=cfg.policy.pretrained_path,
|
|
||||||
dataset_stats=None, # Will load from pretrained processor files
|
|
||||||
preprocessor_overrides={
|
|
||||||
"device_processor": {"device": cfg.policy.device},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
|
||||||
|
|
||||||
relative_step = next(
|
|
||||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
normalizer_step = next(
|
|
||||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if relative_step is not None:
|
|
||||||
if relative_step.action_names is None:
|
|
||||||
cfg_names = getattr(cfg.policy, "action_feature_names", None)
|
|
||||||
if cfg_names:
|
|
||||||
relative_step.action_names = list(cfg_names)
|
|
||||||
else:
|
|
||||||
relative_step.action_names = [
|
|
||||||
k for k in robot.robot.action_features if k.endswith(".pos")
|
|
||||||
]
|
|
||||||
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
|
|
||||||
|
|
||||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
|
||||||
|
|
||||||
if not cfg.rtc.enabled:
|
|
||||||
get_actions_threshold = 0
|
|
||||||
|
|
||||||
while not shutdown_event.is_set():
|
|
||||||
if action_queue.qsize() <= get_actions_threshold:
|
|
||||||
current_time = time.perf_counter()
|
|
||||||
action_index_before_inference = action_queue.get_action_index()
|
|
||||||
prev_actions = action_queue.get_left_over()
|
|
||||||
|
|
||||||
inference_latency = latency_tracker.max()
|
|
||||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
|
||||||
|
|
||||||
obs = robot.get_observation()
|
|
||||||
|
|
||||||
# Apply robot observation processor
|
|
||||||
obs_processed = robot_observation_processor(obs)
|
|
||||||
|
|
||||||
obs_with_policy_features = build_dataset_frame(
|
|
||||||
dataset_features, obs_processed, prefix="observation"
|
|
||||||
)
|
|
||||||
|
|
||||||
for name in obs_with_policy_features:
|
|
||||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
|
||||||
if "image" in name:
|
|
||||||
obs_with_policy_features[name] = (
|
|
||||||
obs_with_policy_features[name].type(torch.float32) / 255
|
|
||||||
)
|
|
||||||
obs_with_policy_features[name] = (
|
|
||||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
|
||||||
)
|
|
||||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
|
||||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
|
||||||
|
|
||||||
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
|
||||||
obs_with_policy_features["robot_type"] = (
|
|
||||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
|
||||||
|
|
||||||
# Re-anchor leftover actions for relative-action policies.
|
|
||||||
# We need the *postprocessed* (absolute) leftover, not the original
|
|
||||||
# (normalized/relative) one that get_left_over() returns.
|
|
||||||
if (
|
|
||||||
prev_actions is not None
|
|
||||||
and relative_step is not None
|
|
||||||
and OBS_STATE in obs_with_policy_features
|
|
||||||
):
|
|
||||||
with action_queue.lock:
|
|
||||||
if action_queue.queue is not None:
|
|
||||||
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
|
|
||||||
else:
|
|
||||||
prev_actions_abs = None
|
|
||||||
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
|
|
||||||
prev_actions = _reanchor_relative_rtc_prefix(
|
|
||||||
prev_actions_absolute=prev_actions_abs,
|
|
||||||
current_state=obs_with_policy_features[OBS_STATE],
|
|
||||||
relative_step=relative_step,
|
|
||||||
normalizer_step=normalizer_step,
|
|
||||||
policy_device=policy_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate actions WITH RTC
|
|
||||||
actions = policy.predict_action_chunk(
|
|
||||||
preproceseded_obs,
|
|
||||||
inference_delay=inference_delay,
|
|
||||||
prev_chunk_left_over=prev_actions,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store original actions (before postprocessing) for RTC
|
|
||||||
original_actions = actions.squeeze(0).clone()
|
|
||||||
|
|
||||||
postprocessed_actions = postprocessor(actions)
|
|
||||||
|
|
||||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
|
||||||
|
|
||||||
new_latency = time.perf_counter() - current_time
|
|
||||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
|
||||||
latency_tracker.add(new_latency)
|
|
||||||
|
|
||||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
|
||||||
logger.warning(
|
|
||||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
|
||||||
)
|
|
||||||
|
|
||||||
action_queue.merge(
|
|
||||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Small sleep to prevent busy waiting
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def actor_control(
|
|
||||||
robot: RobotWrapper,
|
|
||||||
robot_action_processor,
|
|
||||||
action_queue: ActionQueue,
|
|
||||||
shutdown_event: Event,
|
|
||||||
cfg: RTCDemoConfig,
|
|
||||||
):
|
|
||||||
"""Thread function to execute actions on the robot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
robot: The robot instance
|
|
||||||
action_queue: Queue to get actions from
|
|
||||||
shutdown_event: Event to signal shutdown
|
|
||||||
cfg: Demo configuration
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info("[ACTOR] Starting actor thread")
|
|
||||||
|
|
||||||
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
|
|
||||||
|
|
||||||
action_count = 0
|
|
||||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
|
||||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
|
||||||
|
|
||||||
while not shutdown_event.is_set():
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
if interpolator.needs_new_action():
|
|
||||||
new_action = action_queue.get()
|
|
||||||
if new_action is not None:
|
|
||||||
interpolator.add(new_action.cpu())
|
|
||||||
|
|
||||||
action = interpolator.get()
|
|
||||||
if action is not None:
|
|
||||||
action = action.cpu()
|
|
||||||
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
|
|
||||||
action_processed = robot_action_processor((action_dict, None))
|
|
||||||
robot.send_action(action_processed)
|
|
||||||
action_count += 1
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_time
|
|
||||||
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
|
||||||
|
|
||||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
|
||||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
policy: Policy instance to compile
|
|
||||||
cfg: Configuration containing torch compile settings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Policy with compiled predict_action_chunk method
|
|
||||||
"""
|
|
||||||
|
|
||||||
# PI models handle their own compilation
|
|
||||||
if policy.type == "pi05" or policy.type == "pi0":
|
|
||||||
return policy
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check if torch.compile is available (PyTorch 2.0+)
|
|
||||||
if not hasattr(torch, "compile"):
|
|
||||||
logger.warning(
|
|
||||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
|
||||||
f"Current version: {torch.__version__}. Skipping compilation."
|
|
||||||
)
|
|
||||||
return policy
|
|
||||||
|
|
||||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
|
||||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
|
||||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
|
||||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
|
||||||
|
|
||||||
# Compile the predict_action_chunk method
|
|
||||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
|
||||||
compile_kwargs = {
|
|
||||||
"backend": cfg.torch_compile_backend,
|
|
||||||
"mode": cfg.torch_compile_mode,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
|
||||||
if cfg.torch_compile_disable_cudagraphs:
|
|
||||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
|
||||||
|
|
||||||
original_method = policy.predict_action_chunk
|
|
||||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
|
||||||
policy.predict_action_chunk = compiled_method
|
|
||||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to apply torch.compile: {e}")
|
|
||||||
logger.warning("Continuing without torch.compile")
|
|
||||||
|
|
||||||
return policy
|
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
|
||||||
def demo_cli(cfg: RTCDemoConfig):
|
|
||||||
"""Main entry point for RTC demo with draccus configuration."""
|
|
||||||
|
|
||||||
# Initialize logging
|
|
||||||
init_logging()
|
|
||||||
|
|
||||||
logger.info(f"Using device: {cfg.device}")
|
|
||||||
|
|
||||||
# Setup signal handler for graceful shutdown
|
|
||||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
|
||||||
shutdown_event = signal_handler.shutdown_event
|
|
||||||
|
|
||||||
policy = None
|
|
||||||
robot = None
|
|
||||||
get_actions_thread = None
|
|
||||||
actor_thread = None
|
|
||||||
|
|
||||||
policy_class = get_policy_class(cfg.policy.type)
|
|
||||||
|
|
||||||
# Load config and set compile_model for pi0/pi05 models
|
|
||||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
|
||||||
|
|
||||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
|
||||||
config.compile_model = cfg.use_torch_compile
|
|
||||||
|
|
||||||
if config.use_peft:
|
|
||||||
from peft import PeftConfig, PeftModel
|
|
||||||
|
|
||||||
peft_pretrained_path = cfg.policy.pretrained_path
|
|
||||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
|
||||||
|
|
||||||
policy = policy_class.from_pretrained(
|
|
||||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
|
||||||
)
|
|
||||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
|
||||||
else:
|
|
||||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
|
||||||
|
|
||||||
# Turn on RTC
|
|
||||||
policy.config.rtc_config = cfg.rtc
|
|
||||||
|
|
||||||
# Init RTC processort, as by default if RTC disabled in the config
|
|
||||||
# The processor won't be created
|
|
||||||
policy.init_rtc_processor()
|
|
||||||
|
|
||||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
|
||||||
|
|
||||||
policy = policy.to(cfg.device)
|
|
||||||
policy.eval()
|
|
||||||
|
|
||||||
# Apply torch.compile to predict_action_chunk method if enabled
|
|
||||||
if cfg.use_torch_compile:
|
|
||||||
policy = _apply_torch_compile(policy, cfg)
|
|
||||||
|
|
||||||
# Create robot
|
|
||||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
|
||||||
robot = make_robot_from_config(cfg.robot)
|
|
||||||
robot.connect()
|
|
||||||
robot_wrapper = RobotWrapper(robot)
|
|
||||||
|
|
||||||
# Create robot observation processor
|
|
||||||
robot_observation_processor = make_default_robot_observation_processor()
|
|
||||||
robot_action_processor = make_default_robot_action_processor()
|
|
||||||
|
|
||||||
# Create action queue for communication between threads
|
|
||||||
action_queue = ActionQueue(cfg.rtc)
|
|
||||||
|
|
||||||
# Start chunk requester thread
|
|
||||||
get_actions_thread = Thread(
|
|
||||||
target=get_actions,
|
|
||||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
|
||||||
daemon=True,
|
|
||||||
name="GetActions",
|
|
||||||
)
|
|
||||||
get_actions_thread.start()
|
|
||||||
logger.info("Started get actions thread")
|
|
||||||
|
|
||||||
# Start action executor thread
|
|
||||||
actor_thread = Thread(
|
|
||||||
target=actor_control,
|
|
||||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
|
||||||
daemon=True,
|
|
||||||
name="Actor",
|
|
||||||
)
|
|
||||||
actor_thread.start()
|
|
||||||
logger.info("Started actor thread")
|
|
||||||
|
|
||||||
logger.info("Started stop by duration thread")
|
|
||||||
|
|
||||||
# Main thread monitors for duration or shutdown
|
|
||||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
# Log queue status periodically
|
|
||||||
if int(time.time() - start_time) % 5 == 0:
|
|
||||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
|
||||||
|
|
||||||
if time.time() - start_time > cfg.duration:
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info("Demo duration reached or shutdown requested")
|
|
||||||
|
|
||||||
# Signal shutdown
|
|
||||||
shutdown_event.set()
|
|
||||||
|
|
||||||
# Wait for threads to finish
|
|
||||||
if get_actions_thread and get_actions_thread.is_alive():
|
|
||||||
logger.info("Waiting for chunk requester thread to finish...")
|
|
||||||
get_actions_thread.join()
|
|
||||||
|
|
||||||
if actor_thread and actor_thread.is_alive():
|
|
||||||
logger.info("Waiting for action executor thread to finish...")
|
|
||||||
actor_thread.join()
|
|
||||||
|
|
||||||
# Cleanup robot
|
|
||||||
if robot:
|
|
||||||
robot.disconnect()
|
|
||||||
logger.info("Robot disconnected")
|
|
||||||
|
|
||||||
logger.info("Cleanup completed")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
demo_cli()
|
|
||||||
logging.info("RTC demo finished")
|
|
||||||
@@ -14,13 +14,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||||
from lerobot.configs import FeatureType, PolicyFeature
|
from lerobot.configs import FeatureType, PolicyFeature
|
||||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.policies import make_pre_post_processors
|
from lerobot.policies import make_pre_post_processors
|
||||||
from lerobot.policies.act import ACTPolicy
|
from lerobot.policies.act import ACTPolicy
|
||||||
|
from lerobot.policies.utils import make_robot_action
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
RobotProcessorPipeline,
|
RobotProcessorPipeline,
|
||||||
make_default_teleop_action_processor,
|
make_default_teleop_action_processor,
|
||||||
@@ -34,11 +38,12 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
|||||||
ForwardKinematicsJointsToEE,
|
ForwardKinematicsJointsToEE,
|
||||||
InverseKinematicsEEToJoints,
|
InverseKinematicsEEToJoints,
|
||||||
)
|
)
|
||||||
from lerobot.scripts.lerobot_record import record_loop
|
|
||||||
from lerobot.types import RobotAction, RobotObservation
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||||
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
|
|
||||||
NUM_EPISODES = 5
|
NUM_EPISODES = 5
|
||||||
FPS = 30
|
FPS = 30
|
||||||
@@ -49,6 +54,9 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||||
|
# This script provides a self-contained example for educational purposes.
|
||||||
|
|
||||||
# Create the robot configuration & robot
|
# Create the robot configuration & robot
|
||||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||||
robot_config = SO100FollowerConfig(
|
robot_config = SO100FollowerConfig(
|
||||||
@@ -143,43 +151,67 @@ def main():
|
|||||||
raise ValueError("Robot is not connected!")
|
raise ValueError("Robot is not connected!")
|
||||||
|
|
||||||
print("Starting evaluate loop...")
|
print("Starting evaluate loop...")
|
||||||
|
control_interval = 1 / FPS
|
||||||
episode_idx = 0
|
episode_idx = 0
|
||||||
for episode_idx in range(NUM_EPISODES):
|
for episode_idx in range(NUM_EPISODES):
|
||||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
# Main record loop
|
# Inline evaluation loop: predict actions and send to robot
|
||||||
record_loop(
|
timestamp = 0
|
||||||
robot=robot,
|
start_episode_t = time.perf_counter()
|
||||||
events=events,
|
while timestamp < EPISODE_TIME_SEC:
|
||||||
fps=FPS,
|
start_loop_t = time.perf_counter()
|
||||||
policy=policy,
|
|
||||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
if events["exit_early"]:
|
||||||
postprocessor=postprocessor,
|
events["exit_early"] = False
|
||||||
dataset=dataset,
|
break
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
# Get robot observation
|
||||||
display_data=True,
|
obs = robot.get_observation()
|
||||||
teleop_action_processor=make_default_teleop_action_processor(),
|
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
|
||||||
)
|
# Predict action using the policy
|
||||||
|
action_tensor = predict_action(
|
||||||
|
observation=observation_frame,
|
||||||
|
policy=policy,
|
||||||
|
device=policy.config.device,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
use_amp=policy.config.device.type == "cuda",
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
robot_type=robot.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert policy output to robot action dict
|
||||||
|
action_values = make_robot_action(action_tensor, dataset.features)
|
||||||
|
|
||||||
|
# Process and send action to robot (EE -> joints via IK)
|
||||||
|
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||||
|
robot.send_action(robot_action_to_send)
|
||||||
|
|
||||||
|
# Write to dataset
|
||||||
|
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||||
|
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
log_rerun_data(observation=obs_processed, action=action_values)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
sleep_time_s = control_interval - dt_s
|
||||||
|
if sleep_time_s < 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||||
|
)
|
||||||
|
precise_sleep(max(sleep_time_s, 0.0))
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
if not events["stop_recording"] and (
|
if not events["stop_recording"] and (
|
||||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||||
):
|
):
|
||||||
log_say("Reset the environment")
|
log_say("Reset the environment")
|
||||||
record_loop(
|
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||||
robot=robot,
|
|
||||||
events=events,
|
|
||||||
fps=FPS,
|
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
|
||||||
single_task=TASK_DESCRIPTION,
|
|
||||||
display_data=True,
|
|
||||||
teleop_action_processor=make_default_teleop_action_processor(),
|
|
||||||
robot_action_processor=robot_ee_to_joints_processor,
|
|
||||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
log_say("Re-record episode")
|
log_say("Re-record episode")
|
||||||
@@ -190,7 +222,6 @@ def main():
|
|||||||
|
|
||||||
# Save episode
|
# Save episode
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
episode_idx += 1
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up
|
# Clean up
|
||||||
log_say("Stop recording")
|
log_say("Stop recording")
|
||||||
|
|||||||
@@ -62,21 +62,20 @@ def main():
|
|||||||
follower = SO100Follower(follower_config)
|
follower = SO100Follower(follower_config)
|
||||||
leader = SO100Leader(leader_config)
|
leader = SO100Leader(leader_config)
|
||||||
|
|
||||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||||
|
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||||
follower_kinematics_solver = RobotKinematics(
|
follower_kinematics_solver = RobotKinematics(
|
||||||
urdf_path="./SO101/so101_new_calib.urdf",
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
target_frame_name="gripper_frame_link",
|
target_frame_name="gripper_frame_link",
|
||||||
joint_names=list(follower.bus.motors.keys()),
|
joint_names=list(follower.bus.motors.keys()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
|
||||||
leader_kinematics_solver = RobotKinematics(
|
leader_kinematics_solver = RobotKinematics(
|
||||||
urdf_path="./SO101/so101_new_calib.urdf",
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
target_frame_name="gripper_frame_link",
|
target_frame_name="gripper_frame_link",
|
||||||
joint_names=list(leader.bus.motors.keys()),
|
joint_names=list(leader.bus.motors.keys()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert follower joints to EE observation
|
# Build pipeline to convert follower joints to EE observation.
|
||||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||||
steps=[
|
steps=[
|
||||||
ForwardKinematicsJointsToEE(
|
ForwardKinematicsJointsToEE(
|
||||||
@@ -87,7 +86,7 @@ def main():
|
|||||||
to_output=transition_to_observation,
|
to_output=transition_to_observation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert leader joints to EE action
|
# Build pipeline to convert leader joints to EE action.
|
||||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
steps=[
|
steps=[
|
||||||
ForwardKinematicsJointsToEE(
|
ForwardKinematicsJointsToEE(
|
||||||
@@ -98,9 +97,9 @@ def main():
|
|||||||
to_output=transition_to_robot_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build pipeline to convert EE action to follower joints
|
# Build pipeline to convert EE action to follower joints (with safety bounds).
|
||||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
[
|
steps=[
|
||||||
EEBoundsAndSafety(
|
EEBoundsAndSafety(
|
||||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||||
max_ee_step_m=0.10,
|
max_ee_step_m=0.10,
|
||||||
@@ -115,13 +114,12 @@ def main():
|
|||||||
to_output=transition_to_robot_action,
|
to_output=transition_to_robot_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the dataset
|
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||||
|
# matches exactly what the pipelines produce at runtime.
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id=HF_REPO_ID,
|
repo_id=HF_REPO_ID,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
features=combine_feature_dicts(
|
features=combine_feature_dicts(
|
||||||
# Run the feature contract of the pipelines
|
|
||||||
# This tells you how the features would look like after the pipeline steps
|
|
||||||
aggregate_pipeline_dataset_features(
|
aggregate_pipeline_dataset_features(
|
||||||
pipeline=leader_joints_to_ee,
|
pipeline=leader_joints_to_ee,
|
||||||
initial_features=create_initial_features(action=leader.action_features),
|
initial_features=create_initial_features(action=leader.action_features),
|
||||||
@@ -144,7 +142,7 @@ def main():
|
|||||||
|
|
||||||
# Initialize the keyboard listener and rerun visualization
|
# Initialize the keyboard listener and rerun visualization
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener()
|
||||||
init_rerun(session_name="recording_phone")
|
init_rerun(session_name="recording_so100_ee")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not leader.is_connected or not follower.is_connected:
|
if not leader.is_connected or not follower.is_connected:
|
||||||
@@ -160,14 +158,14 @@ def main():
|
|||||||
robot=follower,
|
robot=follower,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=leader_joints_to_ee,
|
||||||
|
robot_action_processor=ee_to_follower_joints,
|
||||||
|
robot_observation_processor=follower_joints_to_ee,
|
||||||
teleop=leader,
|
teleop=leader,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=leader_joints_to_ee,
|
|
||||||
robot_action_processor=ee_to_follower_joints,
|
|
||||||
robot_observation_processor=follower_joints_to_ee,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
@@ -179,13 +177,13 @@ def main():
|
|||||||
robot=follower,
|
robot=follower,
|
||||||
events=events,
|
events=events,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
|
teleop_action_processor=leader_joints_to_ee,
|
||||||
|
robot_action_processor=ee_to_follower_joints,
|
||||||
|
robot_observation_processor=follower_joints_to_ee,
|
||||||
teleop=leader,
|
teleop=leader,
|
||||||
control_time_s=RESET_TIME_SEC,
|
control_time_s=RESET_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=True,
|
||||||
teleop_action_processor=leader_joints_to_ee,
|
|
||||||
robot_action_processor=ee_to_follower_joints,
|
|
||||||
robot_observation_processor=follower_joints_to_ee,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
|
|||||||
134
examples/so100_to_so100_EE/rollout.py
Normal file
134
examples/so100_to_so100_EE/rollout.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
# !/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Run a trained EE-space policy on SO100 without recording (base rollout).
|
||||||
|
|
||||||
|
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||||
|
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||||
|
control tick). The custom observation/action processors convert between
|
||||||
|
joint space (robot hardware) and end-effector space (policy I/O) via
|
||||||
|
forward/inverse kinematics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.configs import PreTrainedConfig
|
||||||
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
|
from lerobot.processor import (
|
||||||
|
RobotProcessorPipeline,
|
||||||
|
observation_to_transition,
|
||||||
|
robot_action_observation_to_transition,
|
||||||
|
transition_to_observation,
|
||||||
|
transition_to_robot_action,
|
||||||
|
)
|
||||||
|
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||||
|
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||||
|
ForwardKinematicsJointsToEE,
|
||||||
|
InverseKinematicsEEToJoints,
|
||||||
|
)
|
||||||
|
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||||
|
from lerobot.rollout.inference import SyncInferenceConfig
|
||||||
|
from lerobot.rollout.strategies import BaseStrategy
|
||||||
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
|
from lerobot.utils.process import ProcessSignalHandler
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
FPS = 30
|
||||||
|
DURATION_SEC = 60
|
||||||
|
TASK_DESCRIPTION = "My task description"
|
||||||
|
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
|
||||||
|
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||||
|
robot_config = SO100FollowerConfig(
|
||||||
|
port="/dev/tty.usbmodem5A460814411",
|
||||||
|
id="my_awesome_follower_arm",
|
||||||
|
cameras=camera_config,
|
||||||
|
use_degrees=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Kinematic solver: we need the motor-name list, so peek at the robot once.
|
||||||
|
# (The rollout engine owns the connected instance; we only use this for introspection.)
|
||||||
|
temp_robot = SO100Follower(robot_config)
|
||||||
|
motor_names = list(temp_robot.bus.motors.keys())
|
||||||
|
|
||||||
|
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||||
|
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||||
|
kinematics_solver = RobotKinematics(
|
||||||
|
urdf_path="./SO101/so101_new_calib.urdf",
|
||||||
|
target_frame_name="gripper_frame_link",
|
||||||
|
joint_names=motor_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Joint-space observation → EE-space observation (consumed by the policy).
|
||||||
|
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||||
|
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||||
|
to_transition=observation_to_transition,
|
||||||
|
to_output=transition_to_observation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# EE-space action (produced by the policy) → joint-space action (sent to robot).
|
||||||
|
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||||
|
steps=[
|
||||||
|
InverseKinematicsEEToJoints(
|
||||||
|
kinematics=kinematics_solver,
|
||||||
|
motor_names=motor_names,
|
||||||
|
initial_guess_current_joints=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
to_transition=robot_action_observation_to_transition,
|
||||||
|
to_output=transition_to_robot_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Policy config (full model is loaded inside build_rollout_context).
|
||||||
|
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||||
|
policy_config.pretrained_path = HF_MODEL_ID
|
||||||
|
|
||||||
|
cfg = RolloutConfig(
|
||||||
|
robot=robot_config,
|
||||||
|
policy=policy_config,
|
||||||
|
strategy=BaseStrategyConfig(),
|
||||||
|
inference=SyncInferenceConfig(),
|
||||||
|
fps=FPS,
|
||||||
|
duration=DURATION_SEC,
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
|
)
|
||||||
|
|
||||||
|
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||||
|
|
||||||
|
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
|
||||||
|
# otherwise skip the joint↔EE conversion and the policy would receive the
|
||||||
|
# wrong observation/action space.
|
||||||
|
ctx = build_rollout_context(
|
||||||
|
cfg,
|
||||||
|
signal_handler.shutdown_event,
|
||||||
|
robot_action_processor=robot_ee_to_joints_processor,
|
||||||
|
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
strategy = BaseStrategy(cfg.strategy)
|
||||||
|
try:
|
||||||
|
strategy.setup(ctx)
|
||||||
|
strategy.run(ctx)
|
||||||
|
finally:
|
||||||
|
strategy.teardown(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -4,13 +4,13 @@ from pathlib import Path
|
|||||||
from queue import Empty, Full
|
from queue import Empty, Full
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
|
||||||
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset
|
||||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||||
from lerobot.policies import SACConfig
|
from lerobot.policies import GaussianActorConfig
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||||
|
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||||
from lerobot.rl.buffer import ReplayBuffer
|
from lerobot.rl.buffer import ReplayBuffer
|
||||||
from lerobot.rl.gym_manipulator import make_robot_env
|
from lerobot.rl.gym_manipulator import make_robot_env
|
||||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||||
@@ -28,7 +28,7 @@ def run_learner(
|
|||||||
transitions_queue: mp.Queue,
|
transitions_queue: mp.Queue,
|
||||||
parameters_queue: mp.Queue,
|
parameters_queue: mp.Queue,
|
||||||
shutdown_event: mp.Event,
|
shutdown_event: mp.Event,
|
||||||
policy_learner: SACPolicy,
|
policy_learner: GaussianActorPolicy,
|
||||||
online_buffer: ReplayBuffer,
|
online_buffer: ReplayBuffer,
|
||||||
offline_buffer: ReplayBuffer,
|
offline_buffer: ReplayBuffer,
|
||||||
lr: float = 3e-4,
|
lr: float = 3e-4,
|
||||||
@@ -40,8 +40,9 @@ def run_learner(
|
|||||||
policy_learner.train()
|
policy_learner.train()
|
||||||
policy_learner.to(device)
|
policy_learner.to(device)
|
||||||
|
|
||||||
# Create Adam optimizer from scratch - simple and clean
|
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||||
|
algorithm.make_optimizers_and_scheduler()
|
||||||
|
|
||||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||||
@@ -83,24 +84,26 @@ def run_learner(
|
|||||||
else:
|
else:
|
||||||
batch[key] = online_batch[key]
|
batch[key] = online_batch[key]
|
||||||
|
|
||||||
loss, _ = policy_learner.forward(batch)
|
def batch_iter(b=batch):
|
||||||
|
while True:
|
||||||
|
yield b
|
||||||
|
|
||||||
optimizer.zero_grad()
|
stats = algorithm.update(batch_iter())
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
training_step += 1
|
training_step += 1
|
||||||
|
|
||||||
if training_step % LOG_EVERY == 0:
|
if training_step % LOG_EVERY == 0:
|
||||||
|
log_dict = stats.to_log_dict()
|
||||||
print(
|
print(
|
||||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
f"[LEARNER] Training step {training_step}, "
|
||||||
|
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send updated parameters to actor every 10 training steps
|
# Send updated parameters to actor every 10 training steps
|
||||||
if training_step % SEND_EVERY == 0:
|
if training_step % SEND_EVERY == 0:
|
||||||
try:
|
try:
|
||||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
weights = algorithm.get_weights()
|
||||||
parameters_queue.put_nowait(state_dict)
|
parameters_queue.put_nowait(weights)
|
||||||
print("[LEARNER] Sent updated parameters to actor")
|
print("[LEARNER] Sent updated parameters to actor")
|
||||||
except Full:
|
except Full:
|
||||||
# Missing write due to queue not being consumed (should happen rarely)
|
# Missing write due to queue not being consumed (should happen rarely)
|
||||||
@@ -113,7 +116,7 @@ def run_actor(
|
|||||||
transitions_queue: mp.Queue,
|
transitions_queue: mp.Queue,
|
||||||
parameters_queue: mp.Queue,
|
parameters_queue: mp.Queue,
|
||||||
shutdown_event: mp.Event,
|
shutdown_event: mp.Event,
|
||||||
policy_actor: SACPolicy,
|
policy_actor: GaussianActorPolicy,
|
||||||
reward_classifier: Classifier,
|
reward_classifier: Classifier,
|
||||||
env_cfg: HILSerlRobotEnvConfig,
|
env_cfg: HILSerlRobotEnvConfig,
|
||||||
device: torch.device = "mps",
|
device: torch.device = "mps",
|
||||||
@@ -144,15 +147,15 @@ def run_actor(
|
|||||||
|
|
||||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
new_params = parameters_queue.get_nowait()
|
new_weights = parameters_queue.get_nowait()
|
||||||
policy_actor.load_state_dict(new_params)
|
policy_actor.load_state_dict(new_weights)
|
||||||
print("[ACTOR] Updated policy parameters from learner")
|
print("[ACTOR] Updated policy parameters from learner")
|
||||||
except Empty: # No new updated parameters available from learner, waiting
|
except Empty: # No new updated parameters available from learner, waiting
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Get action from policy
|
# Get action from policy (returns full action: continuous + discrete)
|
||||||
policy_obs = make_policy_obs(obs, device=device)
|
policy_obs = make_policy_obs(obs, device=device)
|
||||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
action_tensor = policy_actor.select_action(policy_obs)
|
||||||
action = action_tensor.squeeze(0).cpu().numpy()
|
action = action_tensor.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
# Step environment
|
# Step environment
|
||||||
@@ -261,14 +264,14 @@ def main():
|
|||||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||||
|
|
||||||
# Create SAC policy for action selection
|
# Create SAC policy for action selection
|
||||||
policy_cfg = SACConfig(
|
policy_cfg = GaussianActorConfig(
|
||||||
device=device,
|
device=device,
|
||||||
input_features=obs_features,
|
input_features=obs_features,
|
||||||
output_features=action_features,
|
output_features=action_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
policy_actor = SACPolicy(policy_cfg)
|
policy_actor = GaussianActorPolicy(policy_cfg)
|
||||||
policy_learner = SACPolicy(policy_cfg)
|
policy_learner = GaussianActorPolicy(policy_cfg)
|
||||||
|
|
||||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets import LeRobotDataset
|
||||||
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
|
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -22,10 +22,10 @@ def main():
|
|||||||
model_name="microsoft/resnet-18",
|
model_name="microsoft/resnet-18",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make policy, preprocessor, and optimizer
|
# Make reward model, preprocessor, and optimizer
|
||||||
policy = make_policy(config, ds_meta=dataset.meta)
|
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
|
||||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
|
||||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
|
||||||
|
|
||||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ def main():
|
|||||||
batch = preprocessor(batch)
|
batch = preprocessor(batch)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = reward_model.forward(batch)
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -58,8 +58,8 @@ def main():
|
|||||||
|
|
||||||
print("Training finished!")
|
print("Training finished!")
|
||||||
|
|
||||||
# You can now save the trained policy.
|
# You can now save the trained reward model.
|
||||||
policy.push_to_hub(classifier_id)
|
reward_model.push_to_hub(classifier_id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core ML
|
# Core ML
|
||||||
"torch>=2.7,<2.11.0",
|
"torch>=2.7,<2.12.0",
|
||||||
"torchvision>=0.22.0,<0.26.0",
|
"torchvision>=0.22.0,<0.27.0",
|
||||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||||
"Pillow>=10.0.0,<13.0.0",
|
"Pillow>=10.0.0,<13.0.0",
|
||||||
@@ -99,7 +99,18 @@ dataset = [
|
|||||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||||
"lerobot[av-dep]",
|
"lerobot[av-dep]",
|
||||||
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
|
|
||||||
|
# NOTE: torchcodec wheel availability matrix (PyPI):
|
||||||
|
# - linux x86_64/amd64 + macOS arm64 : wheels since 0.3.0 (the historic supported set).
|
||||||
|
# - win32 x86_64 : wheels since 0.7.0 (needs torch>=2.8).
|
||||||
|
# - linux aarch64/arm64 : wheels since 0.11.0 (needs torch>=2.11).
|
||||||
|
# - macOS x86_64 (Intel) and linux armv7l: no wheels in any released version -> fall through to the PyAV decoder.
|
||||||
|
# Each platform gets its own line so the resolver picks the minimum version that has a wheel for it.
|
||||||
|
|
||||||
|
# Other torch/torchcodec pairings (informational): 0.8.1 = ffmpeg>=8 support, 0.10 = system-wide ffmpeg support, 0.12 needs torch==2.12.
|
||||||
|
"torchcodec>=0.3.0,<0.12.0; (sys_platform == 'linux' and (platform_machine == 'x86_64' or platform_machine == 'AMD64')) or (sys_platform == 'darwin' and platform_machine == 'arm64')",
|
||||||
|
"torchcodec>=0.7.0,<0.12.0; sys_platform == 'win32'",
|
||||||
|
"torchcodec>=0.11.0,<0.12.0; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64')",
|
||||||
"jsonlines>=4.0.0,<5.0.0",
|
"jsonlines>=4.0.0,<5.0.0",
|
||||||
]
|
]
|
||||||
training = [
|
training = [
|
||||||
@@ -128,7 +139,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
|||||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||||
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
|
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||||
@@ -193,8 +204,10 @@ groot = [
|
|||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
|
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||||
xvla = ["lerobot[transformers-dep]"]
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||||
|
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||||
@@ -289,8 +302,24 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
|||||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||||
|
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||||
|
lerobot-export-robometer="lerobot.scripts.lerobot_export_robometer:main"
|
||||||
|
|
||||||
# ---------------- Tool Configurations ----------------
|
# ---------------- Tool Configurations ----------------
|
||||||
|
|
||||||
|
# cu128 wheels keep broad hardware reach; the driver floor is 570.86.
|
||||||
|
# To use a different CUDA variant, reinstall torch with an explicit index, e.g.:
|
||||||
|
# uv pip install --force-reinstall torch torchvision \
|
||||||
|
# --index-url https://download.pytorch.org/whl/cu130
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cu128"
|
||||||
|
url = "https://download.pytorch.org/whl/cu128"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||||
|
torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
lerobot = ["envs/*.json"]
|
lerobot = ["envs/*.json"]
|
||||||
|
|
||||||
|
|||||||
164
scripts/debug_robometer_embed_diff.py
Normal file
164
scripts/debug_robometer_embed_diff.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
"""Pinpoint exactly which rows of ``embed_tokens`` / ``lm_head`` differ.
|
||||||
|
|
||||||
|
Useful follow-up to ``scripts/verify_robometer_export.py`` when the verifier
|
||||||
|
reports a small tail of differing keys but you want to know whether the
|
||||||
|
diff is:
|
||||||
|
|
||||||
|
1. Concentrated in the 5 special-token rows added by ``resize_token_embeddings``
|
||||||
|
(expected non-determinism: mean-resize sampling differs between runs).
|
||||||
|
2. Spread across the full vocabulary (would point to a real loading bug).
|
||||||
|
|
||||||
|
Also confirms whether ``apply_upstream_checkpoint`` actually overwrites the
|
||||||
|
embed/lm-head tensors when loading the upstream state dict (vs. silently
|
||||||
|
skipping them due to a key mismatch).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
|
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||||
|
from lerobot.rewards.robometer._upstream_loader import (
|
||||||
|
_download_robometer_snapshot,
|
||||||
|
_remap_state_dict_keys,
|
||||||
|
_resolve_checkpoint_safetensors_files,
|
||||||
|
apply_upstream_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
EMBED_KEY = "model.model.language_model.embed_tokens.weight"
|
||||||
|
LMHEAD_KEY = "model.lm_head.weight"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_upstream(path: str) -> RobometerRewardModel:
|
||||||
|
cfg = RobometerConfig(pretrained_path=path, device="cpu")
|
||||||
|
model = RobometerRewardModel(cfg)
|
||||||
|
apply_upstream_checkpoint(model, path)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _load_lerobot(path: str) -> RobometerRewardModel:
|
||||||
|
cfg = RewardModelConfig.from_pretrained(path)
|
||||||
|
if not isinstance(cfg, RobometerConfig):
|
||||||
|
raise TypeError(f"Expected RobometerConfig, got {type(cfg)}")
|
||||||
|
cfg.pretrained_path = path
|
||||||
|
cfg.device = "cpu"
|
||||||
|
return RobometerRewardModel.from_pretrained(path, config=cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _inspect_upstream_state_dict(upstream_path: str, model: RobometerRewardModel) -> None:
|
||||||
|
"""Dump the upstream state-dict view of the embed/lm-head tensors.
|
||||||
|
|
||||||
|
Loads the raw upstream safetensors (pre-remap), runs the remapper, and
|
||||||
|
reports whether the embed/lm-head keys survive into the merged dict that
|
||||||
|
eventually hits ``model.load_state_dict``.
|
||||||
|
"""
|
||||||
|
snapshot_dir = _download_robometer_snapshot(upstream_path)
|
||||||
|
files = _resolve_checkpoint_safetensors_files(snapshot_dir)
|
||||||
|
merged: dict[str, torch.Tensor] = {}
|
||||||
|
for path in files:
|
||||||
|
merged.update(load_file(str(path)))
|
||||||
|
remapped = _remap_state_dict_keys(merged, model)
|
||||||
|
|
||||||
|
print(f"\n=== Upstream state-dict inspection (snapshot at {snapshot_dir}) ===")
|
||||||
|
print(f"raw keys (before remap) : {len(merged)}")
|
||||||
|
print(f"keys after remap : {len(remapped)}")
|
||||||
|
print(f"model expects (state_dict): {len(model.state_dict())}")
|
||||||
|
|
||||||
|
expected = set(model.state_dict())
|
||||||
|
present_after_remap = set(remapped) & expected
|
||||||
|
print(f"keys present after remap : {len(present_after_remap)}")
|
||||||
|
|
||||||
|
missing_keys = expected - set(remapped)
|
||||||
|
print(f"keys missing from remap : {len(missing_keys)}")
|
||||||
|
if missing_keys:
|
||||||
|
sample = list(missing_keys)[:10]
|
||||||
|
print(f" sample missing keys : {sample}")
|
||||||
|
|
||||||
|
unexpected_keys = set(remapped) - expected
|
||||||
|
print(f"keys unexpected by model : {len(unexpected_keys)}")
|
||||||
|
if unexpected_keys:
|
||||||
|
sample = list(unexpected_keys)[:10]
|
||||||
|
print(f" sample unexpected keys : {sample}")
|
||||||
|
|
||||||
|
for key in (EMBED_KEY, LMHEAD_KEY):
|
||||||
|
present = key in remapped
|
||||||
|
shape = tuple(remapped[key].shape) if present else None
|
||||||
|
print(f" {key:60s} present={present}, shape={shape}")
|
||||||
|
|
||||||
|
|
||||||
|
def _diff_embed(name: str, a: torch.Tensor, b: torch.Tensor, special_token_count: int) -> None:
|
||||||
|
a = a.float()
|
||||||
|
b = b.float()
|
||||||
|
if a.shape != b.shape:
|
||||||
|
print(f"❌ {name} shape mismatch: {tuple(a.shape)} vs {tuple(b.shape)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
abs_diff = (a - b).abs()
|
||||||
|
per_row_max = abs_diff.max(dim=1).values
|
||||||
|
nz_rows = (per_row_max > 0).nonzero(as_tuple=True)[0].tolist()
|
||||||
|
print(f"\n=== {name} (shape {tuple(a.shape)}) ===")
|
||||||
|
print(f"global max|Δ| = {abs_diff.max().item():.3e}")
|
||||||
|
print(f"rows with any diff = {len(nz_rows)}")
|
||||||
|
if nz_rows:
|
||||||
|
first = nz_rows[:10]
|
||||||
|
last = nz_rows[-10:]
|
||||||
|
print(f" first nonzero rows = {first}")
|
||||||
|
print(f" last nonzero rows = {last}")
|
||||||
|
vocab_size = a.shape[0]
|
||||||
|
base_vocab = vocab_size - special_token_count
|
||||||
|
special_rows = list(range(base_vocab, vocab_size))
|
||||||
|
in_special = [r for r in nz_rows if r in special_rows]
|
||||||
|
out_special = [r for r in nz_rows if r not in special_rows]
|
||||||
|
print(
|
||||||
|
f" diffs in special-token rows ({base_vocab}..{vocab_size - 1}): {len(in_special)}/{special_token_count}"
|
||||||
|
)
|
||||||
|
print(f" diffs in base-vocab rows (0..{base_vocab - 1}) : {len(out_special)}")
|
||||||
|
for r in special_rows:
|
||||||
|
print(
|
||||||
|
f" row {r}: max|Δ|={per_row_max[r].item():.3e}, "
|
||||||
|
f"upstream_norm={a[r].norm().item():.3e}, lerobot_norm={b[r].norm().item():.3e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument("--upstream", required=True)
|
||||||
|
parser.add_argument("--lerobot", required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--special-token-count",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of special tokens Robometer adds. Defaults to len(ROBOMETER_SPECIAL_TOKENS)=5.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Loading upstream: {args.upstream}")
|
||||||
|
upstream = _load_upstream(args.upstream)
|
||||||
|
print(f"Loading LeRobot-format: {args.lerobot}")
|
||||||
|
lerobot = _load_lerobot(args.lerobot)
|
||||||
|
|
||||||
|
_inspect_upstream_state_dict(args.upstream, upstream)
|
||||||
|
|
||||||
|
sd_u, sd_l = upstream.state_dict(), lerobot.state_dict()
|
||||||
|
|
||||||
|
for key in (EMBED_KEY, LMHEAD_KEY):
|
||||||
|
if key not in sd_u or key not in sd_l:
|
||||||
|
print(f"❌ key missing: {key} (upstream={key in sd_u}, lerobot={key in sd_l})")
|
||||||
|
continue
|
||||||
|
_diff_embed(key, sd_u[key], sd_l[key], args.special_token_count)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
168
scripts/extract_libero_episode_for_parity.py
Normal file
168
scripts/extract_libero_episode_for_parity.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
"""Extract one LIBERO episode for Robometer parity testing.
|
||||||
|
|
||||||
|
Loads a LeRobot LIBERO (or any video-bearing LeRobot) dataset, picks one
|
||||||
|
episode, samples ``--num-frames`` frames uniformly across its duration
|
||||||
|
(matching upstream Robometer's default of 8 frames), and saves them to
|
||||||
|
``.npz`` plus a sidecar ``.txt`` task file.
|
||||||
|
|
||||||
|
The ``.npz`` layout (``frames`` key, ``(T, H, W, C) uint8``) is what upstream
|
||||||
|
``example_inference_local.py`` consumes, so the same file feeds both pipelines
|
||||||
|
and frame sampling cannot drift.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
|
||||||
|
1. Run this script (LeRobot env) to produce ``frames.npz`` + ``task.txt``.
|
||||||
|
2. Pass them to upstream ``scripts/example_inference_local.py``
|
||||||
|
(upstream env) to produce reference progress / success outputs.
|
||||||
|
3. Pass the same ``frames.npz`` to ``scripts/parity_robometer.py``
|
||||||
|
(LeRobot env) to compare both sides.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
uv run python scripts/extract_libero_episode_for_parity.py \\
|
||||||
|
--repo-id lerobot/libero_10_image \\
|
||||||
|
--episode 0 \\
|
||||||
|
--num-frames 8 \\
|
||||||
|
--out-dir /tmp/libero_ep0
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_visual_feature(features: dict, requested: str | None) -> str:
|
||||||
|
"""Return a visual feature key, preferring ``requested`` when given."""
|
||||||
|
visual_keys = [
|
||||||
|
key
|
||||||
|
for key, ft in features.items()
|
||||||
|
if getattr(ft, "type", None) == FeatureType.VISUAL or ft.get("dtype", "") == "video"
|
||||||
|
]
|
||||||
|
if not visual_keys:
|
||||||
|
raise ValueError(f"Dataset has no visual feature; available: {list(features)}")
|
||||||
|
if requested is not None:
|
||||||
|
if requested not in visual_keys:
|
||||||
|
raise ValueError(f"Camera key {requested!r} not in dataset visual features {visual_keys}")
|
||||||
|
return requested
|
||||||
|
return visual_keys[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _frame_uint8_hwc(tensor: torch.Tensor) -> np.ndarray:
|
||||||
|
"""Convert a LeRobotDataset video frame to ``uint8`` ``(H, W, C)`` RGB."""
|
||||||
|
arr = tensor.detach().cpu().numpy()
|
||||||
|
if arr.ndim == 3 and arr.shape[0] in (1, 3):
|
||||||
|
arr = arr.transpose(1, 2, 0)
|
||||||
|
if arr.dtype != np.uint8:
|
||||||
|
arr = np.clip(arr * 255.0 if arr.max() <= 1.0 + 1e-3 else arr, 0, 255).astype(np.uint8)
|
||||||
|
if arr.shape[-1] == 1:
|
||||||
|
arr = np.repeat(arr, 3, axis=-1)
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
default="lerobot/libero_10_image",
|
||||||
|
help="LeRobot LIBERO (or other) dataset repo id (default: lerobot/libero_10_image).",
|
||||||
|
)
|
||||||
|
parser.add_argument("--episode", type=int, default=0, help="Episode index.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--camera-key",
|
||||||
|
default=None,
|
||||||
|
help="Visual feature key (e.g. observation.images.image). Auto-selects first if omitted.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-frames",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of frames to sample uniformly (default: 8 — Robometer's training-time default).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("outputs/robometer_parity/libero"),
|
||||||
|
help="Directory to write frames.npz / task.txt / frame_indices.npy.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Loading {args.repo_id} (episode {args.episode})...")
|
||||||
|
dataset = LeRobotDataset(args.repo_id, episodes=[args.episode])
|
||||||
|
|
||||||
|
camera_key = _pick_visual_feature(dataset.features, args.camera_key)
|
||||||
|
print(f"Using camera key: {camera_key}")
|
||||||
|
|
||||||
|
ep_from = int(dataset.episode_data_index["from"][0].item())
|
||||||
|
ep_to = int(dataset.episode_data_index["to"][0].item())
|
||||||
|
total_frames = ep_to - ep_from
|
||||||
|
if total_frames <= 0:
|
||||||
|
print(f"ERROR: episode {args.episode} has no frames.", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
print(f"Episode has {total_frames} frames; sampling {args.num_frames} uniformly.")
|
||||||
|
|
||||||
|
indices = np.linspace(0, total_frames - 1, num=min(args.num_frames, total_frames), dtype=int)
|
||||||
|
frames: list[np.ndarray] = []
|
||||||
|
task: str = ""
|
||||||
|
for offset in indices:
|
||||||
|
sample = dataset[ep_from + int(offset)]
|
||||||
|
frame_tensor = sample[camera_key]
|
||||||
|
frames.append(_frame_uint8_hwc(frame_tensor))
|
||||||
|
if not task:
|
||||||
|
task = sample.get("task", "") or ""
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
print("ERROR: episode has no task description in metadata.", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
frames_array = np.stack(frames)
|
||||||
|
|
||||||
|
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
frames_path = args.out_dir / "frames.npz"
|
||||||
|
task_path = args.out_dir / "task.txt"
|
||||||
|
indices_path = args.out_dir / "frame_indices.npy"
|
||||||
|
|
||||||
|
np.savez(frames_path, frames=frames_array)
|
||||||
|
task_path.write_text(task + "\n", encoding="utf-8")
|
||||||
|
np.save(indices_path, indices)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(f"Wrote {frames_path} (shape={frames_array.shape}, dtype={frames_array.dtype})")
|
||||||
|
print(f"Wrote {task_path} (task={task!r})")
|
||||||
|
print(f"Wrote {indices_path} (frame_indices={indices.tolist()})")
|
||||||
|
print()
|
||||||
|
print("Next steps:")
|
||||||
|
print(" # in upstream env (where `robometer` is importable):")
|
||||||
|
print(
|
||||||
|
f" python third_party/robometer/scripts/example_inference_local.py \\\n"
|
||||||
|
f" --model-path robometer/Robometer-4B \\\n"
|
||||||
|
f" --video {frames_path} \\\n"
|
||||||
|
f' --task "{task}" \\\n'
|
||||||
|
f" --out {args.out_dir / 'upstream.npy'}"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
print(" # back in LeRobot env:")
|
||||||
|
print(
|
||||||
|
f" uv run python scripts/parity_robometer.py \\\n"
|
||||||
|
f" --frames {frames_path} \\\n"
|
||||||
|
f' --task "{task}" \\\n'
|
||||||
|
f" --upstream-progress {args.out_dir / 'upstream.npy'} \\\n"
|
||||||
|
f" --upstream-success {args.out_dir / 'upstream_success_probs.npy'}"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
232
scripts/parity_robometer.py
Normal file
232
scripts/parity_robometer.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
"""Functional parity check: LeRobot Robometer vs. upstream Robometer.
|
||||||
|
|
||||||
|
Runs the in-tree :class:`RobometerRewardModel` on the same frames + task that
|
||||||
|
upstream Robometer was run on, and compares per-frame progress / success
|
||||||
|
predictions against reference outputs saved by upstream's
|
||||||
|
``scripts/example_inference_local.py``.
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
|
||||||
|
1. In the upstream Robometer environment (where ``robometer`` is importable),
|
||||||
|
run::
|
||||||
|
|
||||||
|
python third_party/robometer/scripts/example_inference_local.py \\
|
||||||
|
--model-path robometer/Robometer-4B \\
|
||||||
|
--video /path/to/episode.mp4 \\
|
||||||
|
--task "Open the drawer" \\
|
||||||
|
--fps 1.0 \\
|
||||||
|
--out /tmp/robometer_upstream.npy
|
||||||
|
|
||||||
|
This produces:
|
||||||
|
- ``/tmp/robometer_upstream.npy`` (progress predictions)
|
||||||
|
- ``/tmp/robometer_upstream_success_probs.npy`` (success probabilities)
|
||||||
|
|
||||||
|
2. Extract the exact same frames the upstream script used, save as ``.npz``::
|
||||||
|
|
||||||
|
# quick helper: extract frames at the same fps and save as .npz
|
||||||
|
python -c "
|
||||||
|
from third_party.robometer.scripts.example_inference_local import load_frames_input
|
||||||
|
import numpy as np
|
||||||
|
frames = load_frames_input('/path/to/episode.mp4', fps=1.0, max_frames=512)
|
||||||
|
np.savez('/tmp/robometer_frames.npz', frames=frames)
|
||||||
|
"
|
||||||
|
|
||||||
|
3. In this LeRobot env, run this script::
|
||||||
|
|
||||||
|
uv run python scripts/parity_robometer.py \\
|
||||||
|
--frames /tmp/robometer_frames.npz \\
|
||||||
|
--task "Open the drawer" \\
|
||||||
|
--upstream-progress /tmp/robometer_upstream.npy \\
|
||||||
|
--upstream-success /tmp/robometer_upstream_success_probs.npy \\
|
||||||
|
--lerobot-model lilkm/robometer-4b
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
|
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||||
|
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
|
||||||
|
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||||
|
|
||||||
|
|
||||||
|
def _load_frames(path: str) -> np.ndarray:
|
||||||
|
"""Load frames from .npy/.npz. Expects (T, H, W, C) uint8."""
|
||||||
|
if path.endswith(".npy"):
|
||||||
|
frames = np.load(path)
|
||||||
|
elif path.endswith(".npz"):
|
||||||
|
with np.load(path, allow_pickle=False) as npz:
|
||||||
|
frames = npz["frames"].copy() if "frames" in npz else next(iter(npz.values())).copy()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Frames must be .npy or .npz (got {path!r}).")
|
||||||
|
|
||||||
|
if frames.dtype != np.uint8:
|
||||||
|
frames = np.clip(frames, 0, 255).astype(np.uint8)
|
||||||
|
if frames.ndim != 4:
|
||||||
|
raise ValueError(f"Frames must be 4D (T,H,W,C); got shape {frames.shape}.")
|
||||||
|
if frames.shape[-1] not in (1, 3):
|
||||||
|
# Probably (T,C,H,W) — transpose
|
||||||
|
if frames.shape[1] in (1, 3):
|
||||||
|
frames = frames.transpose(0, 2, 3, 1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot interpret frame channel layout: {frames.shape}.")
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def _run_lerobot(
|
||||||
|
frames: np.ndarray,
|
||||||
|
task: str,
|
||||||
|
model_path: str,
|
||||||
|
device: str,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Run LeRobot's Robometer on the given frames; return (progress, success)."""
|
||||||
|
cfg = RobometerConfig(pretrained_path=model_path, device=device, max_frames=None)
|
||||||
|
model = RobometerRewardModel.from_pretrained(model_path, config=cfg)
|
||||||
|
|
||||||
|
encoder = RobometerEncoderProcessorStep(
|
||||||
|
base_model_id=model.config.base_model_id,
|
||||||
|
use_multi_image=model.config.use_multi_image,
|
||||||
|
use_per_frame_progress_token=model.config.use_per_frame_progress_token,
|
||||||
|
max_frames=None,
|
||||||
|
)
|
||||||
|
batch = encoder.encode_samples([(frames, task)])
|
||||||
|
|
||||||
|
model_device = next(model.model.parameters()).device
|
||||||
|
inputs = {key: value.to(model_device) if hasattr(value, "to") else value for key, value in batch.items()}
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
progress_logits, success_logits = model._compute_rbm_logits(inputs)
|
||||||
|
|
||||||
|
decoded = decode_progress_outputs(
|
||||||
|
progress_logits,
|
||||||
|
success_logits,
|
||||||
|
is_discrete_mode=model.config.use_discrete_progress,
|
||||||
|
)
|
||||||
|
progress = np.asarray(decoded["progress_pred"][0], dtype=np.float32)
|
||||||
|
success = (
|
||||||
|
np.asarray(decoded["success_probs"][0], dtype=np.float32)
|
||||||
|
if decoded["success_probs"]
|
||||||
|
else np.array([], dtype=np.float32)
|
||||||
|
)
|
||||||
|
return progress, success
|
||||||
|
|
||||||
|
|
||||||
|
def _compare(name: str, lerobot: np.ndarray, upstream: np.ndarray, atol: float, rtol: float) -> bool:
|
||||||
|
print(f"\n=== {name} ===")
|
||||||
|
if lerobot.shape != upstream.shape:
|
||||||
|
print(f"shape mismatch: lerobot={lerobot.shape} upstream={upstream.shape}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
abs_diff = np.abs(lerobot - upstream)
|
||||||
|
rel_diff = abs_diff / (np.abs(upstream) + 1e-12)
|
||||||
|
print(f"shape : {lerobot.shape}")
|
||||||
|
print(f"max |Δ| : {abs_diff.max():.3e}")
|
||||||
|
print(f"mean |Δ| : {abs_diff.mean():.3e}")
|
||||||
|
print(f"max rel |Δ| : {rel_diff.max():.3e}")
|
||||||
|
print(f"lerobot[:5] : {lerobot[:5]}")
|
||||||
|
print(f"upstream[:5] : {upstream[:5]}")
|
||||||
|
|
||||||
|
within_tol = bool(np.allclose(lerobot, upstream, atol=atol, rtol=rtol))
|
||||||
|
print(f"allclose(atol={atol}, rtol={rtol}) -> {within_tol}")
|
||||||
|
return within_tol
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--frames",
|
||||||
|
required=True,
|
||||||
|
help=".npy / .npz file with the exact frames upstream was run on (T,H,W,C uint8).",
|
||||||
|
)
|
||||||
|
parser.add_argument("--task", required=True, help="Task instruction string.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--upstream-progress",
|
||||||
|
required=True,
|
||||||
|
help="Reference progress .npy saved by upstream example_inference_local.py.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upstream-success",
|
||||||
|
default=None,
|
||||||
|
help="Optional reference success_probs .npy. If omitted, success comparison is skipped.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lerobot-model",
|
||||||
|
default="lilkm/robometer-4b",
|
||||||
|
help="LeRobot-format Robometer Hub repo id or local path.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="Device for the LeRobot model (default: cuda if available).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--atol",
|
||||||
|
type=float,
|
||||||
|
default=1e-3,
|
||||||
|
help="Absolute tolerance for allclose (default: 1e-3; bf16 round-trip headroom).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rtol",
|
||||||
|
type=float,
|
||||||
|
default=1e-2,
|
||||||
|
help="Relative tolerance for allclose (default: 1e-2).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out-prefix",
|
||||||
|
default="lerobot_robometer_outputs",
|
||||||
|
help="Save the LeRobot outputs as <prefix>_progress.npy / <prefix>_success.npy.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 0. Sanity: confirm the LeRobot config is a RobometerConfig.
|
||||||
|
cfg = RewardModelConfig.from_pretrained(args.lerobot_model)
|
||||||
|
if not isinstance(cfg, RobometerConfig):
|
||||||
|
print(f"ERROR: {args.lerobot_model!r} does not resolve to a RobometerConfig.", file=sys.stderr)
|
||||||
|
return 2
|
||||||
|
|
||||||
|
# 1. Load frames + task + upstream reference outputs.
|
||||||
|
frames = _load_frames(args.frames)
|
||||||
|
upstream_progress = np.load(args.upstream_progress).astype(np.float32)
|
||||||
|
upstream_success = (
|
||||||
|
np.load(args.upstream_success).astype(np.float32) if args.upstream_success is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Loaded {frames.shape[0]} frames at {frames.shape[1:]}, task={args.task!r}")
|
||||||
|
print(f"LeRobot model: {args.lerobot_model} device: {args.device}")
|
||||||
|
|
||||||
|
# 2. Run LeRobot pipeline.
|
||||||
|
progress, success = _run_lerobot(frames, args.task, args.lerobot_model, args.device)
|
||||||
|
np.save(f"{args.out_prefix}_progress.npy", progress)
|
||||||
|
if success.size > 0:
|
||||||
|
np.save(f"{args.out_prefix}_success.npy", success)
|
||||||
|
print(f"Saved LeRobot outputs to {args.out_prefix}_progress.npy / _success.npy")
|
||||||
|
|
||||||
|
# 3. Compare to upstream references.
|
||||||
|
progress_ok = _compare("progress", progress, upstream_progress, args.atol, args.rtol)
|
||||||
|
if upstream_success is not None and success.size > 0:
|
||||||
|
success_ok = _compare("success_probs", success, upstream_success, args.atol, args.rtol)
|
||||||
|
else:
|
||||||
|
success_ok = True
|
||||||
|
print("\n(skipping success comparison — upstream success file not provided)")
|
||||||
|
|
||||||
|
print()
|
||||||
|
if progress_ok and success_ok:
|
||||||
|
print("Parity check passed.")
|
||||||
|
return 0
|
||||||
|
print("Parity check FAILED.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
362
scripts/parity_robometer_upstream_examples.py
Normal file
362
scripts/parity_robometer_upstream_examples.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
"""Run LeRobot Robometer parity against upstream Robometer's bundled examples.
|
||||||
|
|
||||||
|
Upstream Robometer ships three reference videos with their pre-computed
|
||||||
|
progress / success outputs at
|
||||||
|
``third_party/robometer/scripts/example_videos/``::
|
||||||
|
|
||||||
|
soar_put_green_stick_in_brown_bowl.mp4
|
||||||
|
+ soar_put_green_stick_in_brown_bowl_rewards.npy (progress)
|
||||||
|
+ soar_put_green_stick_in_brown_bowl_rewards_success_probs.npy (success)
|
||||||
|
berkeley_rpt_stack_cup.mp4
|
||||||
|
+ berkeley_rpt_stack_cup_rewards.npy
|
||||||
|
+ berkeley_rpt_stack_cup_rewards_success_probs.npy
|
||||||
|
jaco_play_pick_up_green_cup.mp4
|
||||||
|
+ pick_up_green_cup_rewards.npy
|
||||||
|
+ pick_up_green_cup_rewards_success_probs.npy
|
||||||
|
|
||||||
|
This script:
|
||||||
|
1. Decodes each video at upstream's sampling fps using ``av`` (PyAV), with the
|
||||||
|
same linspace-over-total-frames logic as upstream's ``extract_frames``.
|
||||||
|
2. Runs the LeRobot ``RobometerRewardModel`` on those frames + the task from
|
||||||
|
upstream's README.
|
||||||
|
3. Compares per-frame progress / success to the pre-saved upstream outputs.
|
||||||
|
|
||||||
|
This means you do **not** need to install upstream Robometer to confirm parity.
|
||||||
|
|
||||||
|
Run::
|
||||||
|
|
||||||
|
uv run python scripts/parity_robometer_upstream_examples.py \\
|
||||||
|
--lerobot-model lilkm/robometer-4b \\
|
||||||
|
--device cuda \\
|
||||||
|
--decoder decord
|
||||||
|
|
||||||
|
The number of frames sampled per video is derived from the length of each
|
||||||
|
upstream ``.npy`` reference, so the script does not need a ``--fps`` argument
|
||||||
|
(the README documents ``fps=3`` for SOAR / Berkeley, but the Jaco Play
|
||||||
|
reference was generated with a different fps).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
|
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||||
|
from lerobot.rewards.robometer.modeling_robometer import decode_progress_outputs
|
||||||
|
from lerobot.rewards.robometer.processor_robometer import RobometerEncoderProcessorStep
|
||||||
|
|
||||||
|
try:
|
||||||
|
import decord # type: ignore
|
||||||
|
|
||||||
|
_HAS_DECORD = True
|
||||||
|
except ImportError:
|
||||||
|
decord = None # type: ignore
|
||||||
|
_HAS_DECORD = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import av
|
||||||
|
|
||||||
|
_HAS_AV = True
|
||||||
|
except ImportError:
|
||||||
|
av = None # type: ignore
|
||||||
|
_HAS_AV = False
|
||||||
|
|
||||||
|
EXAMPLES = [
|
||||||
|
{
|
||||||
|
"name": "soar_put_green_stick_in_brown_bowl",
|
||||||
|
"video": "soar_put_green_stick_in_brown_bowl.mp4",
|
||||||
|
"task": "Put green stick in brown bowl",
|
||||||
|
"progress_npy": "soar_put_green_stick_in_brown_bowl_rewards.npy",
|
||||||
|
"success_npy": "soar_put_green_stick_in_brown_bowl_rewards_success_probs.npy",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "berkeley_rpt_stack_cup",
|
||||||
|
"video": "berkeley_rpt_stack_cup.mp4",
|
||||||
|
"task": "Pick up the yellow cup and stack it on the other cup",
|
||||||
|
"progress_npy": "berkeley_rpt_stack_cup_rewards.npy",
|
||||||
|
"success_npy": "berkeley_rpt_stack_cup_rewards_success_probs.npy",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "jaco_play_pick_up_green_cup",
|
||||||
|
"video": "jaco_play_pick_up_green_cup.mp4",
|
||||||
|
"task": "Pick up the green cup",
|
||||||
|
"progress_npy": "pick_up_green_cup_rewards.npy",
|
||||||
|
"success_npy": "pick_up_green_cup_rewards_success_probs.npy",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_frames_decord(video_path: Path, num_frames: int) -> tuple[np.ndarray, str]:
|
||||||
|
"""Sample ``num_frames`` indices uniformly from the video using decord.
|
||||||
|
|
||||||
|
Mirrors upstream's ``extract_frames`` indexing
|
||||||
|
(``third_party/robometer/scripts/example_inference.py``): a
|
||||||
|
``np.linspace(0, total_frames-1, num_frames)`` lookup over decord's
|
||||||
|
``VideoReader``. We pass ``num_frames`` explicitly (derived from the
|
||||||
|
upstream reference output length) so we don't have to guess what ``fps``
|
||||||
|
upstream actually used when generating each saved ``.npy`` — the file
|
||||||
|
length is the ground truth.
|
||||||
|
"""
|
||||||
|
vr = decord.VideoReader(str(video_path), num_threads=1)
|
||||||
|
total_frames = len(vr)
|
||||||
|
if total_frames == 0:
|
||||||
|
raise RuntimeError(f"No decodable frames in {video_path}.")
|
||||||
|
desired_frames = max(1, min(int(num_frames), total_frames))
|
||||||
|
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
|
||||||
|
frames = vr.get_batch(indices).asnumpy()
|
||||||
|
native_fps = float(vr.get_avg_fps()) or 1.0
|
||||||
|
return frames, f"decord total={total_frames} native_fps={native_fps:.3f}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_frames_av(video_path: Path, num_frames: int) -> tuple[np.ndarray, str]:
|
||||||
|
"""PyAV fallback for environments without decord.
|
||||||
|
|
||||||
|
PyAV and decord can disagree on ``total_frames`` for the same container,
|
||||||
|
so the sampled frame indices can drift. Install ``decord`` for a real
|
||||||
|
parity check; this fallback is for smoke tests only.
|
||||||
|
"""
|
||||||
|
container = av.open(str(video_path))
|
||||||
|
stream = container.streams.video[0]
|
||||||
|
native_fps = float(stream.average_rate) if stream.average_rate else float(stream.guessed_rate or 30.0)
|
||||||
|
rgb_frames: list[np.ndarray] = []
|
||||||
|
for frame in container.decode(stream):
|
||||||
|
rgb_frames.append(frame.to_ndarray(format="rgb24"))
|
||||||
|
container.close()
|
||||||
|
total_frames = len(rgb_frames)
|
||||||
|
if total_frames == 0:
|
||||||
|
raise RuntimeError(f"No decodable frames in {video_path}.")
|
||||||
|
desired_frames = max(1, min(int(num_frames), total_frames))
|
||||||
|
indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int)
|
||||||
|
frames = np.stack([rgb_frames[i] for i in indices])
|
||||||
|
return frames, f"av total={total_frames} native_fps={native_fps:.3f}"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_frames(video_path: Path, num_frames: int, prefer: str) -> tuple[np.ndarray, str]:
|
||||||
|
"""Decoder dispatch. ``prefer`` is ``"decord"`` | ``"av"`` | ``"auto"``."""
|
||||||
|
if prefer == "decord":
|
||||||
|
if not _HAS_DECORD:
|
||||||
|
raise RuntimeError("decord requested but not installed (`uv pip install decord`).")
|
||||||
|
return _extract_frames_decord(video_path, num_frames)
|
||||||
|
if prefer == "av":
|
||||||
|
if not _HAS_AV:
|
||||||
|
raise RuntimeError("av requested but not installed.")
|
||||||
|
return _extract_frames_av(video_path, num_frames)
|
||||||
|
# auto
|
||||||
|
if _HAS_DECORD:
|
||||||
|
return _extract_frames_decord(video_path, num_frames)
|
||||||
|
if _HAS_AV:
|
||||||
|
return _extract_frames_av(video_path, num_frames)
|
||||||
|
raise RuntimeError("No video decoder available (install `decord` or `av`).")
|
||||||
|
|
||||||
|
|
||||||
|
def _pearson(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
|
"""Pearson correlation; returns 1.0 for constant inputs (no signal to align)."""
|
||||||
|
a = a.astype(np.float64)
|
||||||
|
b = b.astype(np.float64)
|
||||||
|
if a.size < 2:
|
||||||
|
return 1.0
|
||||||
|
da = a - a.mean()
|
||||||
|
db = b - b.mean()
|
||||||
|
denom = float(np.sqrt((da * da).sum()) * np.sqrt((db * db).sum()))
|
||||||
|
if denom == 0:
|
||||||
|
return 1.0
|
||||||
|
return float((da * db).sum() / denom)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_lerobot(
|
||||||
|
model: RobometerRewardModel,
|
||||||
|
encoder: RobometerEncoderProcessorStep,
|
||||||
|
frames: np.ndarray,
|
||||||
|
task: str,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
batch = encoder.encode_samples([(frames, task)])
|
||||||
|
device = next(model.model.parameters()).device
|
||||||
|
inputs = {key: value.to(device) if hasattr(value, "to") else value for key, value in batch.items()}
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
progress_logits, success_logits = model._compute_rbm_logits(inputs)
|
||||||
|
decoded = decode_progress_outputs(
|
||||||
|
progress_logits, success_logits, is_discrete_mode=model.config.use_discrete_progress
|
||||||
|
)
|
||||||
|
progress = np.asarray(decoded["progress_pred"][0], dtype=np.float32)
|
||||||
|
success = (
|
||||||
|
np.asarray(decoded["success_probs"][0], dtype=np.float32)
|
||||||
|
if decoded["success_probs"]
|
||||||
|
else np.array([], dtype=np.float32)
|
||||||
|
)
|
||||||
|
return progress, success
|
||||||
|
|
||||||
|
|
||||||
|
def _compare(
|
||||||
|
name: str,
|
||||||
|
lerobot: np.ndarray,
|
||||||
|
upstream: np.ndarray,
|
||||||
|
*,
|
||||||
|
atol: float,
|
||||||
|
pearson_min: float,
|
||||||
|
) -> bool:
|
||||||
|
if lerobot.shape != upstream.shape:
|
||||||
|
print(f" {name:8s} SHAPE MISMATCH lerobot={lerobot.shape} upstream={upstream.shape}")
|
||||||
|
return False
|
||||||
|
abs_diff = np.abs(lerobot - upstream)
|
||||||
|
pearson = _pearson(lerobot, upstream)
|
||||||
|
abs_ok = bool(abs_diff.max() <= atol)
|
||||||
|
pearson_ok = bool(pearson >= pearson_min)
|
||||||
|
verdict = "PASS" if (abs_ok or pearson_ok) else "FAIL"
|
||||||
|
print(
|
||||||
|
f" {name:8s} shape={lerobot.shape} max|Δ|={abs_diff.max():.3e} "
|
||||||
|
f"mean|Δ|={abs_diff.mean():.3e} pearson={pearson:.4f} "
|
||||||
|
f"(atol={atol:.0e} pearson_min={pearson_min:.3f}) -> {verdict}"
|
||||||
|
)
|
||||||
|
return abs_ok or pearson_ok
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=__doc__,
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--examples-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("third_party/robometer/scripts/example_videos"),
|
||||||
|
help="Directory containing the upstream Robometer example mp4s + .npy outputs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lerobot-model",
|
||||||
|
default="lilkm/robometer-4b",
|
||||||
|
help="LeRobot-format Robometer Hub repo id or local path.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||||
|
help="Device for the LeRobot model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder",
|
||||||
|
choices=("auto", "decord", "av"),
|
||||||
|
default="auto",
|
||||||
|
help=(
|
||||||
|
"Video decoder. ``auto`` prefers decord (matches upstream) and falls back to av. "
|
||||||
|
"Force ``decord`` for a clean parity check."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--progress-atol",
|
||||||
|
type=float,
|
||||||
|
default=1e-2,
|
||||||
|
help="Absolute tolerance for the progress array. Default 1e-2 covers CUDA bf16 noise.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--success-atol",
|
||||||
|
type=float,
|
||||||
|
default=1e-1,
|
||||||
|
help=(
|
||||||
|
"Absolute tolerance for the success array. Looser than progress because "
|
||||||
|
"``sigmoid`` amplifies logit-space noise near 0.5."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pearson-min",
|
||||||
|
type=float,
|
||||||
|
default=0.99,
|
||||||
|
help="Minimum Pearson correlation for a PASS verdict (per array).",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.decoder == "av" or (args.decoder == "auto" and not _HAS_DECORD):
|
||||||
|
print(
|
||||||
|
"WARNING: using PyAV decoder. PyAV's total-frame count can differ from decord's, "
|
||||||
|
"which propagates into different sampled-frame indices. Install `decord` and "
|
||||||
|
"re-run for a clean parity check.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
examples_dir = args.examples_dir.resolve()
|
||||||
|
if not examples_dir.is_dir():
|
||||||
|
print(f"ERROR: examples dir {examples_dir} does not exist.", file=sys.stderr)
|
||||||
|
return 2
|
||||||
|
|
||||||
|
# Sanity-check the LeRobot config is a RobometerConfig before loading weights.
|
||||||
|
cfg = RewardModelConfig.from_pretrained(args.lerobot_model)
|
||||||
|
if not isinstance(cfg, RobometerConfig):
|
||||||
|
print(f"ERROR: {args.lerobot_model!r} did not resolve to a RobometerConfig.", file=sys.stderr)
|
||||||
|
return 2
|
||||||
|
|
||||||
|
print(f"Loading LeRobot Robometer from {args.lerobot_model} on {args.device}...")
|
||||||
|
cfg.pretrained_path = args.lerobot_model
|
||||||
|
cfg.device = args.device
|
||||||
|
model = RobometerRewardModel.from_pretrained(args.lerobot_model, config=cfg)
|
||||||
|
encoder = RobometerEncoderProcessorStep(
|
||||||
|
base_model_id=model.config.base_model_id,
|
||||||
|
use_multi_image=model.config.use_multi_image,
|
||||||
|
use_per_frame_progress_token=model.config.use_per_frame_progress_token,
|
||||||
|
max_frames=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_ok = True
|
||||||
|
for ex in EXAMPLES:
|
||||||
|
video_path = examples_dir / ex["video"]
|
||||||
|
upstream_progress_path = examples_dir / ex["progress_npy"]
|
||||||
|
upstream_success_path = examples_dir / ex["success_npy"]
|
||||||
|
|
||||||
|
missing = [p for p in (video_path, upstream_progress_path, upstream_success_path) if not p.exists()]
|
||||||
|
if missing:
|
||||||
|
print(f"[skip] {ex['name']}: missing {[str(m) for m in missing]}")
|
||||||
|
all_ok = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n=== {ex['name']} ===")
|
||||||
|
print(f" task: {ex['task']!r}")
|
||||||
|
|
||||||
|
# Trust the upstream reference array as the source of truth for how
|
||||||
|
# many frames to sample. The README documents fps=3 for SOAR/Berkeley
|
||||||
|
# but Jaco Play was generated with a different fps, so any hardcoded
|
||||||
|
# ``--fps`` mismatches at least one example. The npy length always
|
||||||
|
# tells us what upstream actually used.
|
||||||
|
upstream_progress = np.load(upstream_progress_path).astype(np.float32)
|
||||||
|
upstream_success = np.load(upstream_success_path).astype(np.float32)
|
||||||
|
target_num_frames = int(upstream_progress.shape[0])
|
||||||
|
frames, decoder_info = _extract_frames(video_path, target_num_frames, prefer=args.decoder)
|
||||||
|
print(
|
||||||
|
f" decoded {frames.shape[0]} frames (matches upstream npy length); "
|
||||||
|
f"shape={frames.shape} [{decoder_info}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
progress, success = _run_lerobot(model, encoder, frames, ex["task"])
|
||||||
|
|
||||||
|
progress_ok = _compare(
|
||||||
|
"progress",
|
||||||
|
progress,
|
||||||
|
upstream_progress,
|
||||||
|
atol=args.progress_atol,
|
||||||
|
pearson_min=args.pearson_min,
|
||||||
|
)
|
||||||
|
success_ok = _compare(
|
||||||
|
"success",
|
||||||
|
success,
|
||||||
|
upstream_success,
|
||||||
|
atol=args.success_atol,
|
||||||
|
pearson_min=args.pearson_min,
|
||||||
|
)
|
||||||
|
verdict = "PASS" if (progress_ok and success_ok) else "FAIL"
|
||||||
|
print(f" -> {verdict}")
|
||||||
|
all_ok = all_ok and progress_ok and success_ok
|
||||||
|
|
||||||
|
print()
|
||||||
|
if all_ok:
|
||||||
|
print("All upstream example parity checks passed.")
|
||||||
|
return 0
|
||||||
|
print("Some upstream example parity checks FAILED.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
149
scripts/verify_robometer_export.py
Normal file
149
scripts/verify_robometer_export.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
"""Verify that a LeRobot-format Robometer is byte-equivalent to its upstream source.
|
||||||
|
|
||||||
|
Run this once after publishing a LeRobot-format Robometer to the Hub, before
|
||||||
|
flipping the default `RobometerConfig.pretrained_path` to it. It loads both
|
||||||
|
the upstream snapshot and the re-exported copy, compares state dicts, and
|
||||||
|
prints a clear pass/fail summary.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
python scripts/verify_robometer_export.py \\
|
||||||
|
--upstream robometer/Robometer-4B \\
|
||||||
|
--lerobot lerobot/robometer-4b
|
||||||
|
|
||||||
|
python scripts/verify_robometer_export.py \\
|
||||||
|
--upstream robometer/Robometer-4B \\
|
||||||
|
--lerobot ./robometer-4b-lerobot # local folder also works
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
|
from lerobot.rewards.robometer import RobometerConfig, RobometerRewardModel
|
||||||
|
from lerobot.rewards.robometer._upstream_loader import apply_upstream_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def _load_upstream(path: str) -> RobometerRewardModel:
|
||||||
|
# Fresh ``RobometerConfig`` (``vlm_config=None``) triggers
|
||||||
|
# ``RobometerRewardModel.__init__``'s upstream-matching path: download
|
||||||
|
# base Qwen, resize for ROBOMETER_SPECIAL_TOKENS. The subsequent
|
||||||
|
# ``apply_upstream_checkpoint`` call resizes again if the checkpoint's
|
||||||
|
# vocab differs (e.g. upstream was trained against an older Qwen).
|
||||||
|
cfg = RobometerConfig(pretrained_path=path, device="cpu")
|
||||||
|
model = RobometerRewardModel(cfg)
|
||||||
|
apply_upstream_checkpoint(model, path)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _load_lerobot(path: str) -> RobometerRewardModel:
|
||||||
|
cfg = RewardModelConfig.from_pretrained(path)
|
||||||
|
if not isinstance(cfg, RobometerConfig):
|
||||||
|
raise TypeError(f"Expected RobometerConfig in LeRobot export, got {type(cfg)}")
|
||||||
|
cfg.pretrained_path = path
|
||||||
|
cfg.device = "cpu"
|
||||||
|
return RobometerRewardModel.from_pretrained(path, config=cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_state_dicts(a: RobometerRewardModel, b: RobometerRewardModel) -> bool:
|
||||||
|
sd_a, sd_b = a.state_dict(), b.state_dict()
|
||||||
|
keys_a, keys_b = set(sd_a), set(sd_b)
|
||||||
|
|
||||||
|
missing = keys_a - keys_b
|
||||||
|
extra = keys_b - keys_a
|
||||||
|
if missing:
|
||||||
|
print(f"❌ {len(missing)} keys missing in LeRobot-format model (sample: {list(missing)[:5]})")
|
||||||
|
if extra:
|
||||||
|
print(f"❌ {len(extra)} extra keys in LeRobot-format model (sample: {list(extra)[:5]})")
|
||||||
|
if missing or extra:
|
||||||
|
return False
|
||||||
|
|
||||||
|
diff_summary: list[tuple[str, float]] = []
|
||||||
|
for key in sorted(keys_a):
|
||||||
|
ta, tb = sd_a[key], sd_b[key]
|
||||||
|
if ta.shape != tb.shape:
|
||||||
|
print(f"❌ shape mismatch at {key}: {tuple(ta.shape)} vs {tuple(tb.shape)}")
|
||||||
|
return False
|
||||||
|
# Compare in float to avoid bfloat16 equality quirks.
|
||||||
|
max_abs = (ta.float() - tb.float()).abs().max().item()
|
||||||
|
if max_abs > 0:
|
||||||
|
diff_summary.append((key, max_abs))
|
||||||
|
|
||||||
|
if not diff_summary:
|
||||||
|
print(f"✅ All {len(keys_a)} parameters identical")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Some keys differ; show worst offenders.
|
||||||
|
diff_summary.sort(key=lambda kv: kv[1], reverse=True)
|
||||||
|
print(f"⚠️ {len(diff_summary)} keys differ. Top 10 by max abs diff:")
|
||||||
|
for key, value in diff_summary[:10]:
|
||||||
|
print(f" {key:60s} max|Δ| = {value:.3e}")
|
||||||
|
|
||||||
|
# Tolerance: bf16 round-trips can introduce ULP-level noise but no real
|
||||||
|
# change. Allow up to 1e-3 absolute difference; anything larger is a real
|
||||||
|
# divergence.
|
||||||
|
worst = diff_summary[0][1]
|
||||||
|
if worst < 1e-3:
|
||||||
|
print(f"✅ Worst diff {worst:.3e} is within bf16 round-trip tolerance")
|
||||||
|
return True
|
||||||
|
print(f"❌ Worst diff {worst:.3e} exceeds tolerance (1e-3)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
parser.add_argument("--upstream", required=True, help="Upstream Robometer repo id or local path.")
|
||||||
|
parser.add_argument("--lerobot", required=True, help="LeRobot-format Robometer repo id or local path.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Loading upstream: {args.upstream}")
|
||||||
|
upstream = _load_upstream(args.upstream)
|
||||||
|
print(f"Loading LeRobot-format: {args.lerobot}")
|
||||||
|
lerobot = _load_lerobot(args.lerobot)
|
||||||
|
|
||||||
|
print("\n=== Config comparison ===")
|
||||||
|
config_ok = True
|
||||||
|
for field in [
|
||||||
|
"base_model_id",
|
||||||
|
"torch_dtype",
|
||||||
|
"use_multi_image",
|
||||||
|
"use_per_frame_progress_token",
|
||||||
|
"average_temporal_patches",
|
||||||
|
"frame_pooling",
|
||||||
|
"frame_pooling_attn_temperature",
|
||||||
|
"progress_loss_type",
|
||||||
|
"progress_discrete_bins",
|
||||||
|
]:
|
||||||
|
a, b = getattr(upstream.config, field), getattr(lerobot.config, field)
|
||||||
|
field_ok = a == b
|
||||||
|
config_ok = config_ok and field_ok
|
||||||
|
ok = "✅" if field_ok else "❌"
|
||||||
|
print(f" {ok} {field}: upstream={a!r}, lerobot={b!r}")
|
||||||
|
|
||||||
|
print("\n=== State-dict comparison ===")
|
||||||
|
state_dict_ok = compare_state_dicts(upstream, lerobot)
|
||||||
|
|
||||||
|
print()
|
||||||
|
if config_ok and state_dict_ok:
|
||||||
|
print("🎉 Verification passed — safe to flip the default.")
|
||||||
|
return 0
|
||||||
|
print("⛔ Verification failed — DO NOT flip the default.")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
@@ -17,6 +17,7 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from threading import Event, Lock, Thread
|
from threading import Event, Lock, Thread
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@@ -41,6 +42,7 @@ from ..utils import get_cv2_rotation
|
|||||||
from .configuration_realsense import RealSenseCameraConfig
|
from .configuration_realsense import RealSenseCameraConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
pkg_name = "pyrealsense2-macosx" if sys.platform == "darwin" else "pyrealsense2"
|
||||||
|
|
||||||
|
|
||||||
class RealSenseCamera(Camera):
|
class RealSenseCamera(Camera):
|
||||||
@@ -114,7 +116,7 @@ class RealSenseCamera(Camera):
|
|||||||
Args:
|
Args:
|
||||||
config: The configuration settings for the camera.
|
config: The configuration settings for the camera.
|
||||||
"""
|
"""
|
||||||
require_package("pyrealsense2", extra="intelrealsense")
|
require_package(pkg_name, extra="intelrealsense", import_name="pyrealsense2")
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ def save_checkpoint(
|
|||||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||||
|
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||||
"""
|
"""
|
||||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||||
policy.save_pretrained(pretrained_dir)
|
policy.save_pretrained(pretrained_dir)
|
||||||
|
|||||||
@@ -41,8 +41,12 @@ def cfg_to_group(
|
|||||||
return tag
|
return tag
|
||||||
return tag[:max_tag_length]
|
return tag[:max_tag_length]
|
||||||
|
|
||||||
|
if cfg.is_reward_model_training:
|
||||||
|
trainable_tag = f"reward_model:{cfg.reward_model.type}"
|
||||||
|
else:
|
||||||
|
trainable_tag = f"policy:{cfg.policy.type}"
|
||||||
lst = [
|
lst = [
|
||||||
f"policy:{cfg.policy.type}",
|
trainable_tag,
|
||||||
f"seed:{cfg.seed}",
|
f"seed:{cfg.seed}",
|
||||||
]
|
]
|
||||||
if cfg.dataset is not None:
|
if cfg.dataset is not None:
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ are intentionally NOT re-exported here to avoid circular dependencies
|
|||||||
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .dataset import DatasetRecordConfig
|
||||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||||
from .policies import PreTrainedConfig
|
from .policies import PreTrainedConfig
|
||||||
from .types import (
|
from .types import (
|
||||||
@@ -39,6 +40,7 @@ __all__ = [
|
|||||||
"PolicyFeature",
|
"PolicyFeature",
|
||||||
"RTCAttentionSchedule",
|
"RTCAttentionSchedule",
|
||||||
# Config classes
|
# Config classes
|
||||||
|
"DatasetRecordConfig",
|
||||||
"DatasetConfig",
|
"DatasetConfig",
|
||||||
"EvalConfig",
|
"EvalConfig",
|
||||||
"PeftConfig",
|
"PeftConfig",
|
||||||
|
|||||||
80
src/lerobot/configs/dataset.py
Normal file
80
src/lerobot/configs/dataset.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetRecordConfig:
|
||||||
|
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||||
|
repo_id: str = ""
|
||||||
|
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||||
|
single_task: str = ""
|
||||||
|
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||||
|
root: str | Path | None = None
|
||||||
|
# Limit the frames per second.
|
||||||
|
fps: int = 30
|
||||||
|
# Number of seconds for data recording for each episode.
|
||||||
|
episode_time_s: int | float = 60
|
||||||
|
# Number of seconds for resetting the environment after each episode.
|
||||||
|
reset_time_s: int | float = 60
|
||||||
|
# Number of episodes to record.
|
||||||
|
num_episodes: int = 50
|
||||||
|
# Encode frames in the dataset into video
|
||||||
|
video: bool = True
|
||||||
|
# Upload dataset to Hugging Face hub.
|
||||||
|
push_to_hub: bool = True
|
||||||
|
# Upload on private repository on the Hugging Face hub.
|
||||||
|
private: bool = False
|
||||||
|
# Add tags to your dataset on the hub.
|
||||||
|
tags: list[str] | None = None
|
||||||
|
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||||
|
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||||
|
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||||
|
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||||
|
num_image_writer_processes: int = 0
|
||||||
|
# Number of threads writing the frames as png images on disk, per camera.
|
||||||
|
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||||
|
# Not enough threads might cause low camera fps.
|
||||||
|
num_image_writer_threads_per_camera: int = 4
|
||||||
|
# Number of episodes to record before batch encoding videos
|
||||||
|
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||||
|
video_encoding_batch_size: int = 1
|
||||||
|
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||||
|
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||||
|
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||||
|
vcodec: str = "libsvtav1"
|
||||||
|
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||||
|
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||||
|
streaming_encoding: bool = False
|
||||||
|
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||||
|
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||||
|
encoder_queue_maxsize: int = 30
|
||||||
|
# Number of threads per encoder instance. None = auto (codec default).
|
||||||
|
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||||
|
encoder_threads: int | None = None
|
||||||
|
|
||||||
|
def stamp_repo_id(self) -> None:
|
||||||
|
"""Append a date-time tag to ``repo_id`` so each recording session gets a unique name.
|
||||||
|
|
||||||
|
Must be called explicitly at dataset *creation* time — not on resume,
|
||||||
|
where the existing ``repo_id`` (already stamped) must be preserved.
|
||||||
|
"""
|
||||||
|
if self.repo_id:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
self.repo_id = f"{self.repo_id}_{timestamp}"
|
||||||
@@ -117,3 +117,9 @@ class PeftConfig:
|
|||||||
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
|
# the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full
|
||||||
# fine-tuning.
|
# fine-tuning.
|
||||||
r: int = 16
|
r: int = 16
|
||||||
|
|
||||||
|
# Alpha parameter for LoRA scaling (scaling = lora_alpha / r).
|
||||||
|
# In general, a higher alpha means stronger adaptation signal.
|
||||||
|
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
|
||||||
|
# Common values are r (alpha == rank) or 2*r.
|
||||||
|
lora_alpha: int | None = None
|
||||||
|
|||||||
@@ -46,8 +46,11 @@ class EvalPipelineConfig:
|
|||||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
if policy_path:
|
if policy_path:
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||||
|
self.policy = PreTrainedConfig.from_pretrained(
|
||||||
|
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||||
|
)
|
||||||
self.policy.pretrained_path = Path(policy_path)
|
self.policy.pretrained_path = Path(policy_path)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -13,8 +13,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
from argparse import ArgumentError
|
from argparse import ArgumentError
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
@@ -24,6 +26,7 @@ from types import ModuleType
|
|||||||
from typing import Any, TypeVar, cast
|
from typing import Any, TypeVar, cast
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
import yaml # type: ignore[import-untyped]
|
||||||
|
|
||||||
from lerobot.utils.utils import has_method
|
from lerobot.utils.utils import has_method
|
||||||
|
|
||||||
@@ -32,6 +35,29 @@ F = TypeVar("F", bound=Callable[..., object])
|
|||||||
PATH_KEY = "path"
|
PATH_KEY = "path"
|
||||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||||
|
|
||||||
|
# Storage for path args extracted from YAML/JSON config files, so that
|
||||||
|
# get_path_arg() can find them even when they weren't passed via CLI.
|
||||||
|
_config_path_args: dict[str, str] = {}
|
||||||
|
|
||||||
|
# Storage for non-path YAML overrides so validate() can pass them to from_pretrained.
|
||||||
|
_config_yaml_overrides: dict[str, list[str]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_to_cli_args(d: dict, prefix: str = "") -> list[str]:
|
||||||
|
"""Recursively flatten a nested dict to CLI-style args (e.g. {"lr": 1e-4} -> ["--lr=0.0001"])."""
|
||||||
|
args = []
|
||||||
|
for key, value in d.items():
|
||||||
|
if key in (PATH_KEY, draccus.CHOICE_TYPE_KEY):
|
||||||
|
continue
|
||||||
|
full_key = f"{prefix}.{key}" if prefix else key
|
||||||
|
if isinstance(value, bool):
|
||||||
|
value = str(value).lower()
|
||||||
|
if isinstance(value, dict):
|
||||||
|
args.extend(_flatten_to_cli_args(value, full_key))
|
||||||
|
elif value is not None and not isinstance(value, list):
|
||||||
|
args.append(f"--{full_key}={value}")
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None:
|
||||||
"""Parses arguments from cli at a given nested attribute level.
|
"""Parses arguments from cli at a given nested attribute level.
|
||||||
@@ -145,7 +171,14 @@ def load_plugin(plugin_path: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
result = parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||||
|
if result is None:
|
||||||
|
result = _config_path_args.get(field_name)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_yaml_overrides(field_name: str) -> list[str]:
|
||||||
|
return _config_yaml_overrides.get(field_name, [])
|
||||||
|
|
||||||
|
|
||||||
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||||
@@ -192,6 +225,52 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
|||||||
return filtered_args
|
return filtered_args
|
||||||
|
|
||||||
|
|
||||||
|
def extract_path_fields_from_config(config_path: str, path_fields: list[str]) -> str:
|
||||||
|
"""Extract `path` fields from a YAML/JSON config before draccus processes it.
|
||||||
|
|
||||||
|
When a user specifies e.g. ``policy.path: lerobot/smolvla_base`` in a YAML config,
|
||||||
|
draccus will fail because ``path`` is not a valid field on policy config classes.
|
||||||
|
This function extracts those path values, stores them in ``_config_path_args`` for
|
||||||
|
later retrieval by ``get_path_arg()``, and returns a cleaned temp config file path.
|
||||||
|
"""
|
||||||
|
config_file = Path(config_path)
|
||||||
|
suffix = config_file.suffix.lower()
|
||||||
|
|
||||||
|
if suffix in (".yaml", ".yml"):
|
||||||
|
with open(config_file) as f:
|
||||||
|
config_data = yaml.safe_load(f)
|
||||||
|
elif suffix == ".json":
|
||||||
|
with open(config_file) as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
else:
|
||||||
|
return config_path
|
||||||
|
|
||||||
|
if not isinstance(config_data, dict):
|
||||||
|
return config_path
|
||||||
|
|
||||||
|
modified = False
|
||||||
|
for field in path_fields:
|
||||||
|
if field in config_data and isinstance(config_data[field], dict) and PATH_KEY in config_data[field]:
|
||||||
|
_config_path_args[field] = str(config_data[field].pop(PATH_KEY))
|
||||||
|
remaining = config_data[field]
|
||||||
|
if remaining:
|
||||||
|
_config_yaml_overrides[field] = _flatten_to_cli_args(remaining)
|
||||||
|
else:
|
||||||
|
del config_data[field]
|
||||||
|
modified = True
|
||||||
|
|
||||||
|
if not modified:
|
||||||
|
return config_path
|
||||||
|
|
||||||
|
# Write cleaned config to a temp file
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False) as tmp:
|
||||||
|
if suffix in (".yaml", ".yml"):
|
||||||
|
yaml.dump(config_data, tmp, default_flow_style=False)
|
||||||
|
else:
|
||||||
|
json.dump(config_data, tmp, indent=2)
|
||||||
|
return tmp.name
|
||||||
|
|
||||||
|
|
||||||
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
||||||
"""
|
"""
|
||||||
HACK: Similar to draccus.wrap but does three additional things:
|
HACK: Similar to draccus.wrap but does three additional things:
|
||||||
@@ -225,6 +304,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
|
|||||||
if has_method(argtype, "__get_path_fields__"):
|
if has_method(argtype, "__get_path_fields__"):
|
||||||
path_fields = argtype.__get_path_fields__()
|
path_fields = argtype.__get_path_fields__()
|
||||||
cli_args = filter_path_args(path_fields, cli_args)
|
cli_args = filter_path_args(path_fields, cli_args)
|
||||||
|
# Also extract path fields from the YAML/JSON config file
|
||||||
|
if config_path_cli:
|
||||||
|
config_path_cli = extract_path_fields_from_config(config_path_cli, path_fields)
|
||||||
if has_method(argtype, "from_pretrained") and config_path_cli:
|
if has_method(argtype, "from_pretrained") and config_path_cli:
|
||||||
cli_args = filter_arg("config_path", cli_args)
|
cli_args = filter_arg("config_path", cli_args)
|
||||||
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
cfg = argtype.from_pretrained(config_path_cli, cli_args=cli_args)
|
||||||
|
|||||||
170
src/lerobot/configs/rewards.py
Normal file
170
src/lerobot/configs/rewards.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import builtins
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
import draccus
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.constants import CONFIG_NAME
|
||||||
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
|
||||||
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
from lerobot.optim.optimizers import OptimizerConfig
|
||||||
|
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||||
|
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||||
|
from lerobot.utils.hub import HubMixin
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="RewardModelConfig")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||||
|
"""Base configuration for reward models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
|
||||||
|
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||||
|
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
|
||||||
|
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Reuses PolicyFeature
|
||||||
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
|
device: str | None = None
|
||||||
|
|
||||||
|
pretrained_path: str | None = None
|
||||||
|
|
||||||
|
push_to_hub: bool = False
|
||||||
|
repo_id: str | None = None
|
||||||
|
|
||||||
|
# Hub metadata
|
||||||
|
license: str | None = None
|
||||||
|
tags: list[str] | None = None
|
||||||
|
private: bool | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if not self.device or not is_torch_device_available(self.device):
|
||||||
|
auto_device = auto_select_torch_device()
|
||||||
|
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||||
|
self.device = auto_device.type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
choice_name = self.get_choice_name(self.__class__)
|
||||||
|
if not isinstance(choice_name, str):
|
||||||
|
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||||
|
return choice_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> OptimizerConfig | None:
|
||||||
|
"""Default optimizer for this reward model, or ``None`` for zero-shot models.
|
||||||
|
|
||||||
|
Trainable reward models (e.g. SARM, Classifier) must override this with a
|
||||||
|
concrete optimizer config. Zero-shot reward models (e.g. Robometer) leave
|
||||||
|
the default ``None`` — they error out earlier via the
|
||||||
|
:attr:`~lerobot.rewards.pretrained.PreTrainedRewardModel.is_trainable`
|
||||||
|
check in ``lerobot-train``.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
|
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||||
|
draccus.dump(self, f, indent=4)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls: builtins.type[T],
|
||||||
|
pretrained_name_or_path: str | Path,
|
||||||
|
*,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict[Any, Any] | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
cache_dir: str | Path | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
**reward_kwargs: Any,
|
||||||
|
) -> T:
|
||||||
|
model_id = str(pretrained_name_or_path)
|
||||||
|
config_file: str | None = None
|
||||||
|
if Path(model_id).is_dir():
|
||||||
|
if CONFIG_NAME in os.listdir(model_id):
|
||||||
|
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||||
|
else:
|
||||||
|
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
config_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=CONFIG_NAME,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if config_file is None:
|
||||||
|
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||||
|
|
||||||
|
# HACK: Parse the original config to get the config subclass, so that we can
|
||||||
|
# apply cli overrides.
|
||||||
|
with draccus.config_type("json"):
|
||||||
|
orig_config = draccus.parse(cls, config_file, args=[])
|
||||||
|
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
config.pop("type", None)
|
||||||
|
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
config_file = f.name
|
||||||
|
|
||||||
|
cli_overrides = reward_kwargs.pop("cli_overrides", [])
|
||||||
|
with draccus.config_type("json"):
|
||||||
|
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
|
||||||
@@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import builtins
|
import builtins
|
||||||
import datetime as dt
|
import datetime as dt
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -26,18 +28,57 @@ from lerobot import envs
|
|||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||||
from lerobot.utils.hub import HubMixin
|
from lerobot.utils.hub import HubMixin
|
||||||
|
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||||
|
|
||||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||||
from .policies import PreTrainedConfig
|
from .policies import PreTrainedConfig
|
||||||
|
from .rewards import RewardModelConfig
|
||||||
|
|
||||||
TRAIN_CONFIG_NAME = "train_config.json"
|
TRAIN_CONFIG_NAME = "train_config.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||||
|
legacy_fields = (
|
||||||
|
"use_rabc",
|
||||||
|
"rabc_progress_path",
|
||||||
|
"rabc_kappa",
|
||||||
|
"rabc_epsilon",
|
||||||
|
"rabc_head_mode",
|
||||||
|
)
|
||||||
|
if not any(key in config for key in legacy_fields):
|
||||||
|
return None
|
||||||
|
|
||||||
|
migrated_config = dict(config)
|
||||||
|
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||||
|
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||||
|
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||||
|
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||||
|
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||||
|
|
||||||
|
# New configs may already define sample_weighting explicitly. In that case,
|
||||||
|
# legacy fields are ignored after being stripped from the payload.
|
||||||
|
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||||
|
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||||
|
if rabc_progress_path is not None:
|
||||||
|
sample_weighting["progress_path"] = rabc_progress_path
|
||||||
|
if rabc_kappa is not None:
|
||||||
|
sample_weighting["kappa"] = rabc_kappa
|
||||||
|
if rabc_epsilon is not None:
|
||||||
|
sample_weighting["epsilon"] = rabc_epsilon
|
||||||
|
if rabc_head_mode is not None:
|
||||||
|
sample_weighting["head_mode"] = rabc_head_mode
|
||||||
|
migrated_config["sample_weighting"] = sample_weighting
|
||||||
|
|
||||||
|
return migrated_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainPipelineConfig(HubMixin):
|
class TrainPipelineConfig(HubMixin):
|
||||||
dataset: DatasetConfig
|
dataset: DatasetConfig
|
||||||
env: envs.EnvConfig | None = None
|
env: envs.EnvConfig | None = None
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
|
reward_model: RewardModelConfig | None = None
|
||||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||||
output_dir: Path | None = None
|
output_dir: Path | None = None
|
||||||
@@ -72,27 +113,44 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||||
peft: PeftConfig | None = None
|
peft: PeftConfig | None = None
|
||||||
|
|
||||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
# Sample weighting configuration (e.g., for RA-BC training)
|
||||||
use_rabc: bool = False # Enable reward-weighted training
|
sample_weighting: SampleWeightingConfig | None = None
|
||||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
|
||||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
|
||||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
|
||||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
|
||||||
|
|
||||||
# Rename map for the observation to override the image and state keys
|
# Rename map for the observation to override the image and state keys
|
||||||
rename_map: dict[str, str] = field(default_factory=dict)
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
checkpoint_path: Path | None = field(init=False, default=None)
|
checkpoint_path: Path | None = field(init=False, default=None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_reward_model_training(self) -> bool:
|
||||||
|
"""True when the config targets a reward model rather than a policy."""
|
||||||
|
return self.reward_model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
|
||||||
|
"""Return whichever config (policy or reward_model) is active."""
|
||||||
|
if self.is_reward_model_training:
|
||||||
|
return self.reward_model # type: ignore[return-value]
|
||||||
|
return self.policy # type: ignore[return-value]
|
||||||
|
|
||||||
def validate(self) -> None:
|
def validate(self) -> None:
|
||||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
if policy_path:
|
reward_model_path = parser.get_path_arg("reward_model")
|
||||||
# Only load the policy config
|
|
||||||
cli_overrides = parser.get_cli_overrides("policy")
|
if reward_model_path:
|
||||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
cli_overrides = parser.get_cli_overrides("reward_model")
|
||||||
|
self.reward_model = RewardModelConfig.from_pretrained(
|
||||||
|
reward_model_path, cli_overrides=cli_overrides
|
||||||
|
)
|
||||||
|
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||||
|
elif policy_path:
|
||||||
|
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||||
|
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||||
|
self.policy = PreTrainedConfig.from_pretrained(
|
||||||
|
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||||
|
)
|
||||||
self.policy.pretrained_path = Path(policy_path)
|
self.policy.pretrained_path = Path(policy_path)
|
||||||
elif self.resume:
|
elif self.resume:
|
||||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
|
||||||
config_path = parser.parse_arg("config_path")
|
config_path = parser.parse_arg("config_path")
|
||||||
if not config_path:
|
if not config_path:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -108,18 +166,22 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
policy_dir = Path(config_path).parent
|
policy_dir = Path(config_path).parent
|
||||||
if self.policy is not None:
|
if self.policy is not None:
|
||||||
self.policy.pretrained_path = policy_dir
|
self.policy.pretrained_path = policy_dir
|
||||||
|
if self.reward_model is not None:
|
||||||
|
self.reward_model.pretrained_path = str(policy_dir)
|
||||||
self.checkpoint_path = policy_dir.parent
|
self.checkpoint_path = policy_dir.parent
|
||||||
|
|
||||||
if self.policy is None:
|
if self.policy is None and self.reward_model is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
"Neither policy nor reward_model is configured. "
|
||||||
|
"Please specify one with `--policy.path` or `--reward_model.path`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
active_cfg = self.trainable_config
|
||||||
if not self.job_name:
|
if not self.job_name:
|
||||||
if self.env is None:
|
if self.env is None:
|
||||||
self.job_name = f"{self.policy.type}"
|
self.job_name = f"{active_cfg.type}"
|
||||||
else:
|
else:
|
||||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
self.job_name = f"{self.env.type}_{active_cfg.type}"
|
||||||
|
|
||||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||||
raise FileExistsError(
|
raise FileExistsError(
|
||||||
@@ -137,26 +199,16 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||||
elif self.use_policy_training_preset and not self.resume:
|
elif self.use_policy_training_preset and not self.resume:
|
||||||
self.optimizer = self.policy.get_optimizer_preset()
|
self.optimizer = active_cfg.get_optimizer_preset()
|
||||||
self.scheduler = self.policy.get_scheduler_preset()
|
self.scheduler = active_cfg.get_scheduler_preset()
|
||||||
|
|
||||||
if self.policy.push_to_hub and not self.policy.repo_id:
|
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||||
raise ValueError(
|
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_rabc and not self.rabc_progress_path:
|
|
||||||
# Auto-detect from dataset path
|
|
||||||
repo_id = self.dataset.repo_id
|
|
||||||
if self.dataset.root:
|
|
||||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
|
||||||
else:
|
|
||||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
"""Keys for draccus pretrained-path loading."""
|
||||||
return ["policy"]
|
return ["policy", "reward_model"]
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||||
@@ -207,12 +259,16 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
cli_args = kwargs.pop("cli_args", [])
|
cli_args = kwargs.pop("cli_args", [])
|
||||||
|
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
||||||
|
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
|
||||||
|
if config_file is not None and config_file.endswith(".json"):
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||||
|
if migrated_config is not None:
|
||||||
|
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||||
|
json.dump(migrated_config, f)
|
||||||
|
config_file = f.name
|
||||||
|
|
||||||
with draccus.config_type("json"):
|
with draccus.config_type("json"):
|
||||||
return draccus.parse(cls, config_file, args=cli_args)
|
return draccus.parse(cls, config_file, args=cli_args)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
|
||||||
# NOTE: In RL, we don't need an offline dataset
|
|
||||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
|
||||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
|
||||||
|
|||||||
@@ -97,8 +97,8 @@ def update_data_df(df, src_meta, dst_meta):
|
|||||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
df["index"] = df["index"] + dst_meta.info.total_frames
|
||||||
|
|
||||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
||||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
||||||
@@ -225,9 +225,9 @@ def update_meta_data(
|
|||||||
# Clean up temporary columns
|
# Clean up temporary columns
|
||||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||||
|
|
||||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info.total_frames
|
||||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info.total_frames
|
||||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
@@ -237,8 +237,8 @@ def aggregate_datasets(
|
|||||||
aggr_repo_id: str,
|
aggr_repo_id: str,
|
||||||
roots: list[Path] | None = None,
|
roots: list[Path] | None = None,
|
||||||
aggr_root: Path | None = None,
|
aggr_root: Path | None = None,
|
||||||
data_files_size_in_mb: float | None = None,
|
data_files_size_in_mb: int | None = None,
|
||||||
video_files_size_in_mb: float | None = None,
|
video_files_size_in_mb: int | None = None,
|
||||||
chunk_size: int | None = None,
|
chunk_size: int | None = None,
|
||||||
):
|
):
|
||||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||||
@@ -313,8 +313,8 @@ def aggregate_datasets(
|
|||||||
# to avoid interference between different source datasets
|
# to avoid interference between different source datasets
|
||||||
data_idx.pop("src_to_dst", None)
|
data_idx.pop("src_to_dst", None)
|
||||||
|
|
||||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
dst_meta.info.total_episodes += src_meta.total_episodes
|
||||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
dst_meta.info.total_frames += src_meta.total_frames
|
||||||
|
|
||||||
finalize_aggregation(dst_meta, all_metadata)
|
finalize_aggregation(dst_meta, all_metadata)
|
||||||
logging.info("Aggregation complete.")
|
logging.info("Aggregation complete.")
|
||||||
@@ -640,14 +640,10 @@ def finalize_aggregation(aggr_meta, all_metadata):
|
|||||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||||
|
|
||||||
logging.info("write info")
|
logging.info("write info")
|
||||||
aggr_meta.info.update(
|
aggr_meta.info.total_tasks = len(aggr_meta.tasks)
|
||||||
{
|
aggr_meta.info.total_episodes = sum(m.total_episodes for m in all_metadata)
|
||||||
"total_tasks": len(aggr_meta.tasks),
|
aggr_meta.info.total_frames = sum(m.total_frames for m in all_metadata)
|
||||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
aggr_meta.info.splits = {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}
|
||||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
|
||||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
write_info(aggr_meta.info, aggr_meta.root)
|
write_info(aggr_meta.info, aggr_meta.root)
|
||||||
|
|
||||||
logging.info("write stats")
|
logging.info("write stats")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -37,13 +38,11 @@ from .io_utils import (
|
|||||||
load_subtasks,
|
load_subtasks,
|
||||||
load_tasks,
|
load_tasks,
|
||||||
write_info,
|
write_info,
|
||||||
write_json,
|
|
||||||
write_stats,
|
write_stats,
|
||||||
write_tasks,
|
write_tasks,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
INFO_PATH,
|
|
||||||
check_version_compatibility,
|
check_version_compatibility,
|
||||||
get_safe_version,
|
get_safe_version,
|
||||||
has_legacy_hub_download_metadata,
|
has_legacy_hub_download_metadata,
|
||||||
@@ -191,6 +190,29 @@ class LeRobotDatasetMetadata:
|
|||||||
if self.episodes is None:
|
if self.episodes is None:
|
||||||
self._load_metadata()
|
self._load_metadata()
|
||||||
|
|
||||||
|
def filter_episodes(
|
||||||
|
self,
|
||||||
|
predicate: Callable[[dict], bool],
|
||||||
|
candidates: list[int] | None = None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""Filter episodes whose metadata satisfies a given predicate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predicate: Predicate over per-episode metadata rows used to select episodes.
|
||||||
|
candidates: Optional list of episode indices to restrict evaluation to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of sorted episode indices that satisfy the predicate.
|
||||||
|
"""
|
||||||
|
self.ensure_readable()
|
||||||
|
if candidates is not None:
|
||||||
|
candidate_set = set(candidates)
|
||||||
|
combined = lambda ep: ep["episode_index"] in candidate_set and predicate(ep) # noqa: E731
|
||||||
|
else:
|
||||||
|
combined = predicate
|
||||||
|
filtered = self.episodes.filter(combined, keep_in_memory=True, load_from_cache_file=False)
|
||||||
|
return sorted(int(idx) for idx in filtered["episode_index"])
|
||||||
|
|
||||||
def _pull_from_repo(
|
def _pull_from_repo(
|
||||||
self,
|
self,
|
||||||
allow_patterns: list[str] | str | None = None,
|
allow_patterns: list[str] | str | None = None,
|
||||||
@@ -228,7 +250,7 @@ class LeRobotDatasetMetadata:
|
|||||||
@property
|
@property
|
||||||
def _version(self) -> packaging.version.Version:
|
def _version(self) -> packaging.version.Version:
|
||||||
"""Codebase version used to create this dataset."""
|
"""Codebase version used to create this dataset."""
|
||||||
return packaging.version.parse(self.info["codebase_version"])
|
return packaging.version.parse(self.info.codebase_version)
|
||||||
|
|
||||||
def get_data_file_path(self, ep_index: int) -> Path:
|
def get_data_file_path(self, ep_index: int) -> Path:
|
||||||
"""Return the relative parquet file path for the given episode index.
|
"""Return the relative parquet file path for the given episode index.
|
||||||
@@ -283,27 +305,27 @@ class LeRobotDatasetMetadata:
|
|||||||
@property
|
@property
|
||||||
def data_path(self) -> str:
|
def data_path(self) -> str:
|
||||||
"""Formattable string for the parquet files."""
|
"""Formattable string for the parquet files."""
|
||||||
return self.info["data_path"]
|
return self.info.data_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video_path(self) -> str | None:
|
def video_path(self) -> str | None:
|
||||||
"""Formattable string for the video files."""
|
"""Formattable string for the video files."""
|
||||||
return self.info["video_path"]
|
return self.info.video_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def robot_type(self) -> str | None:
|
def robot_type(self) -> str | None:
|
||||||
"""Robot type used in recording this dataset."""
|
"""Robot type used in recording this dataset."""
|
||||||
return self.info["robot_type"]
|
return self.info.robot_type
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
"""Frames per second used during data collection."""
|
"""Frames per second used during data collection."""
|
||||||
return self.info["fps"]
|
return self.info.fps
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self) -> dict[str, dict]:
|
def features(self) -> dict[str, dict]:
|
||||||
"""All features contained in the dataset."""
|
"""All features contained in the dataset."""
|
||||||
return self.info["features"]
|
return self.info.features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def image_keys(self) -> list[str]:
|
def image_keys(self) -> list[str]:
|
||||||
@@ -333,32 +355,32 @@ class LeRobotDatasetMetadata:
|
|||||||
@property
|
@property
|
||||||
def total_episodes(self) -> int:
|
def total_episodes(self) -> int:
|
||||||
"""Total number of episodes available."""
|
"""Total number of episodes available."""
|
||||||
return self.info["total_episodes"]
|
return self.info.total_episodes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_frames(self) -> int:
|
def total_frames(self) -> int:
|
||||||
"""Total number of frames saved in this dataset."""
|
"""Total number of frames saved in this dataset."""
|
||||||
return self.info["total_frames"]
|
return self.info.total_frames
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_tasks(self) -> int:
|
def total_tasks(self) -> int:
|
||||||
"""Total number of different tasks performed in this dataset."""
|
"""Total number of different tasks performed in this dataset."""
|
||||||
return self.info["total_tasks"]
|
return self.info.total_tasks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chunks_size(self) -> int:
|
def chunks_size(self) -> int:
|
||||||
"""Max number of files per chunk."""
|
"""Max number of files per chunk."""
|
||||||
return self.info["chunks_size"]
|
return self.info.chunks_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_files_size_in_mb(self) -> int:
|
def data_files_size_in_mb(self) -> int:
|
||||||
"""Max size of data file in mega bytes."""
|
"""Max size of data file in mega bytes."""
|
||||||
return self.info["data_files_size_in_mb"]
|
return self.info.data_files_size_in_mb
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video_files_size_in_mb(self) -> int:
|
def video_files_size_in_mb(self) -> int:
|
||||||
"""Max size of video file in mega bytes."""
|
"""Max size of video file in mega bytes."""
|
||||||
return self.info["video_files_size_in_mb"]
|
return self.info.video_files_size_in_mb
|
||||||
|
|
||||||
def get_task_index(self, task: str) -> int | None:
|
def get_task_index(self, task: str) -> int | None:
|
||||||
"""
|
"""
|
||||||
@@ -502,10 +524,10 @@ class LeRobotDatasetMetadata:
|
|||||||
self._save_episode_metadata(episode_dict)
|
self._save_episode_metadata(episode_dict)
|
||||||
|
|
||||||
# Update info
|
# Update info
|
||||||
self.info["total_episodes"] += 1
|
self.info.total_episodes += 1
|
||||||
self.info["total_frames"] += episode_length
|
self.info.total_frames += episode_length
|
||||||
self.info["total_tasks"] = len(self.tasks)
|
self.info.total_tasks = len(self.tasks)
|
||||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
self.info.splits = {"train": f"0:{self.info.total_episodes}"}
|
||||||
|
|
||||||
write_info(self.info, self.root)
|
write_info(self.info, self.root)
|
||||||
|
|
||||||
@@ -524,7 +546,7 @@ class LeRobotDatasetMetadata:
|
|||||||
for key in video_keys:
|
for key in video_keys:
|
||||||
if not self.features[key].get("info", None):
|
if not self.features[key].get("info", None):
|
||||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
self.info.features[key]["info"] = get_video_info(video_path)
|
||||||
|
|
||||||
def update_chunk_settings(
|
def update_chunk_settings(
|
||||||
self,
|
self,
|
||||||
@@ -546,17 +568,17 @@ class LeRobotDatasetMetadata:
|
|||||||
if chunks_size is not None:
|
if chunks_size is not None:
|
||||||
if chunks_size <= 0:
|
if chunks_size <= 0:
|
||||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||||
self.info["chunks_size"] = chunks_size
|
self.info.chunks_size = chunks_size
|
||||||
|
|
||||||
if data_files_size_in_mb is not None:
|
if data_files_size_in_mb is not None:
|
||||||
if data_files_size_in_mb <= 0:
|
if data_files_size_in_mb <= 0:
|
||||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
self.info.data_files_size_in_mb = data_files_size_in_mb
|
||||||
|
|
||||||
if video_files_size_in_mb is not None:
|
if video_files_size_in_mb is not None:
|
||||||
if video_files_size_in_mb <= 0:
|
if video_files_size_in_mb <= 0:
|
||||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
self.info.video_files_size_in_mb = video_files_size_in_mb
|
||||||
|
|
||||||
# Update the info file on disk
|
# Update the info file on disk
|
||||||
write_info(self.info, self.root)
|
write_info(self.info, self.root)
|
||||||
@@ -653,7 +675,7 @@ class LeRobotDatasetMetadata:
|
|||||||
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
||||||
"Either remove video features from the features dict, or set 'use_videos=True'."
|
"Either remove video features from the features dict, or set 'use_videos=True'."
|
||||||
)
|
)
|
||||||
write_json(obj.info, obj.root / INFO_PATH)
|
write_info(obj.info, obj.root)
|
||||||
obj.revision = None
|
obj.revision = None
|
||||||
obj._pq_writer = None
|
obj._pq_writer = None
|
||||||
obj.latest_episode = None
|
obj.latest_episode = None
|
||||||
|
|||||||
@@ -897,14 +897,10 @@ def _copy_and_reindex_episodes_metadata(
|
|||||||
|
|
||||||
dst_meta.finalize()
|
dst_meta.finalize()
|
||||||
|
|
||||||
dst_meta.info.update(
|
dst_meta.info.total_episodes = len(episode_mapping)
|
||||||
{
|
dst_meta.info.total_frames = total_frames
|
||||||
"total_episodes": len(episode_mapping),
|
dst_meta.info.total_tasks = len(dst_meta.tasks) if dst_meta.tasks is not None else 0
|
||||||
"total_frames": total_frames,
|
dst_meta.info.splits = {"train": f"0:{len(episode_mapping)}"}
|
||||||
"total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0,
|
|
||||||
"splits": {"train": f"0:{len(episode_mapping)}"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
write_info(dst_meta.info, dst_meta.root)
|
write_info(dst_meta.info, dst_meta.root)
|
||||||
|
|
||||||
if not all_stats:
|
if not all_stats:
|
||||||
@@ -1069,21 +1065,20 @@ def _copy_episodes_metadata_and_stats(
|
|||||||
if episodes_dir.exists():
|
if episodes_dir.exists():
|
||||||
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
|
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
|
||||||
|
|
||||||
dst_meta.info.update(
|
dst_meta.info.total_episodes = src_dataset.meta.total_episodes
|
||||||
{
|
dst_meta.info.total_frames = src_dataset.meta.total_frames
|
||||||
"total_episodes": src_dataset.meta.total_episodes,
|
dst_meta.info.total_tasks = src_dataset.meta.total_tasks
|
||||||
"total_frames": src_dataset.meta.total_frames,
|
# Preserve original splits if available, otherwise create default
|
||||||
"total_tasks": src_dataset.meta.total_tasks,
|
dst_meta.info.splits = (
|
||||||
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
|
src_dataset.meta.info.splits
|
||||||
}
|
if src_dataset.meta.info.splits
|
||||||
|
else {"train": f"0:{src_dataset.meta.total_episodes}"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||||
for key in dst_meta.video_keys:
|
for key in dst_meta.video_keys:
|
||||||
if key in src_dataset.meta.features:
|
if key in src_dataset.meta.features:
|
||||||
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
|
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||||
"info", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
write_info(dst_meta.info, dst_meta.root)
|
write_info(dst_meta.info, dst_meta.root)
|
||||||
|
|
||||||
@@ -1525,7 +1520,7 @@ def modify_tasks(
|
|||||||
write_tasks(new_task_df, root)
|
write_tasks(new_task_df, root)
|
||||||
|
|
||||||
# Update info.json
|
# Update info.json
|
||||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
dataset.meta.info.total_tasks = len(unique_tasks)
|
||||||
write_info(dataset.meta.info, root)
|
write_info(dataset.meta.info, root)
|
||||||
|
|
||||||
# Reload metadata to reflect changes
|
# Reload metadata to reflect changes
|
||||||
@@ -1858,10 +1853,10 @@ def convert_image_to_video_dataset(
|
|||||||
episodes_df.to_parquet(episodes_path, index=False)
|
episodes_df.to_parquet(episodes_path, index=False)
|
||||||
|
|
||||||
# Update metadata info
|
# Update metadata info
|
||||||
new_meta.info["total_episodes"] = len(episode_indices)
|
new_meta.info.total_episodes = len(episode_indices)
|
||||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
new_meta.info.total_frames = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
new_meta.info.total_tasks = dataset.meta.total_tasks
|
||||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
|
||||||
|
|
||||||
# Update video info for all image keys (now videos)
|
# Update video info for all image keys (now videos)
|
||||||
# We need to manually set video info since update_video_info() checks video_keys first
|
# We need to manually set video info since update_video_info() checks video_keys first
|
||||||
@@ -1870,7 +1865,7 @@ def convert_image_to_video_dataset(
|
|||||||
video_path = new_meta.root / new_meta.video_path.format(
|
video_path = new_meta.root / new_meta.video_path.format(
|
||||||
video_key=img_key, chunk_index=0, file_index=0
|
video_key=img_key, chunk_index=0, file_index=0
|
||||||
)
|
)
|
||||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||||
|
|
||||||
write_info(new_meta.info, new_meta.root)
|
write_info(new_meta.info, new_meta.root)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from pprint import pformat
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs import PreTrainedConfig
|
from lerobot.configs import PreTrainedConfig
|
||||||
|
from lerobot.configs.rewards import RewardModelConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.transforms import ImageTransforms
|
from lerobot.transforms import ImageTransforms
|
||||||
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
||||||
@@ -30,12 +31,14 @@ from .streaming_dataset import StreamingLeRobotDataset
|
|||||||
|
|
||||||
|
|
||||||
def resolve_delta_timestamps(
|
def resolve_delta_timestamps(
|
||||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
|
||||||
) -> dict[str, list] | None:
|
) -> dict[str, list] | None:
|
||||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
|
||||||
|
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
|
||||||
|
``{observation,action,reward}_delta_indices`` properties used below.
|
||||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||||
delta_timestamps against.
|
delta_timestamps against.
|
||||||
|
|
||||||
@@ -82,7 +85,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
ds_meta = LeRobotDatasetMetadata(
|
ds_meta = LeRobotDatasetMetadata(
|
||||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||||
)
|
)
|
||||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
|
||||||
if not cfg.dataset.streaming:
|
if not cfg.dataset.streaming:
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset.repo_id,
|
cfg.dataset.repo_id,
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from .utils import (
|
|||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
|
DatasetInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,8 +79,8 @@ def create_empty_dataset_info(
|
|||||||
chunks_size: int | None = None,
|
chunks_size: int | None = None,
|
||||||
data_files_size_in_mb: int | None = None,
|
data_files_size_in_mb: int | None = None,
|
||||||
video_files_size_in_mb: int | None = None,
|
video_files_size_in_mb: int | None = None,
|
||||||
) -> dict:
|
) -> DatasetInfo:
|
||||||
"""Create a template dictionary for a new dataset's `info.json`.
|
"""Create a template ``DatasetInfo`` object for a new dataset's ``meta/info.json``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
codebase_version (str): The version of the LeRobot codebase.
|
codebase_version (str): The version of the LeRobot codebase.
|
||||||
@@ -87,25 +88,24 @@ def create_empty_dataset_info(
|
|||||||
features (dict): The LeRobot features dictionary for the dataset.
|
features (dict): The LeRobot features dictionary for the dataset.
|
||||||
use_videos (bool): Whether the dataset will store videos.
|
use_videos (bool): Whether the dataset will store videos.
|
||||||
robot_type (str | None): The type of robot used, if any.
|
robot_type (str | None): The type of robot used, if any.
|
||||||
|
chunks_size (int | None): Max files per chunk directory. Defaults to ``DEFAULT_CHUNK_SIZE``.
|
||||||
|
data_files_size_in_mb (int | None): Max parquet file size in MB. Defaults to ``DEFAULT_DATA_FILE_SIZE_IN_MB``.
|
||||||
|
video_files_size_in_mb (int | None): Max video file size in MB. Defaults to ``DEFAULT_VIDEO_FILE_SIZE_IN_MB``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary with the initial dataset metadata.
|
DatasetInfo: A typed dataset information object with initial metadata.
|
||||||
"""
|
"""
|
||||||
return {
|
return DatasetInfo(
|
||||||
"codebase_version": codebase_version,
|
codebase_version=codebase_version,
|
||||||
"robot_type": robot_type,
|
fps=fps,
|
||||||
"total_episodes": 0,
|
features=features,
|
||||||
"total_frames": 0,
|
robot_type=robot_type,
|
||||||
"total_tasks": 0,
|
chunks_size=chunks_size or DEFAULT_CHUNK_SIZE,
|
||||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
data_files_size_in_mb=data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
video_files_size_in_mb=video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
data_path=DEFAULT_DATA_PATH,
|
||||||
"fps": fps,
|
video_path=DEFAULT_VIDEO_PATH if use_videos else None,
|
||||||
"splits": {},
|
)
|
||||||
"data_path": DEFAULT_DATA_PATH,
|
|
||||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
|
||||||
"features": features,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def check_delta_timestamps(
|
def check_delta_timestamps(
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from .utils import (
|
|||||||
EPISODES_DIR,
|
EPISODES_DIR,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
STATS_PATH,
|
STATS_PATH,
|
||||||
|
DatasetInfo,
|
||||||
serialize_dict,
|
serialize_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,25 +116,21 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def write_info(info: dict, local_dir: Path) -> None:
|
def write_info(info: DatasetInfo, local_dir: Path) -> None:
|
||||||
write_json(info, local_dir / INFO_PATH)
|
write_json(info.to_dict(), local_dir / INFO_PATH)
|
||||||
|
|
||||||
|
|
||||||
def load_info(local_dir: Path) -> dict:
|
def load_info(local_dir: Path) -> DatasetInfo:
|
||||||
"""Load dataset info metadata from its standard file path.
|
"""Load dataset info metadata from its standard file path.
|
||||||
|
|
||||||
Also converts shape lists to tuples for consistency.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_dir (Path): The root directory of the dataset.
|
local_dir (Path): The root directory of the dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: The dataset information dictionary.
|
DatasetInfo: The typed dataset information object.
|
||||||
"""
|
"""
|
||||||
info = load_json(local_dir / INFO_PATH)
|
raw = load_json(local_dir / INFO_PATH)
|
||||||
for ft in info["features"].values():
|
return DatasetInfo.from_dict(raw)
|
||||||
ft["shape"] = tuple(ft["shape"])
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
def write_stats(stats: dict, local_dir: Path) -> None:
|
def write_stats(stats: dict, local_dir: Path) -> None:
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
|
episode_filter: Callable[[dict], bool] | None = None,
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[str, list[float]] | None = None,
|
delta_timestamps: dict[str, list[float]] | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
@@ -153,6 +154,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
``$HF_LEROBOT_HOME/hub``.
|
``$HF_LEROBOT_HOME/hub``.
|
||||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||||
their episode_index in this list. Defaults to None.
|
their episode_index in this list. Defaults to None.
|
||||||
|
episode_filter (Callable[[dict], bool] | None, optional): Predicate over per-episode
|
||||||
|
metadata rows used to select episodes. Evaluated against ``meta/`` without ``stats`` keys
|
||||||
|
(e.g.``task_index``, ``episode_index``, ``length``, ``from_timestamp``, ``to_timestamp``).
|
||||||
|
Intersected with ``episodes`` when both are set. Example: ``lambda ep: ep["length"] >= 100``.
|
||||||
|
Defaults to None.
|
||||||
image_transforms (Callable | None, optional):
|
image_transforms (Callable | None, optional):
|
||||||
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
Transform applied to visual modalities inside `__getitem__` after image decoding / tensor
|
||||||
conversion. This works for both image-backed and video-backed observations and can later be
|
conversion. This works for both image-backed and video-backed observations and can later be
|
||||||
@@ -199,7 +205,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.reader = None
|
self.reader = None
|
||||||
self.set_image_transforms(image_transforms)
|
self.set_image_transforms(image_transforms)
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
@@ -218,6 +223,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.root = self.meta.root
|
self.root = self.meta.root
|
||||||
self.revision = self.meta.revision
|
self.revision = self.meta.revision
|
||||||
|
|
||||||
|
if episodes is not None and any(
|
||||||
|
episode >= self.meta.total_episodes or episode < 0 for episode in episodes
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Some episodes in the provided episodes list are out of range for this dataset ({self.meta.total_episodes})."
|
||||||
|
)
|
||||||
|
|
||||||
|
if episode_filter is not None:
|
||||||
|
resolved = self.meta.filter_episodes(episode_filter, candidates=episodes)
|
||||||
|
if not resolved:
|
||||||
|
raise ValueError(
|
||||||
|
"The episode filter did not match any episode. Make sure the filter and episodes list are valid and compatible."
|
||||||
|
)
|
||||||
|
logger.info(f"The episode filter matched {len(resolved)} episode(s).")
|
||||||
|
episodes = resolved
|
||||||
|
self.episodes = episodes
|
||||||
|
|
||||||
# Create reader (hf_dataset loaded below)
|
# Create reader (hf_dataset loaded below)
|
||||||
self.reader = DatasetReader(
|
self.reader = DatasetReader(
|
||||||
meta=self.meta,
|
meta=self.meta,
|
||||||
@@ -630,6 +652,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
streaming_encoding: bool = False,
|
streaming_encoding: bool = False,
|
||||||
encoder_queue_maxsize: int = 30,
|
encoder_queue_maxsize: int = 30,
|
||||||
encoder_threads: int | None = None,
|
encoder_threads: int | None = None,
|
||||||
|
video_files_size_in_mb: int | None = None,
|
||||||
|
data_files_size_in_mb: int | None = None,
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
"""Create a new LeRobotDataset from scratch for recording data.
|
"""Create a new LeRobotDataset from scratch for recording data.
|
||||||
|
|
||||||
@@ -677,6 +701,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
root=root,
|
root=root,
|
||||||
use_videos=use_videos,
|
use_videos=use_videos,
|
||||||
metadata_buffer_size=metadata_buffer_size,
|
metadata_buffer_size=metadata_buffer_size,
|
||||||
|
video_files_size_in_mb=video_files_size_in_mb,
|
||||||
|
data_files_size_in_mb=data_files_size_in_mb,
|
||||||
)
|
)
|
||||||
obj.repo_id = obj.meta.repo_id
|
obj.repo_id = obj.meta.repo_id
|
||||||
obj._requested_root = obj.meta.root
|
obj._requested_root = obj.meta.root
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||||
"""
|
"""
|
||||||
return self._datasets[0].meta.info["fps"]
|
return self._datasets[0].meta.info.fps
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video(self) -> bool:
|
def video(self) -> bool:
|
||||||
@@ -133,7 +133,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||||
"""
|
"""
|
||||||
return self._datasets[0].meta.info.get("video", False)
|
return len(self._datasets[0].meta.video_keys) > 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self) -> datasets.Features:
|
def features(self) -> datasets.Features:
|
||||||
|
|||||||
@@ -434,7 +434,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
|
|
||||||
def _make_padding_camera_frame(self, camera_key: str):
|
def _make_padding_camera_frame(self, camera_key: str):
|
||||||
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
return torch.zeros(self.meta.info.features[camera_key]["shape"]).permute(-1, 0, 1)
|
||||||
|
|
||||||
def _get_video_frame_padding_mask(
|
def _get_video_frame_padding_mask(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -14,9 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import dataclasses
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -70,6 +72,9 @@ class ForwardCompatibilityError(CompatibilityError):
|
|||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||||
@@ -94,6 +99,123 @@ LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
|||||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetInfo:
|
||||||
|
"""Typed representation of the ``meta/info.json`` file for a LeRobot dataset.
|
||||||
|
|
||||||
|
Replaces the previously untyped ``dict`` returned by ``load_info()`` and
|
||||||
|
created by ``create_empty_dataset_info()``. Using a dataclass provides
|
||||||
|
explicit field definitions, IDE auto-completion, and validation at
|
||||||
|
construction time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
codebase_version: str
|
||||||
|
fps: int
|
||||||
|
features: dict[str, dict]
|
||||||
|
|
||||||
|
# Episode / frame counters — start at zero for new datasets
|
||||||
|
total_episodes: int = 0
|
||||||
|
total_frames: int = 0
|
||||||
|
total_tasks: int = 0
|
||||||
|
|
||||||
|
# Storage settings
|
||||||
|
chunks_size: int = field(default=DEFAULT_CHUNK_SIZE)
|
||||||
|
data_files_size_in_mb: int = field(default=DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||||
|
video_files_size_in_mb: int = field(default=DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||||
|
|
||||||
|
# File path templates
|
||||||
|
data_path: str = field(default=DEFAULT_DATA_PATH)
|
||||||
|
video_path: str | None = field(default=DEFAULT_VIDEO_PATH)
|
||||||
|
|
||||||
|
# Optional metadata
|
||||||
|
robot_type: str | None = None
|
||||||
|
splits: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||||
|
# returns lists, but the rest of the codebase expects tuples.
|
||||||
|
for ft in self.features.values():
|
||||||
|
if isinstance(ft.get("shape"), list):
|
||||||
|
ft["shape"] = tuple(ft["shape"])
|
||||||
|
|
||||||
|
if self.fps <= 0:
|
||||||
|
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||||
|
if self.chunks_size <= 0:
|
||||||
|
raise ValueError(f"chunks_size must be positive, got {self.chunks_size}")
|
||||||
|
if self.data_files_size_in_mb <= 0:
|
||||||
|
raise ValueError(f"data_files_size_in_mb must be positive, got {self.data_files_size_in_mb}")
|
||||||
|
if self.video_files_size_in_mb <= 0:
|
||||||
|
raise ValueError(f"video_files_size_in_mb must be positive, got {self.video_files_size_in_mb}")
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Return a JSON-serialisable dict.
|
||||||
|
|
||||||
|
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||||
|
"""
|
||||||
|
d = dataclasses.asdict(self)
|
||||||
|
for ft in d["features"].values():
|
||||||
|
if isinstance(ft.get("shape"), tuple):
|
||||||
|
ft["shape"] = list(ft["shape"])
|
||||||
|
return d
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "DatasetInfo":
|
||||||
|
"""Construct from a raw dict (e.g. loaded directly from JSON).
|
||||||
|
|
||||||
|
Unknown keys are ignored for forward compatibility with datasets that
|
||||||
|
carry additional fields (e.g. ``total_videos`` from v2.x). A warning is
|
||||||
|
logged when such fields are present.
|
||||||
|
"""
|
||||||
|
known = {f.name for f in dataclasses.fields(cls)}
|
||||||
|
unknown = sorted(k for k in data if k not in known)
|
||||||
|
if unknown:
|
||||||
|
logger.warning(f"Unknown fields in DatasetInfo: {unknown}. These will be ignored.")
|
||||||
|
return cls(**{k: v for k, v in data.items() if k in known})
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Temporary dict-style compatibility layer
|
||||||
|
# Allows existing ``info["key"]`` call-sites to keep working without changes.
|
||||||
|
# Once all callers have been migrated to attribute access, remove these.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
f"Accessing DatasetInfo with dict-style syntax info['{key}'] is deprecated. "
|
||||||
|
f"Use attribute access info.{key} instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return getattr(self, key)
|
||||||
|
except AttributeError as err:
|
||||||
|
raise KeyError(key) from err
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value) -> None:
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
f"Setting DatasetInfo with dict-style syntax info['{key}'] = ... is deprecated. "
|
||||||
|
f"Use attribute assignment info.{key} = ... instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
if not hasattr(self, key):
|
||||||
|
raise KeyError(f"DatasetInfo has no field '{key}'")
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
"""Check if a field exists (dict-like interface)."""
|
||||||
|
return hasattr(self, key)
|
||||||
|
|
||||||
|
def get(self, key: str, default=None):
|
||||||
|
"""Get attribute value with default fallback (dict-like interface)."""
|
||||||
|
try:
|
||||||
|
return getattr(self, key)
|
||||||
|
except AttributeError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
||||||
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
||||||
|
|
||||||
@@ -294,7 +416,7 @@ def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) ->
|
|||||||
|
|
||||||
def create_lerobot_dataset_card(
|
def create_lerobot_dataset_card(
|
||||||
tags: list | None = None,
|
tags: list | None = None,
|
||||||
dataset_info: dict | None = None,
|
dataset_info: DatasetInfo | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> DatasetCard:
|
) -> DatasetCard:
|
||||||
"""Create a `DatasetCard` for a LeRobot dataset.
|
"""Create a `DatasetCard` for a LeRobot dataset.
|
||||||
@@ -305,7 +427,7 @@ def create_lerobot_dataset_card(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tags (list | None): A list of tags to add to the dataset card.
|
tags (list | None): A list of tags to add to the dataset card.
|
||||||
dataset_info (dict | None): The dataset's info dictionary, which will
|
dataset_info (DatasetInfo | None): The dataset's info object, which will
|
||||||
be displayed on the card.
|
be displayed on the card.
|
||||||
**kwargs: Additional keyword arguments to populate the card template.
|
**kwargs: Additional keyword arguments to populate the card template.
|
||||||
|
|
||||||
@@ -318,7 +440,7 @@ def create_lerobot_dataset_card(
|
|||||||
card_tags += tags
|
card_tags += tags
|
||||||
if dataset_info:
|
if dataset_info:
|
||||||
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
||||||
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
dataset_structure += f"```json\n{json.dumps(dataset_info.to_dict(), indent=4)}\n```\n"
|
||||||
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
||||||
card_data = DatasetCardData(
|
card_data = DatasetCardData(
|
||||||
license=kwargs.get("license"),
|
license=kwargs.get("license"),
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ import fsspec
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@@ -132,7 +131,9 @@ def decode_video_frames(
|
|||||||
video_path (Path): Path to the video file.
|
video_path (Path): Path to the video file.
|
||||||
timestamps (list[float]): List of timestamps to extract frames.
|
timestamps (list[float]): List of timestamps to extract frames.
|
||||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available
|
||||||
|
in the platform; otherwise, defaults to "pyav". The legacy value "video_reader" is
|
||||||
|
accepted for one release as an alias for "pyav" and will be removed in a future version.
|
||||||
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||||
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||||
|
|
||||||
@@ -145,85 +146,87 @@ def decode_video_frames(
|
|||||||
backend = get_safe_default_codec()
|
backend = get_safe_default_codec()
|
||||||
if backend == "torchcodec":
|
if backend == "torchcodec":
|
||||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||||
elif backend in ["pyav", "video_reader"]:
|
elif backend == "pyav":
|
||||||
return decode_video_frames_torchvision(
|
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||||
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
|
elif backend == "video_reader":
|
||||||
)
|
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
|
||||||
|
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video backend: {backend}")
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_pyav(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
backend: str = "pyav",
|
|
||||||
log_loaded_timestamps: bool = False,
|
log_loaded_timestamps: bool = False,
|
||||||
return_uint8: bool = False,
|
return_uint8: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Loads frames associated to the requested timestamps of a video
|
"""Loads frames associated to the requested timestamps of a video using PyAV.
|
||||||
|
|
||||||
The backend can be either "pyav" (default) or "video_reader".
|
This is the fallback decoder for platforms where torchcodec has no wheel (currently macOS
|
||||||
"video_reader" requires installing torchvision from source, see:
|
x86_64 and linux armv7l — see the torchcodec block in pyproject.toml for the full matrix).
|
||||||
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
On supported platforms, prefer `decode_video_frames_torchcodec`, which is faster and supports
|
||||||
(note that you need to compile against ffmpeg<4.3)
|
accurate seek.
|
||||||
|
|
||||||
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
PyAV doesn't support accurate seek: we seek to the nearest preceding keyframe and decode
|
||||||
For more info on video decoding, see `benchmark/video/README.md`
|
forward until we have covered the requested timestamp range. The number of key frames in a
|
||||||
|
video can be adjusted at encoding time to trade off decoding speed against file size.
|
||||||
|
|
||||||
See torchvision doc for more info on these two backends:
|
Args:
|
||||||
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
video_path: Path to the video file.
|
||||||
|
timestamps: List of timestamps (in seconds) to extract frames for.
|
||||||
|
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
|
||||||
|
decoded frame.
|
||||||
|
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
|
||||||
|
return_uint8: When True, return raw uint8 frames (C, H, W). Otherwise, return float32 in
|
||||||
|
[0, 1] range.
|
||||||
|
|
||||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
Returns:
|
||||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
torch.Tensor of shape (len(timestamps), C, H, W).
|
||||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
|
||||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
|
||||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
|
||||||
"""
|
"""
|
||||||
video_path = str(video_path)
|
|
||||||
|
|
||||||
# set backend
|
|
||||||
keyframes_only = False
|
|
||||||
torchvision.set_video_backend(backend)
|
|
||||||
if backend == "pyav":
|
|
||||||
keyframes_only = True # pyav doesn't support accurate seek
|
|
||||||
|
|
||||||
# set a video stream reader
|
|
||||||
# TODO(rcadene): also load audio stream at the same time
|
# TODO(rcadene): also load audio stream at the same time
|
||||||
reader = torchvision.io.VideoReader(video_path, "video")
|
video_path = str(video_path)
|
||||||
|
|
||||||
# set the first and last requested timestamps
|
# set the first and last requested timestamps
|
||||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||||
first_ts = min(timestamps)
|
first_ts = min(timestamps)
|
||||||
last_ts = max(timestamps)
|
last_ts = max(timestamps)
|
||||||
|
|
||||||
# access closest key frame of the first requested frame
|
loaded_frames: list[torch.Tensor] = []
|
||||||
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
loaded_ts: list[float] = []
|
||||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
|
||||||
reader.seek(first_ts, keyframes_only=keyframes_only)
|
|
||||||
|
|
||||||
# load all frames until last requested frame
|
# Seek + decode. `container.seek(offset)` with no `stream` argument expects the offset in
|
||||||
loaded_frames = []
|
# av.time_base units (microseconds). `backward=True` lands us on the nearest keyframe at or
|
||||||
loaded_ts = []
|
# before `first_ts`, so we can then decode forward until we cover `last_ts`. See:
|
||||||
for frame in reader:
|
# https://pyav.basswood-io.com/docs/stable/api/container.html#av.container.InputContainer.seek
|
||||||
current_ts = frame["pts"]
|
with av.open(video_path) as container:
|
||||||
if log_loaded_timestamps:
|
stream = container.streams.video[0]
|
||||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
container.seek(int(first_ts * av.time_base), backward=True)
|
||||||
loaded_frames.append(frame["data"])
|
|
||||||
loaded_ts.append(current_ts)
|
|
||||||
if current_ts >= last_ts:
|
|
||||||
break
|
|
||||||
|
|
||||||
if backend == "pyav":
|
for frame in container.decode(stream):
|
||||||
reader.container.close()
|
if frame.pts is None:
|
||||||
|
continue
|
||||||
|
current_ts = float(frame.pts * stream.time_base)
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||||
|
# Convert to CHW uint8 to match torchcodec's output layout.
|
||||||
|
arr = frame.to_ndarray(format="rgb24") # H, W, 3
|
||||||
|
loaded_frames.append(torch.from_numpy(arr).permute(2, 0, 1).contiguous())
|
||||||
|
loaded_ts.append(current_ts)
|
||||||
|
if current_ts >= last_ts:
|
||||||
|
break
|
||||||
|
|
||||||
reader = None
|
if not loaded_frames:
|
||||||
|
raise FrameTimestampError(
|
||||||
|
f"No frames could be decoded from {video_path} in the timestamp range [{first_ts}, {last_ts}]."
|
||||||
|
)
|
||||||
|
|
||||||
query_ts = torch.tensor(timestamps)
|
query_ts = torch.tensor(timestamps)
|
||||||
loaded_ts = torch.tensor(loaded_ts)
|
loaded_ts_t = torch.tensor(loaded_ts)
|
||||||
|
|
||||||
# compute distances between each query timestamp and timestamps of all loaded frames
|
# compute distances between each query timestamp and timestamps of all loaded frames
|
||||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
|
||||||
min_, argmin_ = dist.min(1)
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
is_within_tol = min_ < tolerance_s
|
is_within_tol = min_ < tolerance_s
|
||||||
@@ -234,14 +237,14 @@ def decode_video_frames_torchvision(
|
|||||||
" This might be due to synchronization issues with timestamps during data collection."
|
" This might be due to synchronization issues with timestamps during data collection."
|
||||||
" To be safe, we advise to ignore this item during training."
|
" To be safe, we advise to ignore this item during training."
|
||||||
f"\nqueried timestamps: {query_ts}"
|
f"\nqueried timestamps: {query_ts}"
|
||||||
f"\nloaded timestamps: {loaded_ts}"
|
f"\nloaded timestamps: {loaded_ts_t}"
|
||||||
f"\nvideo: {video_path}"
|
f"\nvideo: {video_path}"
|
||||||
f"\nbackend: {backend}"
|
f"\nbackend: pyav"
|
||||||
)
|
)
|
||||||
|
|
||||||
# get closest frames to the query timestamps
|
# get closest frames to the query timestamps
|
||||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||||
closest_ts = loaded_ts[argmin_]
|
closest_ts = loaded_ts_t[argmin_]
|
||||||
|
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logger.info(f"{closest_ts=}")
|
logger.info(f"{closest_ts=}")
|
||||||
@@ -282,7 +285,11 @@ class VideoDecoderCache:
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
if video_path not in self._cache:
|
if video_path not in self._cache:
|
||||||
file_handle = fsspec.open(video_path).__enter__()
|
file_handle = fsspec.open(video_path).__enter__()
|
||||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
try:
|
||||||
|
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||||
|
except Exception:
|
||||||
|
file_handle.close()
|
||||||
|
raise
|
||||||
self._cache[video_path] = (decoder, file_handle)
|
self._cache[video_path] = (decoder, file_handle)
|
||||||
|
|
||||||
return self._cache[video_path][0]
|
return self._cache[video_path][0]
|
||||||
|
|||||||
@@ -12,19 +12,19 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
|
||||||
|
|
||||||
from .act.configuration_act import ACTConfig as ACTConfig
|
from .act.configuration_act import ACTConfig as ACTConfig
|
||||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||||
|
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||||
|
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||||
from .rtc import ActionInterpolator as ActionInterpolator
|
|
||||||
from .sac.configuration_sac import SACConfig as SACConfig
|
|
||||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
|
||||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
|
||||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||||
from .utils import make_robot_action, prepare_observation_for_inference
|
from .utils import make_robot_action, prepare_observation_for_inference
|
||||||
@@ -32,22 +32,21 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
|||||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||||
|
|
||||||
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
|
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
|
||||||
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
|
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
|
||||||
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
|
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Configuration classes
|
# Configuration classes
|
||||||
"ACTConfig",
|
"ACTConfig",
|
||||||
"DiffusionConfig",
|
"DiffusionConfig",
|
||||||
|
"EO1Config",
|
||||||
|
"GaussianActorConfig",
|
||||||
"GrootConfig",
|
"GrootConfig",
|
||||||
"MultiTaskDiTConfig",
|
"MultiTaskDiTConfig",
|
||||||
"PI0Config",
|
"PI0Config",
|
||||||
"PI0FastConfig",
|
"PI0FastConfig",
|
||||||
"PI05Config",
|
"PI05Config",
|
||||||
"RewardClassifierConfig",
|
|
||||||
"SACConfig",
|
|
||||||
"SARMConfig",
|
|
||||||
"SmolVLAConfig",
|
"SmolVLAConfig",
|
||||||
"TDMPCConfig",
|
"TDMPCConfig",
|
||||||
"VQBeTConfig",
|
"VQBeTConfig",
|
||||||
|
|||||||
@@ -142,9 +142,10 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||||
|
|
||||||
l1_loss = (
|
abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none")
|
||||||
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
valid_mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
).mean()
|
num_valid = valid_mask.sum() * abs_err.shape[-1]
|
||||||
|
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss.item()}
|
loss_dict = {"l1_loss": l1_loss.item()}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
|
|||||||
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
n_obs_steps: int = 2
|
n_obs_steps: int = 2
|
||||||
horizon: int = 16
|
horizon: int = 64
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 32
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
crop_ratio: float = 1.0
|
crop_ratio: float = 1.0
|
||||||
crop_shape: tuple[int, int] | None = None
|
crop_shape: tuple[int, int] | None = None
|
||||||
crop_is_random: bool = True
|
crop_is_random: bool = True
|
||||||
pretrained_backbone_weights: str | None = None
|
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||||
use_group_norm: bool = True
|
use_group_norm: bool = False
|
||||||
spatial_softmax_num_keypoints: int = 32
|
spatial_softmax_num_keypoints: int = 32
|
||||||
use_separate_rgb_encoder_per_camera: bool = False
|
use_separate_rgb_encoder_per_camera: bool = True
|
||||||
# Unet.
|
# Unet.
|
||||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||||
kernel_size: int = 5
|
kernel_size: int = 5
|
||||||
|
|||||||
@@ -380,7 +380,9 @@ class DiffusionModel(nn.Module):
|
|||||||
f"{self.config.do_mask_loss_for_padding=}."
|
f"{self.config.do_mask_loss_for_padding=}."
|
||||||
)
|
)
|
||||||
in_episode_bound = ~batch["action_is_pad"]
|
in_episode_bound = ~batch["action_is_pad"]
|
||||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
mask = in_episode_bound.unsqueeze(-1)
|
||||||
|
num_valid = mask.sum() * loss.shape[-1]
|
||||||
|
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
|
|||||||
1
src/lerobot/policies/eo1/README.md
Symbolic link
1
src/lerobot/policies/eo1/README.md
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../../../docs/source/eo1.mdx
|
||||||
7
src/lerobot/policies/eo1/__init__.py
Normal file
7
src/lerobot/policies/eo1/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from .configuration_eo1 import EO1Config
|
||||||
|
from .modeling_eo1 import EO1Policy
|
||||||
|
from .processor_eo1 import make_eo1_pre_post_processors
|
||||||
|
|
||||||
|
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]
|
||||||
193
src/lerobot/policies/eo1/configuration_eo1.py
Normal file
193
src/lerobot/policies/eo1/configuration_eo1.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||||
|
Qwen2_5_VLConfig,
|
||||||
|
Qwen2_5_VLTextConfig,
|
||||||
|
Qwen2_5_VLVisionConfig,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
Qwen2_5_VLConfig = None
|
||||||
|
Qwen2_5_VLTextConfig = None
|
||||||
|
Qwen2_5_VLVisionConfig = None
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("eo1")
|
||||||
|
@dataclass
|
||||||
|
class EO1Config(PreTrainedConfig):
|
||||||
|
"""Configuration for native EO1 policy integration in LeRobot."""
|
||||||
|
|
||||||
|
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
vlm_config: dict | None = None
|
||||||
|
|
||||||
|
# Vision processor settings.
|
||||||
|
image_min_pixels: int | None = 64 * 28 * 28
|
||||||
|
image_max_pixels: int | None = 128 * 28 * 28
|
||||||
|
use_fast_processor: bool = False
|
||||||
|
|
||||||
|
# Execution and action horizon.
|
||||||
|
n_obs_steps: int = 1
|
||||||
|
chunk_size: int = 8
|
||||||
|
n_action_steps: int = 8
|
||||||
|
|
||||||
|
# State/action padding to match EO1 flow head dimensionality.
|
||||||
|
max_state_dim: int = 32
|
||||||
|
max_action_dim: int = 32
|
||||||
|
|
||||||
|
# Flow matching sampling.
|
||||||
|
num_denoise_steps: int = 10
|
||||||
|
num_action_layers: int = 2
|
||||||
|
action_act: str = "linear"
|
||||||
|
time_sampling_beta_alpha: float = 1.5
|
||||||
|
time_sampling_beta_beta: float = 1.0
|
||||||
|
time_sampling_scale: float = 0.999
|
||||||
|
time_sampling_offset: float = 0.001
|
||||||
|
min_period: float = 4e-3
|
||||||
|
max_period: float = 4.0
|
||||||
|
supervise_padding_action_dims: bool = True
|
||||||
|
supervise_padding_actions: bool = True
|
||||||
|
|
||||||
|
# Policy-level dtype request for the Qwen backbone.
|
||||||
|
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
|
||||||
|
# The EO1 flow-matching head still keeps its own parameters in fp32.
|
||||||
|
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
|
||||||
|
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
|
||||||
|
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
|
||||||
|
force_fp32_autocast: bool = True
|
||||||
|
|
||||||
|
# Optional attention backend request passed through to the Qwen backbone.
|
||||||
|
# Common values: None, "eager", "sdpa", "flash_attention_2".
|
||||||
|
attn_implementation: str | None = None
|
||||||
|
|
||||||
|
# Training settings.
|
||||||
|
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.MEAN_STD,
|
||||||
|
"ACTION": NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
|
||||||
|
optimizer_lr: float = 1e-4
|
||||||
|
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||||
|
optimizer_eps: float = 1e-8
|
||||||
|
optimizer_weight_decay: float = 0.1
|
||||||
|
optimizer_grad_clip_norm: float = 1.0
|
||||||
|
|
||||||
|
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
|
||||||
|
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||||
|
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||||
|
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
|
||||||
|
scheduler_decay_steps: int = 30_000
|
||||||
|
scheduler_decay_lr: float = 0.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
if self.n_action_steps > self.chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Populate the serialized backbone config only when the caller did not provide one.
|
||||||
|
if self.vlm_config is None:
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
config_dict = deepcopy(self.vlm_config)
|
||||||
|
if self.attn_implementation is not None:
|
||||||
|
config_dict["attn_implementation"] = self.attn_implementation
|
||||||
|
return Qwen2_5_VLConfig(**config_dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text_config(self) -> Qwen2_5_VLTextConfig:
|
||||||
|
return self.vlm_backbone_config.text_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vision_config(self) -> Qwen2_5_VLVisionConfig:
|
||||||
|
return self.vlm_backbone_config.vision_config
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
"""Validate and set up EO1 input and output features."""
|
||||||
|
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||||
|
if not image_features:
|
||||||
|
raise ValueError(
|
||||||
|
"EO1 policy requires at least one visual input feature. "
|
||||||
|
"No features of type FeatureType.VISUAL found in input_features."
|
||||||
|
)
|
||||||
|
|
||||||
|
if OBS_STATE not in self.input_features:
|
||||||
|
state_feature = PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(self.max_state_dim,),
|
||||||
|
)
|
||||||
|
self.input_features[OBS_STATE] = state_feature
|
||||||
|
|
||||||
|
if ACTION not in self.output_features:
|
||||||
|
action_feature = PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(self.max_action_dim,),
|
||||||
|
)
|
||||||
|
self.output_features[ACTION] = action_feature
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
|
return AdamWConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self):
|
||||||
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
peak_lr=self.optimizer_lr,
|
||||||
|
decay_lr=self.scheduler_decay_lr,
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
num_decay_steps=self.scheduler_decay_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list[int]:
|
||||||
|
return list(range(self.chunk_size))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
620
src/lerobot/policies/eo1/modeling_eo1.py
Normal file
620
src/lerobot/policies/eo1/modeling_eo1.py
Normal file
@@ -0,0 +1,620 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from collections import deque
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers.utils import torch_compilable_check
|
||||||
|
else:
|
||||||
|
ACT2FN = None
|
||||||
|
Qwen2_5_VLForConditionalGeneration = None
|
||||||
|
torch_compilable_check = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_vector(vector, new_dim):
|
||||||
|
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||||
|
|
||||||
|
Can be (batch_size x sequence_length x features_dimension)
|
||||||
|
or (batch_size x features_dimension)
|
||||||
|
"""
|
||||||
|
if vector.shape[-1] >= new_dim:
|
||||||
|
return vector
|
||||||
|
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
||||||
|
|
||||||
|
|
||||||
|
class EO1Policy(PreTrainedPolicy):
|
||||||
|
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
|
||||||
|
|
||||||
|
config_class = EO1Config
|
||||||
|
name = "eo1"
|
||||||
|
|
||||||
|
def __init__(self, config: EO1Config, **kwargs):
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if config.pretrained_path is None:
|
||||||
|
# Initialize from pretrained VLM
|
||||||
|
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
config.vlm_base,
|
||||||
|
dtype=config.dtype,
|
||||||
|
attn_implementation=config.attn_implementation,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
|
||||||
|
config.vlm_backbone_config,
|
||||||
|
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
|
||||||
|
if config.gradient_checkpointing:
|
||||||
|
self.model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
self.model.to(config.device)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
|
||||||
|
return {key: value for key, value in batch.items() if key not in excluded_keys}
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||||
|
state = self.prepare_state(batch[OBS_STATE])
|
||||||
|
actions = self.prepare_action(batch[ACTION])
|
||||||
|
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
|
||||||
|
loss = self.model(states=state, action=actions, **model_inputs)
|
||||||
|
|
||||||
|
loss_dict = {"loss": loss.item()}
|
||||||
|
return loss, loss_dict
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
states = self.prepare_state(batch[OBS_STATE])
|
||||||
|
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
|
||||||
|
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
|
||||||
|
|
||||||
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
return actions[:, :, :original_action_dim]
|
||||||
|
|
||||||
|
def prepare_state(self, state: Tensor) -> Tensor:
|
||||||
|
return pad_vector(state, self.config.max_state_dim)
|
||||||
|
|
||||||
|
def prepare_action(self, action: Tensor) -> Tensor:
|
||||||
|
return pad_vector(action, self.config.max_action_dim)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
if len(self._action_queue) == 0:
|
||||||
|
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||||
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
return self.parameters()
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_dtype(target_dtype, device_type):
|
||||||
|
"""Get a safe dtype for the given device type."""
|
||||||
|
if device_type == "mps" and target_dtype == torch.float64:
|
||||||
|
return torch.float32
|
||||||
|
if device_type == "cpu":
|
||||||
|
# CPU doesn't support bfloat16, use float32 instead
|
||||||
|
if target_dtype == torch.bfloat16:
|
||||||
|
return torch.float32
|
||||||
|
if target_dtype == torch.float64:
|
||||||
|
return torch.float64
|
||||||
|
return target_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||||
|
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||||
|
) -> Tensor:
|
||||||
|
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||||
|
if dimension % 2 != 0:
|
||||||
|
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||||
|
|
||||||
|
if time.ndim != 1:
|
||||||
|
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||||
|
|
||||||
|
dtype = get_safe_dtype(torch.float64, device.type)
|
||||||
|
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||||
|
period = min_period * (max_period / min_period) ** fraction
|
||||||
|
|
||||||
|
# Compute the outer product
|
||||||
|
scaling_factor = 1.0 / period * 2 * math.pi
|
||||||
|
sin_input = scaling_factor[None, :] * time[:, None]
|
||||||
|
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||||
|
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||||
|
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||||
|
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||||
|
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||||
|
return dist.sample((bsize,)).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
class EO1VisionActionProjector(torch.nn.Sequential):
|
||||||
|
"""This block implements the multi-layer perceptron (MLP) module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_layers: int = 2,
|
||||||
|
activation_layer: str = "linear",
|
||||||
|
bias: bool = True,
|
||||||
|
device: Any = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
):
|
||||||
|
layers = []
|
||||||
|
in_dim = in_channels
|
||||||
|
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
|
||||||
|
for hidden_dim in hidden_channels[:-1]:
|
||||||
|
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
|
||||||
|
layers.append(ACT2FN[activation_layer])
|
||||||
|
in_dim = hidden_dim
|
||||||
|
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
|
||||||
|
super().__init__(*layers)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self[0].weight.dtype
|
||||||
|
|
||||||
|
|
||||||
|
class EO1VisionFlowMatchingModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: EO1Config,
|
||||||
|
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
|
||||||
|
):
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
|
||||||
|
self.vlm_backbone = vlm_backbone
|
||||||
|
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
|
||||||
|
max_state_dim = config.max_state_dim
|
||||||
|
max_action_dim = config.max_action_dim
|
||||||
|
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
|
||||||
|
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
|
||||||
|
self.action_out_proj = EO1VisionActionProjector(
|
||||||
|
self.hidden_size,
|
||||||
|
max_action_dim,
|
||||||
|
config.num_action_layers,
|
||||||
|
config.action_act,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
|
||||||
|
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
|
||||||
|
self.gradient_checkpointing_enabled = False
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.vlm_backbone.get_input_embeddings()
|
||||||
|
|
||||||
|
def flow_head_autocast_context(self):
|
||||||
|
if self.config.force_fp32_autocast:
|
||||||
|
return torch.autocast(
|
||||||
|
device_type=self.state_proj.weight.device.type,
|
||||||
|
enabled=False,
|
||||||
|
)
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def gradient_checkpointing_enable(self):
|
||||||
|
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||||
|
self.gradient_checkpointing_enabled = True
|
||||||
|
self.vlm_backbone.gradient_checkpointing_enable(
|
||||||
|
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||||
|
)
|
||||||
|
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||||
|
|
||||||
|
def gradient_checkpointing_disable(self):
|
||||||
|
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||||
|
self.gradient_checkpointing_enabled = False
|
||||||
|
self.vlm_backbone.gradient_checkpointing_disable()
|
||||||
|
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||||
|
|
||||||
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||||
|
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
|
||||||
|
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
|
||||||
|
return torch.utils.checkpoint.checkpoint(
|
||||||
|
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
def sample_noise(self, shape, device):
|
||||||
|
noise = torch.normal(
|
||||||
|
mean=0.0,
|
||||||
|
std=1.0,
|
||||||
|
size=shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
return noise
|
||||||
|
|
||||||
|
def sample_time(self, bsize, device):
|
||||||
|
time_beta = sample_beta(
|
||||||
|
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
||||||
|
)
|
||||||
|
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||||
|
return time.to(dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
def get_placeholder_mask(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None,
|
||||||
|
inputs_embeds: torch.FloatTensor | None,
|
||||||
|
state_features: torch.FloatTensor | None = None,
|
||||||
|
action_features: torch.FloatTensor | None = None,
|
||||||
|
*,
|
||||||
|
state_token_id: int,
|
||||||
|
action_token_id: int,
|
||||||
|
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
|
||||||
|
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
|
||||||
|
if input_ids is None:
|
||||||
|
special_state_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
|
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
)
|
||||||
|
special_state_mask = special_state_mask.all(-1)
|
||||||
|
special_action_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
|
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||||
|
)
|
||||||
|
special_action_mask = special_action_mask.all(-1)
|
||||||
|
else:
|
||||||
|
special_state_mask = input_ids == state_token_id
|
||||||
|
special_action_mask = input_ids == action_token_id
|
||||||
|
|
||||||
|
n_state_tokens = special_state_mask.sum()
|
||||||
|
special_state_mask = (
|
||||||
|
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
if state_features is not None:
|
||||||
|
torch_compilable_check(
|
||||||
|
inputs_embeds[special_state_mask].numel() == state_features.numel(),
|
||||||
|
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
n_action_tokens = special_action_mask.sum()
|
||||||
|
special_action_mask = (
|
||||||
|
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
if action_features is not None:
|
||||||
|
torch_compilable_check(
|
||||||
|
inputs_embeds[special_action_mask].numel() == action_features.numel(),
|
||||||
|
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return special_state_mask, special_action_mask
|
||||||
|
|
||||||
|
def embed_prefix(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
states: torch.Tensor,
|
||||||
|
*,
|
||||||
|
state_token_id: int,
|
||||||
|
action_token_id: int,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
|
||||||
|
|
||||||
|
# Get the input embeddings for the input IDs
|
||||||
|
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||||
|
return self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
|
||||||
|
|
||||||
|
# Project the states to the hidden size
|
||||||
|
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
states = states.to(dtype=self.state_proj.weight.dtype)
|
||||||
|
return self.state_proj(states)
|
||||||
|
|
||||||
|
state_embs = self._apply_checkpoint(state_proj_func, states)
|
||||||
|
state_mask, _ = self.get_placeholder_mask(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
state_features=state_embs,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
)
|
||||||
|
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def embed_suffix(
|
||||||
|
self,
|
||||||
|
timestep: torch.Tensor,
|
||||||
|
noisy_actions: torch.Tensor,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""Embed the suffix"""
|
||||||
|
|
||||||
|
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
|
||||||
|
return self.action_in_proj(noisy_actions)
|
||||||
|
|
||||||
|
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||||
|
time_embs = create_sinusoidal_pos_embedding(
|
||||||
|
timestep,
|
||||||
|
self.hidden_size,
|
||||||
|
min_period=self.config.min_period,
|
||||||
|
max_period=self.config.max_period,
|
||||||
|
device=action_embs.device,
|
||||||
|
)
|
||||||
|
time_embs = time_embs.to(dtype=action_embs.dtype)
|
||||||
|
time_embs = time_embs[:, None, :].expand_as(action_embs)
|
||||||
|
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
|
||||||
|
|
||||||
|
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
|
||||||
|
action_time_embs = self.action_time_mlp_in(action_time_embs)
|
||||||
|
action_time_embs = F.silu(action_time_embs)
|
||||||
|
return self.action_time_mlp_out(action_time_embs)
|
||||||
|
|
||||||
|
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
|
||||||
|
return action_time_embs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: torch.LongTensor | None = None,
|
||||||
|
pixel_values: torch.FloatTensor | None = None,
|
||||||
|
image_grid_thw: torch.LongTensor | None = None,
|
||||||
|
mm_token_type_ids: torch.IntTensor | None = None,
|
||||||
|
states: torch.FloatTensor | None = None,
|
||||||
|
action: torch.FloatTensor | None = None,
|
||||||
|
action_is_pad: torch.BoolTensor | None = None,
|
||||||
|
*,
|
||||||
|
state_token_id: int,
|
||||||
|
action_token_id: int,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Run the EO1 training forward pass and compute the flow-matching loss."""
|
||||||
|
|
||||||
|
# 1. Build the EO1 prefix with state placeholders resolved.
|
||||||
|
inputs_embeds = self.embed_prefix(
|
||||||
|
input_ids,
|
||||||
|
states=states,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Sample the diffusion target and replace the action placeholders.
|
||||||
|
time = self.sample_time(action.shape[0], inputs_embeds.device)
|
||||||
|
noise = self.sample_noise(action.shape, inputs_embeds.device)
|
||||||
|
|
||||||
|
time_expanded = time[:, None, None]
|
||||||
|
x_t = time_expanded * noise + (1 - time_expanded) * action
|
||||||
|
u_t = noise - action
|
||||||
|
action_time_embs = self.embed_suffix(time, x_t)
|
||||||
|
_, action_mask = self.get_placeholder_mask(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
action_features=action_time_embs,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
)
|
||||||
|
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
|
||||||
|
|
||||||
|
# 3. Optionally drop padded action tokens from backbone attention.
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
if not self.config.supervise_padding_actions:
|
||||||
|
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
|
||||||
|
action_token_mask = action_mask[..., 0]
|
||||||
|
action_padding_mask = torch.zeros_like(action_token_mask)
|
||||||
|
action_padding_mask = action_padding_mask.masked_scatter(
|
||||||
|
action_token_mask,
|
||||||
|
action_is_pad.reshape(-1),
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
|
||||||
|
|
||||||
|
# 4. Run the Qwen backbone on the fused EO1 sequence.
|
||||||
|
def vlm_forward_func(
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
attention_mask: torch.Tensor | None,
|
||||||
|
inputs_embeds: torch.FloatTensor,
|
||||||
|
pixel_values: torch.Tensor | None,
|
||||||
|
image_grid_thw: torch.LongTensor | None,
|
||||||
|
mm_token_type_ids: torch.IntTensor | None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
outputs = self.vlm_backbone.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
mm_token_type_ids=mm_token_type_ids,
|
||||||
|
use_cache=False,
|
||||||
|
output_hidden_states=False,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
return outputs.last_hidden_state
|
||||||
|
|
||||||
|
hidden_states = self._apply_checkpoint(
|
||||||
|
vlm_forward_func,
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
pixel_values,
|
||||||
|
image_grid_thw,
|
||||||
|
mm_token_type_ids,
|
||||||
|
)
|
||||||
|
action_hidden_states = hidden_states[action_mask[..., 0]]
|
||||||
|
|
||||||
|
# 5. Project the action-token hidden states back to the flow target space.
|
||||||
|
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||||
|
return self.action_out_proj(action_hidden_states)
|
||||||
|
|
||||||
|
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
|
||||||
|
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
|
||||||
|
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||||
|
|
||||||
|
# 6. Apply the configured supervision mask and reduce the loss.
|
||||||
|
if not self.config.supervise_padding_action_dims:
|
||||||
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
losses = losses[..., :original_action_dim]
|
||||||
|
|
||||||
|
if not self.config.supervise_padding_actions:
|
||||||
|
losses = losses[~action_is_pad]
|
||||||
|
|
||||||
|
return losses.mean()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_actions(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
pixel_values: torch.Tensor | None = None,
|
||||||
|
image_grid_thw: torch.LongTensor | None = None,
|
||||||
|
mm_token_type_ids: torch.IntTensor | None = None,
|
||||||
|
states: torch.Tensor | None = None,
|
||||||
|
*,
|
||||||
|
state_token_id: int,
|
||||||
|
action_token_id: int,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Sample actions from the model."""
|
||||||
|
if states is None:
|
||||||
|
raise ValueError("states are required for EO1 action sampling.")
|
||||||
|
if mm_token_type_ids is None:
|
||||||
|
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
|
||||||
|
|
||||||
|
# 1. Resolve the left-padded rollout prompt and locate the action span.
|
||||||
|
chunk_size = self.config.chunk_size
|
||||||
|
|
||||||
|
inputs_embeds = self.embed_prefix(
|
||||||
|
input_ids,
|
||||||
|
states=states,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
).clone()
|
||||||
|
_, action_placeholder_mask = self.get_placeholder_mask(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
state_token_id=state_token_id,
|
||||||
|
action_token_id=action_token_id,
|
||||||
|
)
|
||||||
|
action_mask = action_placeholder_mask[..., 0]
|
||||||
|
token_counts = action_mask.sum(dim=1)
|
||||||
|
if not torch.all(token_counts == chunk_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
|
||||||
|
)
|
||||||
|
if action_mask.ne(action_mask[:1]).any():
|
||||||
|
raise ValueError(
|
||||||
|
"Batch inference expects all samples to share the same action token mask after left padding."
|
||||||
|
)
|
||||||
|
act_start = int(action_mask[0].to(torch.int64).argmax().item())
|
||||||
|
act_end = act_start + self.config.chunk_size
|
||||||
|
if not torch.all(action_mask[:, act_start:act_end]):
|
||||||
|
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
|
||||||
|
act_slice = slice(act_start, act_end)
|
||||||
|
|
||||||
|
# 2. Encode the fixed prefix once and cache its KV state.
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
device = inputs_embeds.device
|
||||||
|
attention_mask = attention_mask.to(device)
|
||||||
|
mm_token_type_ids = mm_token_type_ids.to(device)
|
||||||
|
position_ids, _ = self.vlm_backbone.model.get_rope_index(
|
||||||
|
input_ids,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
mm_token_type_ids=mm_token_type_ids,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.to(device)
|
||||||
|
|
||||||
|
outputs = self.vlm_backbone.model(
|
||||||
|
input_ids=input_ids[:, :act_start],
|
||||||
|
attention_mask=attention_mask[:, :act_start],
|
||||||
|
position_ids=position_ids[..., :act_start],
|
||||||
|
inputs_embeds=inputs_embeds[:, :act_start],
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
mm_token_type_ids=mm_token_type_ids[:, :act_start],
|
||||||
|
use_cache=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_t = self.sample_noise(
|
||||||
|
(batch_size, chunk_size, self.config.max_action_dim),
|
||||||
|
device,
|
||||||
|
).to(dtype=self.action_in_proj.weight.dtype)
|
||||||
|
dt = -1.0 / self.config.num_denoise_steps
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
|
||||||
|
for step in range(self.config.num_denoise_steps):
|
||||||
|
time = torch.full(
|
||||||
|
(batch_size,),
|
||||||
|
1.0 + step * dt,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
action_time_embs = self.embed_suffix(time, x_t)
|
||||||
|
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
|
||||||
|
|
||||||
|
# Keep the prefix KV cache invariant across denoising steps.
|
||||||
|
past_key_values.crop(act_start)
|
||||||
|
outputs = self.vlm_backbone.model(
|
||||||
|
attention_mask=attention_mask[:, :act_end],
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds[:, act_slice],
|
||||||
|
position_ids=position_ids[..., act_slice],
|
||||||
|
use_cache=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
with self.flow_head_autocast_context():
|
||||||
|
hidden_states = outputs.last_hidden_state[:, :chunk_size]
|
||||||
|
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||||
|
v_t = self.action_out_proj(hidden_states)
|
||||||
|
|
||||||
|
x_t += dt * v_t.reshape(x_t.shape)
|
||||||
|
|
||||||
|
return x_t
|
||||||
282
src/lerobot/policies/eo1/processor_eo1.py
Normal file
282
src/lerobot/policies/eo1/processor_eo1.py
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||||
|
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
ComplementaryDataProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
ProcessorStep,
|
||||||
|
ProcessorStepRegistry,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
UnnormalizerProcessorStep,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
|
from lerobot.types import TransitionKey
|
||||||
|
from lerobot.utils.constants import (
|
||||||
|
OBS_STATE,
|
||||||
|
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
)
|
||||||
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||||
|
else:
|
||||||
|
Qwen2_5_VLProcessor = None
|
||||||
|
|
||||||
|
SYSTEM_MESSAGE = "You are a helpful physical assistant."
|
||||||
|
|
||||||
|
# EO-1 special tokens
|
||||||
|
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
|
||||||
|
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
|
||||||
|
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
|
||||||
|
STATE_START_TOKEN = "<|state_start|>" # nosec B105
|
||||||
|
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
|
||||||
|
STATE_END_TOKEN = "<|state_end|>" # nosec B105
|
||||||
|
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
|
||||||
|
|
||||||
|
EO1_SPECIAL_TOKENS = [
|
||||||
|
ACTION_START_TOKEN,
|
||||||
|
DEFAULT_ACTION_TOKEN,
|
||||||
|
ACTION_END_TOKEN,
|
||||||
|
STATE_START_TOKEN,
|
||||||
|
DEFAULT_STATE_TOKEN,
|
||||||
|
STATE_END_TOKEN,
|
||||||
|
TASK_VLA_TOKEN,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
|
||||||
|
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
|
||||||
|
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
|
||||||
|
chunk_size: int
|
||||||
|
|
||||||
|
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Robust JSON deserialization handling (guard empty maps).
|
||||||
|
if self.input_features:
|
||||||
|
first_val = next(iter(self.input_features.values()))
|
||||||
|
if isinstance(first_val, dict):
|
||||||
|
reconstructed = {}
|
||||||
|
for key, ft_dict in self.input_features.items():
|
||||||
|
reconstructed[key] = PolicyFeature(
|
||||||
|
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||||
|
)
|
||||||
|
self.input_features = reconstructed
|
||||||
|
|
||||||
|
self._image_keys = [
|
||||||
|
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
|
||||||
|
]
|
||||||
|
|
||||||
|
def complementary_data(self, complementary_data):
|
||||||
|
tasks = complementary_data.get("task")
|
||||||
|
if tasks is None:
|
||||||
|
raise ValueError("Task is required for EO1ConversationTemplateStep.")
|
||||||
|
|
||||||
|
observation = self.transition.get(TransitionKey.OBSERVATION)
|
||||||
|
if observation is None:
|
||||||
|
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
|
||||||
|
|
||||||
|
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
|
||||||
|
raise ValueError("Batch size mismatch between observation.state and task list.")
|
||||||
|
|
||||||
|
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
|
||||||
|
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
|
||||||
|
images = {
|
||||||
|
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
|
||||||
|
}
|
||||||
|
messages = []
|
||||||
|
for i in range(len(tasks)):
|
||||||
|
content = [
|
||||||
|
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": (
|
||||||
|
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
messages.append(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
|
||||||
|
{"role": "user", "content": content},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
complementary_data["messages"] = messages
|
||||||
|
|
||||||
|
return complementary_data
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""
|
||||||
|
This step only materializes EO1-specific message objects in complementary_data.
|
||||||
|
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
|
||||||
|
feature contract change to record here.
|
||||||
|
"""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"input_features": {
|
||||||
|
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
|
||||||
|
},
|
||||||
|
"chunk_size": self.chunk_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
|
||||||
|
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
|
||||||
|
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
image_min_pixels: int | None = 64 * 28 * 28
|
||||||
|
image_max_pixels: int | None = 128 * 28 * 28
|
||||||
|
use_fast_processor: bool = False
|
||||||
|
|
||||||
|
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
|
||||||
|
_state_token_id: int | None = field(default=None, init=False, repr=False)
|
||||||
|
_action_token_id: int | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
require_package("transformers", extra="eo1")
|
||||||
|
self._processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||||
|
self.processor_name,
|
||||||
|
use_fast=self.use_fast_processor,
|
||||||
|
)
|
||||||
|
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
|
||||||
|
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
|
||||||
|
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
|
||||||
|
|
||||||
|
def complementary_data(self, complementary_data):
|
||||||
|
messages = complementary_data.pop("messages", None)
|
||||||
|
if messages is None:
|
||||||
|
raise ValueError("Messages are required for EO1QwenProcessorStep.")
|
||||||
|
|
||||||
|
# Rollout batches use left padding so action spans stay aligned across samples.
|
||||||
|
# Supervised batches use right padding to match standard training collation.
|
||||||
|
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
|
||||||
|
|
||||||
|
inputs = self._processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=True,
|
||||||
|
padding=True,
|
||||||
|
padding_side=padding_side,
|
||||||
|
min_pixels=self.image_min_pixels,
|
||||||
|
max_pixels=self.image_max_pixels,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
complementary_data["input_ids"] = inputs["input_ids"]
|
||||||
|
complementary_data["pixel_values"] = inputs["pixel_values"]
|
||||||
|
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
|
||||||
|
complementary_data["attention_mask"] = inputs["attention_mask"]
|
||||||
|
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
|
||||||
|
complementary_data["state_token_id"] = self._state_token_id
|
||||||
|
complementary_data["action_token_id"] = self._action_token_id
|
||||||
|
|
||||||
|
return complementary_data
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"processor_name": self.processor_name,
|
||||||
|
"image_min_pixels": self.image_min_pixels,
|
||||||
|
"image_max_pixels": self.image_max_pixels,
|
||||||
|
"use_fast_processor": self.use_fast_processor,
|
||||||
|
}
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""
|
||||||
|
This step only converts the messages to the model input format.
|
||||||
|
"""
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def make_eo1_pre_post_processors(
|
||||||
|
config: EO1Config,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""Build pre/post processor pipelines for EO1."""
|
||||||
|
|
||||||
|
input_steps: list[ProcessorStep] = [
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features={**config.input_features, **config.output_features},
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
|
||||||
|
EO1QwenProcessorStep(
|
||||||
|
processor_name=config.vlm_base,
|
||||||
|
image_min_pixels=config.image_min_pixels,
|
||||||
|
image_max_pixels=config.image_max_pixels,
|
||||||
|
use_fast_processor=config.use_fast_processor,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
]
|
||||||
|
|
||||||
|
output_steps: list[ProcessorStep] = [
|
||||||
|
UnnormalizerProcessorStep(
|
||||||
|
features=config.output_features,
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device="cpu"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=input_steps,
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=output_steps,
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -46,14 +46,13 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
|||||||
|
|
||||||
from .act.configuration_act import ACTConfig
|
from .act.configuration_act import ACTConfig
|
||||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from .eo1.configuration_eo1 import EO1Config
|
||||||
|
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||||
from .groot.configuration_groot import GrootConfig
|
from .groot.configuration_groot import GrootConfig
|
||||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||||
from .pi0.configuration_pi0 import PI0Config
|
from .pi0.configuration_pi0 import PI0Config
|
||||||
from .pi05.configuration_pi05 import PI05Config
|
from .pi05.configuration_pi05 import PI05Config
|
||||||
from .pretrained import PreTrainedPolicy
|
from .pretrained import PreTrainedPolicy
|
||||||
from .sac.configuration_sac import SACConfig
|
|
||||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
|
|
||||||
from .sarm.configuration_sarm import SARMConfig
|
|
||||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from .utils import validate_visual_features_consistency
|
from .utils import validate_visual_features_consistency
|
||||||
@@ -89,7 +88,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "smolvla", "wall_x".
|
||||||
Returns:
|
Returns:
|
||||||
The policy class corresponding to the given name.
|
The policy class corresponding to the given name.
|
||||||
|
|
||||||
@@ -128,22 +127,14 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from .pi05.modeling_pi05 import PI05Policy
|
from .pi05.modeling_pi05 import PI05Policy
|
||||||
|
|
||||||
return PI05Policy
|
return PI05Policy
|
||||||
elif name == "sac":
|
elif name == "gaussian_actor":
|
||||||
from .sac.modeling_sac import SACPolicy
|
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||||
|
|
||||||
return SACPolicy
|
return GaussianActorPolicy
|
||||||
elif name == "reward_classifier":
|
|
||||||
from .sac.reward_model.modeling_classifier import Classifier
|
|
||||||
|
|
||||||
return Classifier
|
|
||||||
elif name == "smolvla":
|
elif name == "smolvla":
|
||||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
|
|
||||||
return SmolVLAPolicy
|
return SmolVLAPolicy
|
||||||
elif name == "sarm":
|
|
||||||
from .sarm.modeling_sarm import SARMRewardModel
|
|
||||||
|
|
||||||
return SARMRewardModel
|
|
||||||
elif name == "groot":
|
elif name == "groot":
|
||||||
from .groot.modeling_groot import GrootPolicy
|
from .groot.modeling_groot import GrootPolicy
|
||||||
|
|
||||||
@@ -156,6 +147,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from .wall_x.modeling_wall_x import WallXPolicy
|
from .wall_x.modeling_wall_x import WallXPolicy
|
||||||
|
|
||||||
return WallXPolicy
|
return WallXPolicy
|
||||||
|
elif name == "eo1":
|
||||||
|
from .eo1.modeling_eo1 import EO1Policy
|
||||||
|
|
||||||
|
return EO1Policy
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return _get_policy_cls_from_policy_name(name=name)
|
return _get_policy_cls_from_policy_name(name=name)
|
||||||
@@ -172,8 +167,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||||
"smolvla", "reward_classifier", "wall_x".
|
"smolvla", "wall_x".
|
||||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -196,18 +191,18 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return PI0Config(**kwargs)
|
return PI0Config(**kwargs)
|
||||||
elif policy_type == "pi05":
|
elif policy_type == "pi05":
|
||||||
return PI05Config(**kwargs)
|
return PI05Config(**kwargs)
|
||||||
elif policy_type == "sac":
|
elif policy_type == "gaussian_actor":
|
||||||
return SACConfig(**kwargs)
|
return GaussianActorConfig(**kwargs)
|
||||||
elif policy_type == "smolvla":
|
elif policy_type == "smolvla":
|
||||||
return SmolVLAConfig(**kwargs)
|
return SmolVLAConfig(**kwargs)
|
||||||
elif policy_type == "reward_classifier":
|
|
||||||
return RewardClassifierConfig(**kwargs)
|
|
||||||
elif policy_type == "groot":
|
elif policy_type == "groot":
|
||||||
return GrootConfig(**kwargs)
|
return GrootConfig(**kwargs)
|
||||||
elif policy_type == "xvla":
|
elif policy_type == "xvla":
|
||||||
return XVLAConfig(**kwargs)
|
return XVLAConfig(**kwargs)
|
||||||
elif policy_type == "wall_x":
|
elif policy_type == "wall_x":
|
||||||
return WallXConfig(**kwargs)
|
return WallXConfig(**kwargs)
|
||||||
|
elif policy_type == "eo1":
|
||||||
|
return EO1Config(**kwargs)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||||
@@ -370,18 +365,10 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, SACConfig):
|
elif isinstance(policy_cfg, GaussianActorConfig):
|
||||||
from .sac.processor_sac import make_sac_pre_post_processors
|
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||||
|
|
||||||
processors = make_sac_pre_post_processors(
|
processors = make_gaussian_actor_pre_post_processors(
|
||||||
config=policy_cfg,
|
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
|
||||||
from .sac.reward_model.processor_classifier import make_classifier_processor
|
|
||||||
|
|
||||||
processors = make_classifier_processor(
|
|
||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
@@ -394,14 +381,6 @@ def make_pre_post_processors(
|
|||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, SARMConfig):
|
|
||||||
from .sarm.processor_sarm import make_sarm_pre_post_processors
|
|
||||||
|
|
||||||
processors = make_sarm_pre_post_processors(
|
|
||||||
config=policy_cfg,
|
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
|
||||||
dataset_meta=kwargs.get("dataset_meta"),
|
|
||||||
)
|
|
||||||
elif isinstance(policy_cfg, GrootConfig):
|
elif isinstance(policy_cfg, GrootConfig):
|
||||||
from .groot.processor_groot import make_groot_pre_post_processors
|
from .groot.processor_groot import make_groot_pre_post_processors
|
||||||
|
|
||||||
@@ -427,6 +406,13 @@ def make_pre_post_processors(
|
|||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
elif isinstance(policy_cfg, EO1Config):
|
||||||
|
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_eo1_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@@ -542,7 +528,7 @@ def make_policy(
|
|||||||
|
|
||||||
logging.info("Loading policy's PEFT adapter.")
|
logging.info("Loading policy's PEFT adapter.")
|
||||||
|
|
||||||
peft_pretrained_path = cfg.pretrained_path
|
peft_pretrained_path = str(cfg.pretrained_path)
|
||||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||||
|
|
||||||
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
||||||
@@ -555,7 +541,9 @@ def make_policy(
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = policy_cls.from_pretrained(**kwargs)
|
policy = policy_cls.from_pretrained(**kwargs)
|
||||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
policy = PeftModel.from_pretrained(
|
||||||
|
policy, peft_pretrained_path, config=peft_config, is_trainable=True
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Make a fresh policy.
|
# Make a fresh policy.
|
||||||
|
|||||||
@@ -12,8 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .configuration_sac import SACConfig
|
from .configuration_gaussian_actor import GaussianActorConfig
|
||||||
from .modeling_sac import SACPolicy
|
from .modeling_gaussian_actor import GaussianActorPolicy
|
||||||
from .processor_sac import make_sac_pre_post_processors
|
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||||
|
|
||||||
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
|
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
# !/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team.
|
# Copyright 2025 The HuggingFace Inc. team.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
@@ -75,18 +75,19 @@ class PolicyConfig:
|
|||||||
init_final: float = 0.05
|
init_final: float = 0.05
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("sac")
|
@PreTrainedConfig.register_subclass("gaussian_actor")
|
||||||
@dataclass
|
@dataclass
|
||||||
class SACConfig(PreTrainedConfig):
|
class GaussianActorConfig(PreTrainedConfig):
|
||||||
"""Soft Actor-Critic (SAC) configuration.
|
"""Gaussian actor configuration.
|
||||||
|
|
||||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
This configures the policy-side (actor + observation encoder) of a Gaussian
|
||||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
|
||||||
using experience collected from the environment.
|
By default the actor output is a tanh-squashed diagonal Gaussian
|
||||||
|
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
|
||||||
|
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
|
||||||
|
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
|
||||||
|
|
||||||
This configuration class contains all the parameters needed to define a SAC agent,
|
CLI: ``--policy.type=gaussian_actor``.
|
||||||
including network architectures, optimization settings, and algorithm-specific
|
|
||||||
hyperparameters.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Mapping of feature types to normalization modes
|
# Mapping of feature types to normalization modes
|
||||||
@@ -122,7 +123,7 @@ class SACConfig(PreTrainedConfig):
|
|||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
# Device to store the model on
|
# Device to store the model on
|
||||||
storage_device: str = "cpu"
|
storage_device: str = "cpu"
|
||||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
|
||||||
vision_encoder_name: str | None = None
|
vision_encoder_name: str | None = None
|
||||||
# Whether to freeze the vision encoder during training
|
# Whether to freeze the vision encoder during training
|
||||||
freeze_vision_encoder: bool = True
|
freeze_vision_encoder: bool = True
|
||||||
@@ -135,7 +136,13 @@ class SACConfig(PreTrainedConfig):
|
|||||||
# Dimension of the image embedding pooling
|
# Dimension of the image embedding pooling
|
||||||
image_embedding_pooling_dim: int = 8
|
image_embedding_pooling_dim: int = 8
|
||||||
|
|
||||||
# Training parameter
|
# Encoder architecture
|
||||||
|
# Hidden dimension size for the state encoder
|
||||||
|
state_encoder_hidden_dim: int = 256
|
||||||
|
# Dimension of the latent space
|
||||||
|
latent_dim: int = 256
|
||||||
|
|
||||||
|
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
|
||||||
# Number of steps for online training
|
# Number of steps for online training
|
||||||
online_steps: int = 1000000
|
online_steps: int = 1000000
|
||||||
# Capacity of the online replay buffer
|
# Capacity of the online replay buffer
|
||||||
@@ -146,67 +153,38 @@ class SACConfig(PreTrainedConfig):
|
|||||||
async_prefetch: bool = False
|
async_prefetch: bool = False
|
||||||
# Number of steps before learning starts
|
# Number of steps before learning starts
|
||||||
online_step_before_learning: int = 100
|
online_step_before_learning: int = 100
|
||||||
# Frequency of policy updates
|
|
||||||
policy_update_freq: int = 1
|
|
||||||
|
|
||||||
# SAC algorithm parameters
|
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
|
||||||
# Discount factor for the SAC algorithm
|
|
||||||
discount: float = 0.99
|
|
||||||
# Initial temperature value
|
|
||||||
temperature_init: float = 1.0
|
|
||||||
# Number of critics in the ensemble
|
|
||||||
num_critics: int = 2
|
|
||||||
# Number of subsampled critics for training
|
|
||||||
num_subsample_critics: int | None = None
|
|
||||||
# Learning rate for the critic network
|
|
||||||
critic_lr: float = 3e-4
|
|
||||||
# Learning rate for the actor network
|
|
||||||
actor_lr: float = 3e-4
|
|
||||||
# Learning rate for the temperature parameter
|
|
||||||
temperature_lr: float = 3e-4
|
|
||||||
# Weight for the critic target update
|
|
||||||
critic_target_update_weight: float = 0.005
|
|
||||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
|
||||||
utd_ratio: int = 1
|
|
||||||
# Hidden dimension size for the state encoder
|
|
||||||
state_encoder_hidden_dim: int = 256
|
|
||||||
# Dimension of the latent space
|
|
||||||
latent_dim: int = 256
|
|
||||||
# Target entropy for the SAC algorithm
|
|
||||||
target_entropy: float | None = None
|
|
||||||
# Whether to use backup entropy for the SAC algorithm
|
|
||||||
use_backup_entropy: bool = True
|
|
||||||
# Gradient clipping norm for the SAC algorithm
|
|
||||||
grad_clip_norm: float = 40.0
|
|
||||||
|
|
||||||
# Network configuration
|
|
||||||
# Configuration for the critic network architecture
|
|
||||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
|
||||||
# Configuration for the actor network architecture
|
|
||||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
|
||||||
# Configuration for the policy parameters
|
|
||||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
|
||||||
# Configuration for the discrete critic network
|
|
||||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
|
||||||
# Configuration for actor-learner architecture
|
# Configuration for actor-learner architecture
|
||||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||||
|
|
||||||
# Optimizations
|
# Network architecture
|
||||||
use_torch_compile: bool = True
|
# Configuration for the actor network architecture
|
||||||
|
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||||
|
# Configuration for the policy parameters (Gaussian head)
|
||||||
|
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||||
|
# Configuration for the discrete critic network
|
||||||
|
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
# Any validation specific to SAC configuration
|
# Any validation specific to GaussianActor configuration
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||||
|
# Default learning rate used to satisfy the abstract ``get_optimizer_preset()``
|
||||||
|
# contract from ``PreTrainedConfig``. The actual optimizers used during RL
|
||||||
|
# training are built by ``SACAlgorithm.make_optimizers_and_scheduler()`` from
|
||||||
|
# ``SACAlgorithmConfig.{actor_lr,critic_lr,temperature_lr}`` and fully bypass
|
||||||
|
# this preset.
|
||||||
|
default_lr = 3e-4
|
||||||
return MultiAdamConfig(
|
return MultiAdamConfig(
|
||||||
weight_decay=0.0,
|
weight_decay=0.0,
|
||||||
optimizer_groups={
|
optimizer_groups={
|
||||||
"actor": {"lr": self.actor_lr},
|
"actor": {"lr": default_lr},
|
||||||
"critic": {"lr": self.critic_lr},
|
"critic": {"lr": default_lr},
|
||||||
"temperature": {"lr": self.temperature_lr},
|
"temperature": {"lr": default_lr},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,16 +15,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F # noqa: N812
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||||
|
|
||||||
@@ -32,20 +27,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
|||||||
|
|
||||||
from ..pretrained import PreTrainedPolicy
|
from ..pretrained import PreTrainedPolicy
|
||||||
from ..utils import get_device_from_parameters
|
from ..utils import get_device_from_parameters
|
||||||
from .configuration_sac import SACConfig, is_image_feature
|
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
|
||||||
|
|
||||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||||
|
|
||||||
|
|
||||||
class SACPolicy(
|
class GaussianActorPolicy(
|
||||||
PreTrainedPolicy,
|
PreTrainedPolicy,
|
||||||
):
|
):
|
||||||
config_class = SACConfig
|
config_class = GaussianActorConfig
|
||||||
name = "sac"
|
name = "gaussian_actor"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: SACConfig | None = None,
|
config: GaussianActorConfig | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
config.validate_features()
|
config.validate_features()
|
||||||
@@ -54,9 +49,8 @@ class SACPolicy(
|
|||||||
# Determine action dimension and initialize all components
|
# Determine action dimension and initialize all components
|
||||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||||
self._init_encoders()
|
self._init_encoders()
|
||||||
self._init_critics(continuous_action_dim)
|
|
||||||
self._init_actor(continuous_action_dim)
|
self._init_actor(continuous_action_dim)
|
||||||
self._init_temperature()
|
self._init_discrete_critic()
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
optim_params = {
|
optim_params = {
|
||||||
@@ -65,11 +59,7 @@ class SACPolicy(
|
|||||||
for n, p in self.actor.named_parameters()
|
for n, p in self.actor.named_parameters()
|
||||||
if not n.startswith("encoder") or not self.shared_encoder
|
if not n.startswith("encoder") or not self.shared_encoder
|
||||||
],
|
],
|
||||||
"critic": self.critic_ensemble.parameters(),
|
|
||||||
"temperature": self.log_alpha,
|
|
||||||
}
|
}
|
||||||
if self.config.num_discrete_actions is not None:
|
|
||||||
optim_params["discrete_critic"] = self.discrete_critic.parameters()
|
|
||||||
return optim_params
|
return optim_params
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@@ -79,7 +69,9 @@ class SACPolicy(
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
raise NotImplementedError(
|
||||||
|
"GaussianActorPolicy does not support action chunking. It returns single actions!"
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
@@ -92,360 +84,43 @@ class SACPolicy(
|
|||||||
actions, _, _ = self.actor(batch, observations_features)
|
actions, _, _ = self.actor(batch, observations_features)
|
||||||
|
|
||||||
if self.config.num_discrete_actions is not None:
|
if self.config.num_discrete_actions is not None:
|
||||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
if self.discrete_critic is not None:
|
||||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||||
|
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||||
|
else:
|
||||||
|
discrete_action = torch.ones(
|
||||||
|
(*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype
|
||||||
|
)
|
||||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def critic_forward(
|
def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]:
|
||||||
self,
|
"""Actor forward pass: sample actions and return log-probabilities.
|
||||||
observations: dict[str, Tensor],
|
|
||||||
actions: Tensor,
|
|
||||||
use_target: bool = False,
|
|
||||||
observation_features: Tensor | None = None,
|
|
||||||
) -> Tensor:
|
|
||||||
"""Forward pass through a critic network ensemble
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
observations: Dictionary of observations
|
batch: A flat observation dict, or a training dict containing
|
||||||
actions: Action tensor
|
``"state"`` (observations) and optionally ``"observation_feature"``
|
||||||
use_target: If True, use target critics, otherwise use ensemble critics
|
(pre-computed encoder features).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor of Q-values from all critics
|
Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors.
|
||||||
"""
|
"""
|
||||||
|
observations = batch.get("state", batch)
|
||||||
critics = self.critic_target if use_target else self.critic_ensemble
|
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
|
||||||
q_values = critics(observations, actions, observation_features)
|
actions, log_probs, means = self.actor(observations, observation_features)
|
||||||
return q_values
|
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||||
|
|
||||||
def discrete_critic_forward(
|
|
||||||
self, observations, use_target=False, observation_features=None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Forward pass through a discrete critic network
|
|
||||||
|
|
||||||
Args:
|
|
||||||
observations: Dictionary of observations
|
|
||||||
use_target: If True, use target critics, otherwise use ensemble critics
|
|
||||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor of Q-values from the discrete critic network
|
|
||||||
"""
|
|
||||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
|
||||||
q_values = discrete_critic(observations, observation_features)
|
|
||||||
return q_values
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
|
||||||
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
|
|
||||||
) -> dict[str, Tensor]:
|
|
||||||
"""Compute the loss for the given model
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch: Dictionary containing:
|
|
||||||
- action: Action tensor
|
|
||||||
- reward: Reward tensor
|
|
||||||
- state: Observations tensor dict
|
|
||||||
- next_state: Next observations tensor dict
|
|
||||||
- done: Done mask tensor
|
|
||||||
- observation_feature: Optional pre-computed observation features
|
|
||||||
- next_observation_feature: Optional pre-computed next observation features
|
|
||||||
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The computed loss tensor
|
|
||||||
"""
|
|
||||||
# Extract common components from batch
|
|
||||||
actions: Tensor = batch[ACTION]
|
|
||||||
observations: dict[str, Tensor] = batch["state"]
|
|
||||||
observation_features: Tensor = batch.get("observation_feature")
|
|
||||||
|
|
||||||
if model == "critic":
|
|
||||||
# Extract critic-specific components
|
|
||||||
rewards: Tensor = batch["reward"]
|
|
||||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
|
||||||
done: Tensor = batch["done"]
|
|
||||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
|
||||||
|
|
||||||
loss_critic = self.compute_loss_critic(
|
|
||||||
observations=observations,
|
|
||||||
actions=actions,
|
|
||||||
rewards=rewards,
|
|
||||||
next_observations=next_observations,
|
|
||||||
done=done,
|
|
||||||
observation_features=observation_features,
|
|
||||||
next_observation_features=next_observation_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"loss_critic": loss_critic}
|
|
||||||
|
|
||||||
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
|
|
||||||
# Extract critic-specific components
|
|
||||||
rewards: Tensor = batch["reward"]
|
|
||||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
|
||||||
done: Tensor = batch["done"]
|
|
||||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
|
||||||
complementary_info = batch.get("complementary_info")
|
|
||||||
loss_discrete_critic = self.compute_loss_discrete_critic(
|
|
||||||
observations=observations,
|
|
||||||
actions=actions,
|
|
||||||
rewards=rewards,
|
|
||||||
next_observations=next_observations,
|
|
||||||
done=done,
|
|
||||||
observation_features=observation_features,
|
|
||||||
next_observation_features=next_observation_features,
|
|
||||||
complementary_info=complementary_info,
|
|
||||||
)
|
|
||||||
return {"loss_discrete_critic": loss_discrete_critic}
|
|
||||||
if model == "actor":
|
|
||||||
return {
|
|
||||||
"loss_actor": self.compute_loss_actor(
|
|
||||||
observations=observations,
|
|
||||||
observation_features=observation_features,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if model == "temperature":
|
|
||||||
return {
|
|
||||||
"loss_temperature": self.compute_loss_temperature(
|
|
||||||
observations=observations,
|
|
||||||
observation_features=observation_features,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown model type: {model}")
|
|
||||||
|
|
||||||
def update_target_networks(self):
|
|
||||||
"""Update target networks with exponential moving average"""
|
|
||||||
for target_param, param in zip(
|
|
||||||
self.critic_target.parameters(),
|
|
||||||
self.critic_ensemble.parameters(),
|
|
||||||
strict=True,
|
|
||||||
):
|
|
||||||
target_param.data.copy_(
|
|
||||||
param.data * self.config.critic_target_update_weight
|
|
||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
|
||||||
)
|
|
||||||
if self.config.num_discrete_actions is not None:
|
|
||||||
for target_param, param in zip(
|
|
||||||
self.discrete_critic_target.parameters(),
|
|
||||||
self.discrete_critic.parameters(),
|
|
||||||
strict=True,
|
|
||||||
):
|
|
||||||
target_param.data.copy_(
|
|
||||||
param.data * self.config.critic_target_update_weight
|
|
||||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def temperature(self) -> float:
|
|
||||||
"""Return the current temperature value, always in sync with log_alpha."""
|
|
||||||
return self.log_alpha.exp().item()
|
|
||||||
|
|
||||||
def compute_loss_critic(
|
|
||||||
self,
|
|
||||||
observations,
|
|
||||||
actions,
|
|
||||||
rewards,
|
|
||||||
next_observations,
|
|
||||||
done,
|
|
||||||
observation_features: Tensor | None = None,
|
|
||||||
next_observation_features: Tensor | None = None,
|
|
||||||
) -> Tensor:
|
|
||||||
with torch.no_grad():
|
|
||||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
|
||||||
|
|
||||||
# 2- compute q targets
|
|
||||||
q_targets = self.critic_forward(
|
|
||||||
observations=next_observations,
|
|
||||||
actions=next_action_preds,
|
|
||||||
use_target=True,
|
|
||||||
observation_features=next_observation_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
|
||||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
|
||||||
if self.config.num_subsample_critics is not None:
|
|
||||||
indices = torch.randperm(self.config.num_critics)
|
|
||||||
indices = indices[: self.config.num_subsample_critics]
|
|
||||||
q_targets = q_targets[indices]
|
|
||||||
|
|
||||||
# critics subsample size
|
|
||||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
|
||||||
if self.config.use_backup_entropy:
|
|
||||||
min_q = min_q - (self.temperature * next_log_probs)
|
|
||||||
|
|
||||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
|
||||||
|
|
||||||
# 3- compute predicted qs
|
|
||||||
if self.config.num_discrete_actions is not None:
|
|
||||||
# NOTE: We only want to keep the continuous action part
|
|
||||||
# In the buffer we have the full action space (continuous + discrete)
|
|
||||||
# We need to split them before concatenating them in the critic forward
|
|
||||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
|
||||||
q_preds = self.critic_forward(
|
|
||||||
observations=observations,
|
|
||||||
actions=actions,
|
|
||||||
use_target=False,
|
|
||||||
observation_features=observation_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4- Calculate loss
|
|
||||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
|
||||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
|
||||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
|
||||||
critics_loss = (
|
|
||||||
F.mse_loss(
|
|
||||||
input=q_preds,
|
|
||||||
target=td_target_duplicate,
|
|
||||||
reduction="none",
|
|
||||||
).mean(dim=1)
|
|
||||||
).sum()
|
|
||||||
return critics_loss
|
|
||||||
|
|
||||||
def compute_loss_discrete_critic(
|
|
||||||
self,
|
|
||||||
observations,
|
|
||||||
actions,
|
|
||||||
rewards,
|
|
||||||
next_observations,
|
|
||||||
done,
|
|
||||||
observation_features=None,
|
|
||||||
next_observation_features=None,
|
|
||||||
complementary_info=None,
|
|
||||||
):
|
|
||||||
# NOTE: We only want to keep the discrete action part
|
|
||||||
# In the buffer we have the full action space (continuous + discrete)
|
|
||||||
# We need to split them before concatenating them in the critic forward
|
|
||||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
|
||||||
actions_discrete = torch.round(actions_discrete)
|
|
||||||
actions_discrete = actions_discrete.long()
|
|
||||||
|
|
||||||
discrete_penalties: Tensor | None = None
|
|
||||||
if complementary_info is not None:
|
|
||||||
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# For DQN, select actions using online network, evaluate with target network
|
|
||||||
next_discrete_qs = self.discrete_critic_forward(
|
|
||||||
next_observations, use_target=False, observation_features=next_observation_features
|
|
||||||
)
|
|
||||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
# Get target Q-values from target network
|
|
||||||
target_next_discrete_qs = self.discrete_critic_forward(
|
|
||||||
observations=next_observations,
|
|
||||||
use_target=True,
|
|
||||||
observation_features=next_observation_features,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use gather to select Q-values for best actions
|
|
||||||
target_next_discrete_q = torch.gather(
|
|
||||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
|
||||||
).squeeze(-1)
|
|
||||||
|
|
||||||
# Compute target Q-value with Bellman equation
|
|
||||||
rewards_discrete = rewards
|
|
||||||
if discrete_penalties is not None:
|
|
||||||
rewards_discrete = rewards + discrete_penalties
|
|
||||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
|
||||||
|
|
||||||
# Get predicted Q-values for current observations
|
|
||||||
predicted_discrete_qs = self.discrete_critic_forward(
|
|
||||||
observations=observations, use_target=False, observation_features=observation_features
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use gather to select Q-values for taken actions
|
|
||||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
|
||||||
|
|
||||||
# Compute MSE loss between predicted and target Q-values
|
|
||||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
|
||||||
return discrete_critic_loss
|
|
||||||
|
|
||||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
|
||||||
"""Compute the temperature loss"""
|
|
||||||
# calculate temperature loss
|
|
||||||
with torch.no_grad():
|
|
||||||
_, log_probs, _ = self.actor(observations, observation_features)
|
|
||||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
|
||||||
return temperature_loss
|
|
||||||
|
|
||||||
def compute_loss_actor(
|
|
||||||
self,
|
|
||||||
observations,
|
|
||||||
observation_features: Tensor | None = None,
|
|
||||||
) -> Tensor:
|
|
||||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
|
||||||
|
|
||||||
q_preds = self.critic_forward(
|
|
||||||
observations=observations,
|
|
||||||
actions=actions_pi,
|
|
||||||
use_target=False,
|
|
||||||
observation_features=observation_features,
|
|
||||||
)
|
|
||||||
min_q_preds = q_preds.min(dim=0)[0]
|
|
||||||
|
|
||||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
|
||||||
return actor_loss
|
|
||||||
|
|
||||||
def _init_encoders(self):
|
def _init_encoders(self):
|
||||||
"""Initialize shared or separate encoders for actor and critic."""
|
"""Initialize shared or separate encoders for actor and critic."""
|
||||||
self.shared_encoder = self.config.shared_encoder
|
self.shared_encoder = self.config.shared_encoder
|
||||||
self.encoder_critic = SACObservationEncoder(self.config)
|
self.encoder_critic = GaussianActorObservationEncoder(self.config)
|
||||||
self.encoder_actor = (
|
self.encoder_actor = (
|
||||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init_critics(self, continuous_action_dim):
|
|
||||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
|
||||||
heads = [
|
|
||||||
CriticHead(
|
|
||||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
|
||||||
**asdict(self.config.critic_network_kwargs),
|
|
||||||
)
|
|
||||||
for _ in range(self.config.num_critics)
|
|
||||||
]
|
|
||||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
|
||||||
target_heads = [
|
|
||||||
CriticHead(
|
|
||||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
|
||||||
**asdict(self.config.critic_network_kwargs),
|
|
||||||
)
|
|
||||||
for _ in range(self.config.num_critics)
|
|
||||||
]
|
|
||||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
|
||||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
|
||||||
|
|
||||||
if self.config.use_torch_compile:
|
|
||||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
|
||||||
self.critic_target = torch.compile(self.critic_target)
|
|
||||||
|
|
||||||
if self.config.num_discrete_actions is not None:
|
|
||||||
self._init_discrete_critics()
|
|
||||||
|
|
||||||
def _init_discrete_critics(self):
|
|
||||||
"""Build discrete discrete critic ensemble and target networks."""
|
|
||||||
self.discrete_critic = DiscreteCritic(
|
|
||||||
encoder=self.encoder_critic,
|
|
||||||
input_dim=self.encoder_critic.output_dim,
|
|
||||||
output_dim=self.config.num_discrete_actions,
|
|
||||||
**asdict(self.config.discrete_critic_network_kwargs),
|
|
||||||
)
|
|
||||||
self.discrete_critic_target = DiscreteCritic(
|
|
||||||
encoder=self.encoder_critic,
|
|
||||||
input_dim=self.encoder_critic.output_dim,
|
|
||||||
output_dim=self.config.num_discrete_actions,
|
|
||||||
**asdict(self.config.discrete_critic_network_kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: (maractingi, azouitine) Compile the discrete critic
|
|
||||||
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
|
|
||||||
|
|
||||||
def _init_actor(self, continuous_action_dim):
|
def _init_actor(self, continuous_action_dim):
|
||||||
"""Initialize policy actor network and default target entropy."""
|
"""Initialize policy actor network."""
|
||||||
# NOTE: The actor select only the continuous action part
|
# NOTE: The actor select only the continuous action part
|
||||||
self.actor = Policy(
|
self.actor = Policy(
|
||||||
encoder=self.encoder_actor,
|
encoder=self.encoder_actor,
|
||||||
@@ -455,21 +130,25 @@ class SACPolicy(
|
|||||||
**asdict(self.config.policy_kwargs),
|
**asdict(self.config.policy_kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.target_entropy = self.config.target_entropy
|
def _init_discrete_critic(self) -> None:
|
||||||
if self.target_entropy is None:
|
"""Initialize discrete critic network."""
|
||||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
if self.config.num_discrete_actions is None:
|
||||||
self.target_entropy = -np.prod(dim) / 2
|
self.discrete_critic = None
|
||||||
|
return
|
||||||
|
|
||||||
def _init_temperature(self) -> None:
|
# TODO(Khalil): Compile the discrete critic
|
||||||
"""Set up temperature parameter (log_alpha)."""
|
self.discrete_critic = DiscreteCritic(
|
||||||
temp_init = self.config.temperature_init
|
encoder=self.encoder_critic,
|
||||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
input_dim=self.encoder_critic.output_dim,
|
||||||
|
output_dim=self.config.num_discrete_actions,
|
||||||
|
**asdict(self.config.discrete_critic_network_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SACObservationEncoder(nn.Module):
|
class GaussianActorObservationEncoder(nn.Module):
|
||||||
"""Encode image and/or state vector observations."""
|
"""Encode image and/or state vector observations."""
|
||||||
|
|
||||||
def __init__(self, config: SACConfig) -> None:
|
def __init__(self, config: GaussianActorConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self._init_image_layers()
|
self._init_image_layers()
|
||||||
@@ -677,84 +356,6 @@ class MLP(nn.Module):
|
|||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
class CriticHead(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_dim: int,
|
|
||||||
hidden_dims: list[int],
|
|
||||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
|
||||||
activate_final: bool = False,
|
|
||||||
dropout_rate: float | None = None,
|
|
||||||
init_final: float | None = None,
|
|
||||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.net = MLP(
|
|
||||||
input_dim=input_dim,
|
|
||||||
hidden_dims=hidden_dims,
|
|
||||||
activations=activations,
|
|
||||||
activate_final=activate_final,
|
|
||||||
dropout_rate=dropout_rate,
|
|
||||||
final_activation=final_activation,
|
|
||||||
)
|
|
||||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
|
||||||
if init_final is not None:
|
|
||||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
|
||||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
|
||||||
else:
|
|
||||||
orthogonal_init()(self.output_layer.weight)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.output_layer(self.net(x))
|
|
||||||
|
|
||||||
|
|
||||||
class CriticEnsemble(nn.Module):
|
|
||||||
"""
|
|
||||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
encoder (SACObservationEncoder): encoder for observations.
|
|
||||||
ensemble (List[CriticHead]): list of critic heads.
|
|
||||||
init_final (float | None): optional initializer scale for final layers.
|
|
||||||
|
|
||||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
encoder: SACObservationEncoder,
|
|
||||||
ensemble: list[CriticHead],
|
|
||||||
init_final: float | None = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = encoder
|
|
||||||
self.init_final = init_final
|
|
||||||
self.critics = nn.ModuleList(ensemble)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
observations: dict[str, torch.Tensor],
|
|
||||||
actions: torch.Tensor,
|
|
||||||
observation_features: torch.Tensor | None = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
device = get_device_from_parameters(self)
|
|
||||||
# Move each tensor in observations to device
|
|
||||||
observations = {k: v.to(device) for k, v in observations.items()}
|
|
||||||
|
|
||||||
obs_enc = self.encoder(observations, cache=observation_features)
|
|
||||||
|
|
||||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
|
||||||
|
|
||||||
# Loop through critics and collect outputs
|
|
||||||
q_values = []
|
|
||||||
for critic in self.critics:
|
|
||||||
q_values.append(critic(inputs))
|
|
||||||
|
|
||||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
|
||||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
|
||||||
return q_values
|
|
||||||
|
|
||||||
|
|
||||||
class DiscreteCritic(nn.Module):
|
class DiscreteCritic(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -800,7 +401,7 @@ class DiscreteCritic(nn.Module):
|
|||||||
class Policy(nn.Module):
|
class Policy(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder: SACObservationEncoder,
|
encoder: GaussianActorObservationEncoder,
|
||||||
network: nn.Module,
|
network: nn.Module,
|
||||||
action_dim: int,
|
action_dim: int,
|
||||||
std_min: float = -5,
|
std_min: float = -5,
|
||||||
@@ -811,7 +412,7 @@ class Policy(nn.Module):
|
|||||||
encoder_is_shared: bool = False,
|
encoder_is_shared: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder: SACObservationEncoder = encoder
|
self.encoder: GaussianActorObservationEncoder = encoder
|
||||||
self.network = network
|
self.network = network
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.std_min = std_min
|
self.std_min = std_min
|
||||||
@@ -885,7 +486,7 @@ class Policy(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DefaultImageEncoder(nn.Module):
|
class DefaultImageEncoder(nn.Module):
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: GaussianActorConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
image_key = next(key for key in config.input_features if is_image_feature(key))
|
image_key = next(key for key in config.input_features if is_image_feature(key))
|
||||||
self.image_enc_layers = nn.Sequential(
|
self.image_enc_layers = nn.Sequential(
|
||||||
@@ -931,12 +532,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PretrainedImageEncoder(nn.Module):
|
class PretrainedImageEncoder(nn.Module):
|
||||||
def __init__(self, config: SACConfig):
|
def __init__(self, config: GaussianActorConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||||
|
|
||||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
|
||||||
"""Set up CNN encoder"""
|
"""Set up CNN encoder"""
|
||||||
from transformers import AutoModel
|
from transformers import AutoModel
|
||||||
|
|
||||||
@@ -32,18 +32,18 @@ from lerobot.processor import (
|
|||||||
)
|
)
|
||||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
from .configuration_sac import SACConfig
|
from .configuration_gaussian_actor import GaussianActorConfig
|
||||||
|
|
||||||
|
|
||||||
def make_sac_pre_post_processors(
|
def make_gaussian_actor_pre_post_processors(
|
||||||
config: SACConfig,
|
config: GaussianActorConfig,
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Constructs pre-processor and post-processor pipelines for the SAC policy.
|
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
|
||||||
|
|
||||||
The pre-processing pipeline prepares input data for the model by:
|
The pre-processing pipeline prepares input data for the model by:
|
||||||
1. Renaming features to match pretrained configurations.
|
1. Renaming features to match pretrained configurations.
|
||||||
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
|
|||||||
2. Unnormalizing the output features to their original scale.
|
2. Unnormalizing the output features to their original scale.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: The configuration object for the SAC policy.
|
config: The configuration object for the tanh-Gaussian policy.
|
||||||
dataset_stats: A dictionary of statistics for normalization.
|
dataset_stats: A dictionary of statistics for normalization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import field
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -109,7 +109,6 @@ class MultiEmbodimentActionEncoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -174,17 +173,14 @@ N_COLOR_CHANNELS = 3
|
|||||||
|
|
||||||
|
|
||||||
# config
|
# config
|
||||||
@dataclass
|
|
||||||
class GR00TN15Config(PretrainedConfig):
|
class GR00TN15Config(PretrainedConfig):
|
||||||
model_type = "gr00t_n1_5"
|
model_type = "gr00t_n1_5"
|
||||||
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
|
|
||||||
|
|
||||||
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
|
backbone_cfg: dict
|
||||||
|
action_head_cfg: dict
|
||||||
action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
|
action_horizon: int
|
||||||
|
action_dim: int
|
||||||
action_dim: int = field(init=False, metadata={"help": "Action dimension."})
|
compute_dtype: str = "float32"
|
||||||
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|||||||
@@ -688,8 +688,9 @@ class DiffusionObjective(nn.Module):
|
|||||||
loss = F.mse_loss(predicted, target, reduction="none")
|
loss = F.mse_loss(predicted, target, reduction="none")
|
||||||
|
|
||||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||||
valid_actions = ~batch["action_is_pad"]
|
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
loss = loss * valid_actions.unsqueeze(-1)
|
num_valid = mask.sum() * loss.shape[-1]
|
||||||
|
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
@@ -752,8 +753,9 @@ class FlowMatchingObjective(nn.Module):
|
|||||||
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
||||||
|
|
||||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||||
valid_mask = ~batch["action_is_pad"]
|
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||||
loss = loss * valid_mask.unsqueeze(-1)
|
num_valid = mask.sum() * loss.shape[-1]
|
||||||
|
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
|
|||||||
@@ -444,13 +444,13 @@ class PaliGemmaWithExpertModel(
|
|||||||
if image.dtype != torch.float32:
|
if image.dtype != torch.float32:
|
||||||
image = image.to(torch.float32)
|
image = image.to(torch.float32)
|
||||||
image_outputs = self.paligemma.model.get_image_features(image)
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
features = image_outputs.pooler_output
|
||||||
if features.dtype != out_dtype:
|
if features.dtype != out_dtype:
|
||||||
features = features.to(out_dtype)
|
features = features.to(out_dtype)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -666,8 +666,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Process language tokens
|
# Process language tokens
|
||||||
def lang_embed_func(lang_tokens):
|
def lang_embed_func(lang_tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
return lang_emb
|
||||||
return lang_emb * math.sqrt(lang_emb_dim)
|
|
||||||
|
|
||||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||||
embs.append(lang_emb)
|
embs.append(lang_emb)
|
||||||
@@ -748,16 +747,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
def forward(
|
def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) -> Tensor:
|
||||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
|
||||||
) -> Tensor:
|
|
||||||
"""Do a full training forward pass and compute the loss."""
|
"""Do a full training forward pass and compute the loss."""
|
||||||
if noise is None:
|
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
|
||||||
|
|
||||||
if time is None:
|
|
||||||
time = self.sample_time(actions.shape[0], actions.device)
|
|
||||||
|
|
||||||
time_expanded = time[:, None, None]
|
time_expanded = time[:, None, None]
|
||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
@@ -1292,8 +1283,11 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
|
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||||
|
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||||
|
|
||||||
# Truncate losses to actual action dimensions
|
# Truncate losses to actual action dimensions
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -728,14 +728,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor:
|
||||||
"""Do a full training forward pass and compute the loss."""
|
"""Do a full training forward pass and compute the loss."""
|
||||||
if noise is None:
|
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
|
||||||
|
|
||||||
if time is None:
|
|
||||||
time = self.sample_time(actions.shape[0], actions.device)
|
|
||||||
|
|
||||||
time_expanded = time[:, None, None]
|
time_expanded = time[:, None, None]
|
||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
@@ -1262,8 +1256,11 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
|
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||||
|
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||||
|
|
||||||
# Compute loss (no separate state needed for PI05)
|
# Compute loss (no separate state needed for PI05)
|
||||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time)
|
||||||
|
|
||||||
# Truncate losses to actual action dimensions
|
# Truncate losses to actual action dimensions
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||||
@@ -227,6 +226,7 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||||
if use_adarms[0]:
|
if use_adarms[0]:
|
||||||
text_config = self.paligemma.config.text_config
|
text_config = self.paligemma.config.text_config
|
||||||
|
del self.paligemma.model.language_model
|
||||||
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||||
|
|
||||||
self.to_bfloat16_for_selected_params(precision)
|
self.to_bfloat16_for_selected_params(precision)
|
||||||
@@ -260,13 +260,15 @@ class PI0FastPaliGemma(nn.Module):
|
|||||||
if image.dtype != torch.float32:
|
if image.dtype != torch.float32:
|
||||||
image = image.to(torch.float32)
|
image = image.to(torch.float32)
|
||||||
image_outputs = self.paligemma.model.get_image_features(image)
|
image_outputs = self.paligemma.model.get_image_features(image)
|
||||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
features = image_outputs.pooler_output
|
||||||
|
norm = 2048**0.5
|
||||||
|
features = features / norm * norm
|
||||||
if features.dtype != out_dtype:
|
if features.dtype != out_dtype:
|
||||||
features = features.to(out_dtype)
|
features = features.to(out_dtype)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -416,8 +418,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Process language instruction tokens
|
# Process language instruction tokens
|
||||||
def lang_embed_func(tokens):
|
def lang_embed_func(tokens):
|
||||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||||
lang_emb_dim = lang_emb.shape[-1]
|
return lang_emb
|
||||||
return lang_emb * math.sqrt(lang_emb_dim)
|
|
||||||
|
|
||||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||||
embs.append(lang_emb)
|
embs.append(lang_emb)
|
||||||
@@ -431,8 +432,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
def fast_action_embed_func(fast_action_tokens):
|
def fast_action_embed_func(fast_action_tokens):
|
||||||
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||||
fast_emb_dim = fast_emb.shape[-1]
|
return fast_emb
|
||||||
return fast_emb * math.sqrt(fast_emb_dim)
|
|
||||||
|
|
||||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||||
embs.append(fast_action_emb)
|
embs.append(fast_action_emb)
|
||||||
@@ -665,7 +665,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
if t < max_decoding_steps - 1:
|
if t < max_decoding_steps - 1:
|
||||||
# embed the newly generated token
|
# embed the newly generated token
|
||||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
|
||||||
if prefix_embs.dtype == torch.bfloat16:
|
if prefix_embs.dtype == torch.bfloat16:
|
||||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
@@ -770,7 +769,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
# Embed the single previous token
|
# Embed the single previous token
|
||||||
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
|
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
|
||||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
|
||||||
if prefix_embs.dtype == torch.bfloat16:
|
if prefix_embs.dtype == torch.bfloat16:
|
||||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
|||||||
@@ -197,6 +197,9 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
|||||||
|
|
||||||
def __init__(self, config: GemmaConfig, **kwargs):
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
|
# Free parent-allocated layers/norm before replacing to avoid ~2x peak memory.
|
||||||
|
del self.layers
|
||||||
|
del self.norm
|
||||||
# if not getattr(config, "use_adarms", False):
|
# if not getattr(config, "use_adarms", False):
|
||||||
# return
|
# return
|
||||||
cond_dim = getattr(config, "adarms_cond_dim", None)
|
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||||
@@ -328,6 +331,7 @@ class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
|||||||
|
|
||||||
def __init__(self, config: GemmaConfig, **kwargs):
|
def __init__(self, config: GemmaConfig, **kwargs):
|
||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
|
del self.model
|
||||||
self.model = PiGemmaModel(config)
|
self.model = PiGemmaModel(config)
|
||||||
|
|
||||||
|
|
||||||
@@ -336,6 +340,7 @@ class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
del self.language_model
|
||||||
self.language_model = PiGemmaModel(config.text_config)
|
self.language_model = PiGemmaModel(config.text_config)
|
||||||
|
|
||||||
|
|
||||||
@@ -344,6 +349,7 @@ class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGenera
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
del self.model
|
||||||
self.model = PaliGemmaModelWithPiGemma(config)
|
self.model = PaliGemmaModelWithPiGemma(config)
|
||||||
|
|
||||||
# Make modules available through conditional class for BC
|
# Make modules available through conditional class for BC
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from .action_queue import ActionQueue
|
|||||||
from .configuration_rtc import RTCConfig
|
from .configuration_rtc import RTCConfig
|
||||||
from .latency_tracker import LatencyTracker
|
from .latency_tracker import LatencyTracker
|
||||||
from .modeling_rtc import RTCProcessor
|
from .modeling_rtc import RTCProcessor
|
||||||
|
from .relative import reanchor_relative_rtc_prefix
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ActionInterpolator",
|
"ActionInterpolator",
|
||||||
@@ -26,4 +27,5 @@ __all__ = [
|
|||||||
"LatencyTracker",
|
"LatencyTracker",
|
||||||
"RTCConfig",
|
"RTCConfig",
|
||||||
"RTCProcessor",
|
"RTCProcessor",
|
||||||
|
"reanchor_relative_rtc_prefix",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,116 +1,4 @@
|
|||||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
|
||||||
#
|
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Action interpolation for smoother robot control.
|
__all__ = ["ActionInterpolator"]
|
||||||
|
|
||||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
|
||||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class ActionInterpolator:
|
|
||||||
"""Interpolates between consecutive actions for smoother control.
|
|
||||||
|
|
||||||
When enabled with multiplier N, produces N actions per policy action
|
|
||||||
by linearly interpolating between the previous and current action.
|
|
||||||
|
|
||||||
Example with multiplier=3:
|
|
||||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
|
||||||
|
|
||||||
This effectively multiplies the control rate for smoother motion.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
|
||||||
|
|
||||||
# In control loop:
|
|
||||||
if interpolator.needs_new_action():
|
|
||||||
new_action = queue.get()
|
|
||||||
if new_action:
|
|
||||||
interpolator.add(new_action.cpu())
|
|
||||||
|
|
||||||
action = interpolator.get()
|
|
||||||
if action:
|
|
||||||
robot.send_action(action)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, multiplier: int = 1):
|
|
||||||
"""Initialize the interpolator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
|
||||||
"""
|
|
||||||
if multiplier < 1:
|
|
||||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
|
||||||
self.multiplier = multiplier
|
|
||||||
self._prev: Tensor | None = None
|
|
||||||
self._buffer: list[Tensor] = []
|
|
||||||
self._idx = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def enabled(self) -> bool:
|
|
||||||
"""Whether interpolation is active (multiplier > 1)."""
|
|
||||||
return self.multiplier > 1
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
"""Reset interpolation state (call between episodes)."""
|
|
||||||
self._prev = None
|
|
||||||
self._buffer = []
|
|
||||||
self._idx = 0
|
|
||||||
|
|
||||||
def needs_new_action(self) -> bool:
|
|
||||||
"""Check if a new action is needed from the queue."""
|
|
||||||
return self._idx >= len(self._buffer)
|
|
||||||
|
|
||||||
def add(self, action: Tensor) -> None:
|
|
||||||
"""Add a new action and compute interpolated sequence.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: New action tensor from policy/queue (already on CPU).
|
|
||||||
"""
|
|
||||||
if self.multiplier > 1 and self._prev is not None:
|
|
||||||
self._buffer = []
|
|
||||||
for i in range(1, self.multiplier + 1):
|
|
||||||
t = i / self.multiplier
|
|
||||||
interp = self._prev + t * (action - self._prev)
|
|
||||||
self._buffer.append(interp)
|
|
||||||
else:
|
|
||||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
|
||||||
self._buffer = [action.clone()]
|
|
||||||
self._prev = action.clone()
|
|
||||||
self._idx = 0
|
|
||||||
|
|
||||||
def get(self) -> Tensor | None:
|
|
||||||
"""Get the next interpolated action.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Next action tensor, or None if buffer is exhausted.
|
|
||||||
"""
|
|
||||||
if self._idx >= len(self._buffer):
|
|
||||||
return None
|
|
||||||
action = self._buffer[self._idx]
|
|
||||||
self._idx += 1
|
|
||||||
return action
|
|
||||||
|
|
||||||
def get_control_interval(self, fps: float) -> float:
|
|
||||||
"""Get the control interval based on interpolation multiplier.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fps: Base frames per second.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Control interval in seconds (divided by multiplier).
|
|
||||||
"""
|
|
||||||
return 1.0 / (fps * self.multiplier)
|
|
||||||
|
|||||||
@@ -92,10 +92,10 @@ class ActionQueue:
|
|||||||
Returns:
|
Returns:
|
||||||
int: Number of unconsumed actions.
|
int: Number of unconsumed actions.
|
||||||
"""
|
"""
|
||||||
if self.queue is None:
|
with self.lock:
|
||||||
return 0
|
if self.queue is None:
|
||||||
length = len(self.queue)
|
return 0
|
||||||
return length - self.last_index
|
return len(self.queue) - self.last_index
|
||||||
|
|
||||||
def empty(self) -> bool:
|
def empty(self) -> bool:
|
||||||
"""Check if the queue is empty.
|
"""Check if the queue is empty.
|
||||||
@@ -103,11 +103,10 @@ class ActionQueue:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if no actions remain, False otherwise.
|
bool: True if no actions remain, False otherwise.
|
||||||
"""
|
"""
|
||||||
if self.queue is None:
|
with self.lock:
|
||||||
return True
|
if self.queue is None:
|
||||||
|
return True
|
||||||
length = len(self.queue)
|
return len(self.queue) - self.last_index <= 0
|
||||||
return length - self.last_index <= 0
|
|
||||||
|
|
||||||
def get_action_index(self) -> int:
|
def get_action_index(self) -> int:
|
||||||
"""Get the current action consumption index.
|
"""Get the current action consumption index.
|
||||||
@@ -115,7 +114,8 @@ class ActionQueue:
|
|||||||
Returns:
|
Returns:
|
||||||
int: Index of the next action to be consumed.
|
int: Index of the next action to be consumed.
|
||||||
"""
|
"""
|
||||||
return self.last_index
|
with self.lock:
|
||||||
|
return self.last_index
|
||||||
|
|
||||||
def get_left_over(self) -> Tensor | None:
|
def get_left_over(self) -> Tensor | None:
|
||||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class RTCConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Infrastructure
|
# Infrastructure
|
||||||
enabled: bool = False
|
enabled: bool = True
|
||||||
|
|
||||||
# Core RTC settings
|
# Core RTC settings
|
||||||
# Todo change to exp
|
# Todo change to exp
|
||||||
|
|||||||
58
src/lerobot/policies/rtc/relative.py
Normal file
58
src/lerobot/policies/rtc/relative.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Relative-action helpers for Real-Time Chunking (RTC)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.processor import (
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
RelativeActionsProcessorStep,
|
||||||
|
TransitionKey,
|
||||||
|
create_transition,
|
||||||
|
to_relative_actions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reanchor_relative_rtc_prefix(
|
||||||
|
prev_actions_absolute: torch.Tensor,
|
||||||
|
current_state: torch.Tensor,
|
||||||
|
relative_step: RelativeActionsProcessorStep,
|
||||||
|
normalizer_step: NormalizerProcessorStep | None,
|
||||||
|
policy_device: torch.device | str,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
|
||||||
|
|
||||||
|
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
|
||||||
|
is stored in absolute coordinates. Before feeding it back to the policy, this
|
||||||
|
helper re-expresses those actions relative to the robot's current joint state
|
||||||
|
and optionally normalizes them so the policy receives correctly scaled inputs.
|
||||||
|
"""
|
||||||
|
state = current_state.detach().cpu()
|
||||||
|
if state.dim() == 1:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
|
||||||
|
action_cpu = prev_actions_absolute.detach().cpu()
|
||||||
|
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||||
|
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||||
|
|
||||||
|
transition = create_transition(action=relative_actions)
|
||||||
|
if normalizer_step is not None:
|
||||||
|
transition = normalizer_step(transition)
|
||||||
|
|
||||||
|
return transition[TransitionKey.ACTION].to(policy_device)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user