mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
Compare commits
53 Commits
feat/lerob
...
codex/mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a23ebf9d35 | ||
|
|
bfff81fd4b | ||
|
|
929400cd44 | ||
|
|
fe78f8fee9 | ||
|
|
ce9bfa754d | ||
|
|
5adad11128 | ||
|
|
b86935c64b | ||
|
|
a2f72e42f6 | ||
|
|
a515eadc96 | ||
|
|
a07f22e22c | ||
|
|
282c31cfef | ||
|
|
a147fa4439 | ||
|
|
0f1c9b0851 | ||
|
|
e699e52388 | ||
|
|
b2765b39b8 | ||
|
|
777b808c70 | ||
|
|
8d982614a6 | ||
|
|
c8df80ae91 | ||
|
|
1ac8e96575 | ||
|
|
a6dd28e8b4 | ||
|
|
1842100402 | ||
|
|
00e9defb80 | ||
|
|
b81eef43c8 | ||
|
|
d483dd4c4b | ||
|
|
a56423fa33 | ||
|
|
da7da741f1 | ||
|
|
b1e16783de | ||
|
|
a4544ffea7 | ||
|
|
dbe01b0444 | ||
|
|
e16a95a78e | ||
|
|
4137b5785d | ||
|
|
8ece10e484 | ||
|
|
ddeb216ab9 | ||
|
|
d46d67f75d | ||
|
|
b746cd3c61 | ||
|
|
6d1a5fca02 | ||
|
|
8d7099cd7d | ||
|
|
516f39685a | ||
|
|
b27e838376 | ||
|
|
40470648d1 | ||
|
|
25e5062b2c | ||
|
|
35e3b28da1 | ||
|
|
ed8a98dda6 | ||
|
|
9dc38d9993 | ||
|
|
3922f81791 | ||
|
|
28e8483297 | ||
|
|
e1b22ed1c4 | ||
|
|
f2d0f04dd0 | ||
|
|
3ea722c6c0 | ||
|
|
48660e7a7c | ||
|
|
c94fe868c9 | ||
|
|
d4f27cfb6e | ||
|
|
1a2aec1b04 |
641
.github/workflows/benchmark_tests.yml
vendored
641
.github/workflows/benchmark_tests.yml
vendored
@@ -83,10 +83,13 @@ jobs:
|
||||
cache-binary: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: ${{ env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
|
||||
# Build the benchmark-specific image. The Dockerfile separates dep-install
|
||||
# from source-copy, so code-only changes skip the slow uv-sync layer
|
||||
@@ -115,7 +118,7 @@ jobs:
|
||||
bash -c "
|
||||
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
|
||||
lerobot-eval \
|
||||
--policy.path=pepijn223/smolvla_libero \
|
||||
--policy.path=lerobot/smolvla_libero \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
--eval.batch_size=1 \
|
||||
@@ -144,7 +147,7 @@ jobs:
|
||||
--artifacts-dir /tmp/libero-artifacts \
|
||||
--env libero \
|
||||
--task libero_spatial \
|
||||
--policy pepijn223/smolvla_libero
|
||||
--policy lerobot/smolvla_libero
|
||||
|
||||
- name: Upload Libero rollout video
|
||||
if: always()
|
||||
@@ -238,10 +241,13 @@ jobs:
|
||||
cache-binary: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: ${{ env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
|
||||
- name: Build MetaWorld benchmark image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
@@ -264,7 +270,7 @@ jobs:
|
||||
bash -c "
|
||||
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
|
||||
lerobot-eval \
|
||||
--policy.path=pepijn223/smolvla_metaworld \
|
||||
--policy.path=lerobot/smolvla_metaworld \
|
||||
--env.type=metaworld \
|
||||
--env.task=metaworld-push-v3 \
|
||||
--eval.batch_size=1 \
|
||||
@@ -293,7 +299,7 @@ jobs:
|
||||
--artifacts-dir /tmp/metaworld-artifacts \
|
||||
--env metaworld \
|
||||
--task metaworld-push-v3 \
|
||||
--policy pepijn223/smolvla_metaworld
|
||||
--policy lerobot/smolvla_metaworld
|
||||
|
||||
- name: Upload MetaWorld rollout video
|
||||
if: always()
|
||||
@@ -310,3 +316,630 @@ jobs:
|
||||
name: metaworld-metrics
|
||||
path: /tmp/metaworld-artifacts/metrics.json
|
||||
if-no-files-found: warn
|
||||
|
||||
# ── ROBOTWIN 2.0 ──────────────────────────────────────────────────────────
|
||||
# Isolated image: full RoboTwin 2.0 stack — SAPIEN, mplib, CuRobo,
|
||||
# pytorch3d, + simulation assets (~4 GB).
|
||||
# Build takes ~20 min on first run; subsequent runs hit the layer cache.
|
||||
# Requires an NVIDIA GPU runner with CUDA 12.1 drivers.
|
||||
robotwin-integration-test:
|
||||
name: RoboTwin 2.0 — build image + 1-episode eval
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
ROBOTWIN_POLICY: lerobot/smolvla_robotwin
|
||||
ROBOTWIN_TASKS: beat_block_hammer,click_bell,handover_block,stack_blocks_two,click_alarmclock,open_microwave,adjust_bottle,lift_pot,stamp_seal,turn_switch
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: ${{ env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
|
||||
# Build the full-install image: SAPIEN, mplib, CuRobo, pytorch3d +
|
||||
# simulation assets (~4 GB). Layer cache lives in the runner's local
|
||||
# Docker daemon — reused across re-runs on the same machine.
|
||||
- name: Build RoboTwin 2.0 benchmark image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile.benchmark.robotwin
|
||||
push: false
|
||||
load: true
|
||||
tags: lerobot-benchmark-robotwin:ci
|
||||
cache-from: type=local,src=/tmp/.buildx-cache-robotwin
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-robotwin,mode=max
|
||||
|
||||
- name: Run RoboTwin 2.0 smoke eval (10 tasks, 1 episode each)
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
# Named container (no --rm) so we can docker cp artifacts out.
|
||||
docker run --name robotwin-eval --gpus all \
|
||||
--shm-size=4g \
|
||||
-e HF_HOME=/tmp/hf \
|
||||
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
|
||||
-e ROBOTWIN_POLICY="${ROBOTWIN_POLICY}" \
|
||||
-e ROBOTWIN_TASKS="${ROBOTWIN_TASKS}" \
|
||||
lerobot-benchmark-robotwin:ci \
|
||||
bash -c "
|
||||
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
|
||||
cd /opt/robotwin && lerobot-eval \
|
||||
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
||||
--env.type=robotwin \
|
||||
--env.task=\"\$ROBOTWIN_TASKS\" \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={\"observation.images.head_camera\": \"observation.images.camera1\", \"observation.images.left_camera\": \"observation.images.camera2\", \"observation.images.right_camera\": \"observation.images.camera3\"}' \
|
||||
--output_dir=/tmp/eval-artifacts
|
||||
python /lerobot/scripts/ci/extract_task_descriptions.py \
|
||||
--env robotwin \
|
||||
--task \"\$ROBOTWIN_TASKS\" \
|
||||
--output /tmp/eval-artifacts/task_descriptions.json
|
||||
"
|
||||
|
||||
- name: Copy RoboTwin artifacts from container
|
||||
if: always()
|
||||
run: |
|
||||
mkdir -p /tmp/robotwin-artifacts
|
||||
docker cp robotwin-eval:/tmp/eval-artifacts/. /tmp/robotwin-artifacts/ 2>/dev/null || true
|
||||
docker rm -f robotwin-eval || true
|
||||
|
||||
- name: Parse RoboTwin eval metrics
|
||||
if: always()
|
||||
run: |
|
||||
python3 scripts/ci/parse_eval_metrics.py \
|
||||
--artifacts-dir /tmp/robotwin-artifacts \
|
||||
--env robotwin \
|
||||
--task "${ROBOTWIN_TASKS}" \
|
||||
--policy "${ROBOTWIN_POLICY}"
|
||||
|
||||
- name: Upload RoboTwin rollout video
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: robotwin-rollout-video
|
||||
path: /tmp/robotwin-artifacts/videos/
|
||||
if-no-files-found: warn
|
||||
|
||||
- name: Upload RoboTwin eval metrics
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: robotwin-metrics
|
||||
path: /tmp/robotwin-artifacts/metrics.json
|
||||
if-no-files-found: warn
|
||||
|
||||
# ── ROBOCASA365 ──────────────────────────────────────────────────────────
|
||||
# Isolated image: robocasa + robosuite installed manually as editable
|
||||
# clones (no `lerobot[robocasa]` extra — robocasa's setup.py pins
|
||||
# `lerobot==0.3.3`, which would shadow this repo's lerobot).
|
||||
robocasa-integration-test:
|
||||
name: RoboCasa365 — build image + 1-episode eval
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: ${{ env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
|
||||
- name: Build RoboCasa365 benchmark image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile.benchmark.robocasa
|
||||
push: false
|
||||
load: true
|
||||
tags: lerobot-benchmark-robocasa:ci
|
||||
|
||||
- name: Run RoboCasa365 smoke eval (10 atomic tasks, 1 episode each)
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
docker run --name robocasa-eval --gpus all \
|
||||
--shm-size=4g \
|
||||
-e HF_HOME=/tmp/hf \
|
||||
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
|
||||
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
|
||||
-e MUJOCO_GL=egl \
|
||||
lerobot-benchmark-robocasa:ci \
|
||||
bash -c "
|
||||
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_robocasa \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={\"observation.images.robot0_agentview_left\": \"observation.images.camera1\", \"observation.images.robot0_eye_in_hand\": \"observation.images.camera2\", \"observation.images.robot0_agentview_right\": \"observation.images.camera3\"}' \
|
||||
--output_dir=/tmp/eval-artifacts
|
||||
python scripts/ci/extract_task_descriptions.py \
|
||||
--env robocasa \
|
||||
--task CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
||||
--output /tmp/eval-artifacts/task_descriptions.json
|
||||
"
|
||||
|
||||
- name: Copy RoboCasa365 artifacts from container
|
||||
if: always()
|
||||
run: |
|
||||
mkdir -p /tmp/robocasa-artifacts
|
||||
docker cp robocasa-eval:/tmp/eval-artifacts/. /tmp/robocasa-artifacts/ 2>/dev/null || true
|
||||
docker rm -f robocasa-eval || true
|
||||
|
||||
- name: Parse RoboCasa365 eval metrics
|
||||
if: always()
|
||||
run: |
|
||||
python3 scripts/ci/parse_eval_metrics.py \
|
||||
--artifacts-dir /tmp/robocasa-artifacts \
|
||||
--env robocasa \
|
||||
--task atomic_smoke_10 \
|
||||
--policy lerobot/smolvla_robocasa
|
||||
|
||||
- name: Upload RoboCasa365 rollout video
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: robocasa-rollout-video
|
||||
path: /tmp/robocasa-artifacts/videos/
|
||||
if-no-files-found: warn
|
||||
|
||||
- name: Upload RoboCasa365 eval metrics
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: robocasa-metrics
|
||||
path: /tmp/robocasa-artifacts/metrics.json
|
||||
if-no-files-found: warn
|
||||
|
||||
# ── ROBOCEREBRA ───────────────────────────────────────────────────────────
|
||||
# Reuses the LIBERO simulator (libero_10 suite) with RoboCerebra camera
|
||||
# defaults (image/wrist_image). The image is layered on
|
||||
# huggingface/lerobot-gpu, which already ships [libero] as part of [all].
|
||||
robocerebra-integration-test:
|
||||
name: RoboCerebra — build image + 1-episode eval
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: ${{ env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
|
||||
- name: Build RoboCerebra benchmark image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile.benchmark.robocerebra
|
||||
push: false
|
||||
load: true
|
||||
tags: lerobot-benchmark-robocerebra:ci
|
||||
cache-from: type=local,src=/tmp/.buildx-cache-robocerebra
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-robocerebra,mode=max
|
||||
|
||||
- name: Run RoboCerebra smoke eval (1 episode)
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
docker run --name robocerebra-eval --gpus all \
|
||||
--shm-size=4g \
|
||||
-e HF_HOME=/tmp/hf \
|
||||
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
|
||||
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
|
||||
-e LIBERO_DATA_FOLDER=/tmp/libero_data \
|
||||
lerobot-benchmark-robocerebra:ci \
|
||||
bash -c "
|
||||
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_robocerebra \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.fps=20 \
|
||||
--env.obs_type=pixels_agent_pos \
|
||||
--env.observation_height=256 \
|
||||
--env.observation_width=256 \
|
||||
'--env.camera_name_mapping={\"agentview_image\": \"image\", \"robot0_eye_in_hand_image\": \"wrist_image\"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.wrist_image\": \"observation.images.camera2\"}' \
|
||||
--policy.empty_cameras=1 \
|
||||
--output_dir=/tmp/eval-artifacts
|
||||
python scripts/ci/extract_task_descriptions.py \
|
||||
--env libero --task libero_10 \
|
||||
--output /tmp/eval-artifacts/task_descriptions.json
|
||||
"
|
||||
|
||||
- name: Copy RoboCerebra artifacts from container
|
||||
if: always()
|
||||
run: |
|
||||
mkdir -p /tmp/robocerebra-artifacts
|
||||
docker cp robocerebra-eval:/tmp/eval-artifacts/. /tmp/robocerebra-artifacts/ 2>/dev/null || true
|
||||
docker rm -f robocerebra-eval || true
|
||||
|
||||
- name: Parse RoboCerebra eval metrics
|
||||
if: always()
|
||||
run: |
|
||||
python3 scripts/ci/parse_eval_metrics.py \
|
||||
--artifacts-dir /tmp/robocerebra-artifacts \
|
||||
--env robocerebra \
|
||||
--task libero_10 \
|
||||
--policy lerobot/smolvla_robocerebra
|
||||
|
||||
- name: Upload RoboCerebra rollout video
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: robocerebra-rollout-video
|
||||
path: /tmp/robocerebra-artifacts/videos/
|
||||
if-no-files-found: warn
|
||||
|
||||
- name: Upload RoboCerebra eval metrics
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: robocerebra-metrics
|
||||
path: /tmp/robocerebra-artifacts/metrics.json
|
||||
if-no-files-found: warn
|
||||
|
||||
# ── ROBOMME ───────────────────────────────────────────────────────────────
|
||||
# Isolated image: mani-skill/SAPIEN/Vulkan chain with gymnasium and numpy
|
||||
# overrides (robomme can't be a pyproject extra due to numpy<2 pin).
|
||||
robomme-integration-test:
|
||||
name: RoboMME — build image + 1-episode eval
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
ROBOMME_POLICY: lerobot/smolvla_robomme
|
||||
ROBOMME_TASKS: PickXtimes,BinFill,StopCube,MoveCube,InsertPeg,SwingXtimes,VideoUnmask,ButtonUnmask,PickHighlight,PatternLock
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: ${{ env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
|
||||
- name: Build RoboMME benchmark image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile.benchmark.robomme
|
||||
push: false
|
||||
load: true
|
||||
tags: lerobot-benchmark-robomme:ci
|
||||
|
||||
- name: Run RoboMME smoke eval (10 tasks, 1 episode each)
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
docker run --name robomme-eval --gpus all \
|
||||
--shm-size=4g \
|
||||
-e HF_HOME=/tmp/hf \
|
||||
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
|
||||
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
|
||||
-e ROBOMME_POLICY="${ROBOMME_POLICY}" \
|
||||
-e ROBOMME_TASKS="${ROBOMME_TASKS}" \
|
||||
lerobot-benchmark-robomme:ci \
|
||||
bash -c "
|
||||
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
|
||||
lerobot-eval \
|
||||
--policy.path=\"\$ROBOMME_POLICY\" \
|
||||
--env.type=robomme \
|
||||
--env.task=\"\$ROBOMME_TASKS\" \
|
||||
--env.dataset_split=test \
|
||||
--env.task_ids=[0] \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.wrist_image\": \"observation.images.camera2\"}' \
|
||||
--policy.empty_cameras=3 \
|
||||
--output_dir=/tmp/eval-artifacts
|
||||
python scripts/ci/extract_task_descriptions.py \
|
||||
--env robomme --task \"\$ROBOMME_TASKS\" \
|
||||
--output /tmp/eval-artifacts/task_descriptions.json
|
||||
"
|
||||
|
||||
- name: Copy RoboMME artifacts from container
|
||||
if: always()
|
||||
run: |
|
||||
mkdir -p /tmp/robomme-artifacts
|
||||
docker cp robomme-eval:/tmp/eval-artifacts/. /tmp/robomme-artifacts/ 2>/dev/null || true
|
||||
docker rm -f robomme-eval || true
|
||||
|
||||
- name: Parse RoboMME eval metrics
|
||||
if: always()
|
||||
run: |
|
||||
python3 scripts/ci/parse_eval_metrics.py \
|
||||
--artifacts-dir /tmp/robomme-artifacts \
|
||||
--env robomme \
|
||||
--task "${ROBOMME_TASKS}" \
|
||||
--policy "${ROBOMME_POLICY}"
|
||||
|
||||
- name: Upload RoboMME rollout video
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: robomme-rollout-video
|
||||
path: /tmp/robomme-artifacts/videos/
|
||||
if-no-files-found: warn
|
||||
|
||||
- name: Upload RoboMME eval metrics
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: robomme-metrics
|
||||
path: /tmp/robomme-artifacts/metrics.json
|
||||
if-no-files-found: warn
|
||||
|
||||
# ── LIBERO-plus ───────────────────────────────────────────────────────────
|
||||
# Isolated image: LIBERO-plus fork cloned into /home/user_lerobot on top of
|
||||
# huggingface/lerobot-gpu (see docker/Dockerfile.benchmark.libero_plus).
|
||||
libero-plus-integration-test:
|
||||
name: LIBERO-plus — build image + 1-episode eval
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
LIBERO_PLUS_SUITE: libero_spatial
|
||||
LIBERO_PLUS_POLICY: lerobot/smolvla_libero_plus
|
||||
LIBERO_PLUS_TASK_IDS: "[0,100,260,500,1000,1500,2000,2400]"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: ${{ env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
|
||||
- name: Build LIBERO-plus benchmark image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile.benchmark.libero_plus
|
||||
push: false
|
||||
load: true
|
||||
tags: lerobot-benchmark-libero-plus:ci
|
||||
cache-from: type=local,src=/tmp/.buildx-cache-libero-plus
|
||||
cache-to: type=local,dest=/tmp/.buildx-cache-libero-plus,mode=max
|
||||
|
||||
- name: Run LIBERO-plus smoke eval (1 episode)
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
docker run --name libero-plus-eval --gpus all \
|
||||
--shm-size=4g \
|
||||
-e HF_HOME=/tmp/hf \
|
||||
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
|
||||
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
|
||||
-e LIBERO_PLUS_SUITE="${LIBERO_PLUS_SUITE}" \
|
||||
-e LIBERO_PLUS_POLICY="${LIBERO_PLUS_POLICY}" \
|
||||
-e LIBERO_PLUS_TASK_IDS="${LIBERO_PLUS_TASK_IDS}" \
|
||||
lerobot-benchmark-libero-plus:ci \
|
||||
bash -c "
|
||||
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
|
||||
lerobot-eval \
|
||||
--policy.path=\"\$LIBERO_PLUS_POLICY\" \
|
||||
--env.type=libero_plus \
|
||||
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
||||
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--env.camera_name_mapping={\"agentview_image\": \"camera1\", \"robot0_eye_in_hand_image\": \"camera2\"}' \
|
||||
--policy.empty_cameras=1 \
|
||||
--output_dir=/tmp/eval-artifacts
|
||||
python scripts/ci/extract_task_descriptions.py \
|
||||
--env libero_plus --task \"\$LIBERO_PLUS_SUITE\" \
|
||||
--output /tmp/eval-artifacts/task_descriptions.json
|
||||
"
|
||||
|
||||
- name: Copy LIBERO-plus artifacts from container
|
||||
if: always()
|
||||
run: |
|
||||
mkdir -p /tmp/libero-plus-artifacts
|
||||
docker cp libero-plus-eval:/tmp/eval-artifacts/. /tmp/libero-plus-artifacts/ 2>/dev/null || true
|
||||
docker rm -f libero-plus-eval || true
|
||||
|
||||
- name: Parse LIBERO-plus eval metrics
|
||||
if: always()
|
||||
run: |
|
||||
python3 scripts/ci/parse_eval_metrics.py \
|
||||
--artifacts-dir /tmp/libero-plus-artifacts \
|
||||
--env libero_plus \
|
||||
--task "${LIBERO_PLUS_SUITE}" \
|
||||
--policy "${LIBERO_PLUS_POLICY}"
|
||||
|
||||
- name: Upload LIBERO-plus rollout video
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: libero-plus-rollout-video
|
||||
path: /tmp/libero-plus-artifacts/videos/
|
||||
if-no-files-found: warn
|
||||
|
||||
- name: Upload LIBERO-plus eval metrics
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: libero-plus-metrics
|
||||
path: /tmp/libero-plus-artifacts/metrics.json
|
||||
if-no-files-found: warn
|
||||
|
||||
# ── VLABENCH ─────────────────────────────────────────────────────────────
|
||||
# Isolated image: lerobot[vlabench] only (VLABench, mujoco==3.2.2, dm-control chain)
|
||||
vlabench-integration-test:
|
||||
name: VLABench — build image + 1-episode eval
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
if: ${{ env.DOCKERHUB_USERNAME != '' }}
|
||||
uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
env:
|
||||
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
|
||||
- name: Build VLABench benchmark image
|
||||
uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
context: .
|
||||
file: docker/Dockerfile.benchmark.vlabench
|
||||
push: false
|
||||
load: true
|
||||
tags: lerobot-benchmark-vlabench:ci
|
||||
build-args: |
|
||||
VLABENCH_ASSETS_REPO=lerobot/vlabench-assets
|
||||
|
||||
- name: Run VLABench smoke eval (10 tasks, 1 episode each)
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
docker run --name vlabench-eval --gpus all \
|
||||
--shm-size=4g \
|
||||
-e HF_HOME=/tmp/hf \
|
||||
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
|
||||
-e HF_HUB_DOWNLOAD_TIMEOUT=300 \
|
||||
-e MUJOCO_GL=egl \
|
||||
lerobot-benchmark-vlabench:ci \
|
||||
bash -c "
|
||||
hf auth login --token \"\$HF_USER_TOKEN\" --add-to-git-credential 2>/dev/null || true
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.second_image\": \"observation.images.camera2\", \"observation.images.wrist_image\": \"observation.images.camera3\"}' \
|
||||
--output_dir=/tmp/eval-artifacts
|
||||
python scripts/ci/extract_task_descriptions.py \
|
||||
--env vlabench \
|
||||
--task select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||
--output /tmp/eval-artifacts/task_descriptions.json
|
||||
"
|
||||
|
||||
- name: Copy VLABench artifacts from container
|
||||
if: always()
|
||||
run: |
|
||||
mkdir -p /tmp/vlabench-artifacts
|
||||
docker cp vlabench-eval:/tmp/eval-artifacts/. /tmp/vlabench-artifacts/ 2>/dev/null || true
|
||||
docker rm -f vlabench-eval || true
|
||||
|
||||
- name: Parse VLABench eval metrics
|
||||
if: always()
|
||||
run: |
|
||||
python3 scripts/ci/parse_eval_metrics.py \
|
||||
--artifacts-dir /tmp/vlabench-artifacts \
|
||||
--env vlabench \
|
||||
--task select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||
--policy lerobot/smolvla_vlabench
|
||||
|
||||
- name: Upload VLABench rollout video
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: vlabench-rollout-video
|
||||
path: /tmp/vlabench-artifacts/videos/
|
||||
if-no-files-found: warn
|
||||
|
||||
- name: Upload VLABench eval metrics
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: vlabench-metrics
|
||||
path: /tmp/vlabench-artifacts/metrics.json
|
||||
if-no-files-found: warn
|
||||
|
||||
237
.github/workflows/model_profiling.yml
vendored
Normal file
237
.github/workflows/model_profiling.yml
vendored
Normal file
@@ -0,0 +1,237 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: Model Profiling
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0 * * 0"
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- .github/workflows/model_profiling.yml
|
||||
- src/lerobot/configs/train.py
|
||||
- src/lerobot/scripts/lerobot_train.py
|
||||
- src/lerobot/utils/model_profiling.py
|
||||
- tests/test_model_profiling.py
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
git_ref:
|
||||
description: Git ref to profile when no commit SHA is provided
|
||||
required: false
|
||||
type: string
|
||||
default: main
|
||||
git_commit:
|
||||
description: Optional exact commit SHA to profile
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
policies:
|
||||
description: Optional comma-separated policy filter
|
||||
required: false
|
||||
type: string
|
||||
default: ""
|
||||
profile_mode:
|
||||
description: Torch profiler mode
|
||||
required: false
|
||||
type: choice
|
||||
options:
|
||||
- trace
|
||||
- summary
|
||||
default: trace
|
||||
publish_results:
|
||||
description: Publish results to the profiling dataset when a Hub token is available
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
results_repo:
|
||||
description: Dataset repo name or fully qualified repo id
|
||||
required: false
|
||||
type: string
|
||||
default: model-profiling-history
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event.inputs.git_commit || github.event.inputs.git_ref || github.ref_name || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
profile-models:
|
||||
name: Weekly Model Profiling
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
env:
|
||||
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
|
||||
PROFILE_MODE: ${{ github.event_name == 'pull_request' && 'summary' || github.event.inputs.profile_mode || 'trace' }}
|
||||
POLICY_FILTER: ${{ github.event_name == 'pull_request' && 'act,diffusion,pi0,pi05,smolvla,groot,xvla,wall_x' || github.event.inputs.policies || '' }}
|
||||
RESULTS_REPO: ${{ github.event.inputs.results_repo || 'model-profiling-history' }}
|
||||
SHOULD_PUBLISH: ${{ github.event_name == 'schedule' || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish_results == 'true') }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
ref: ${{ github.event.pull_request.head.sha || github.event.inputs.git_commit || github.event.inputs.git_ref || 'main' }}
|
||||
|
||||
- name: Pull GPU image
|
||||
run: docker pull huggingface/lerobot-gpu:latest
|
||||
|
||||
- name: Run model profiling
|
||||
env:
|
||||
HOST_GIT_COMMIT: ${{ github.event.pull_request.head.sha || github.event.inputs.git_commit || github.sha }}
|
||||
PROFILE_GIT_REF: ${{ github.head_ref || github.ref_name || github.event.inputs.git_ref || 'main' }}
|
||||
PROFILE_PR_NUMBER: ${{ github.event.pull_request.number || '' }}
|
||||
run: |
|
||||
set -eux
|
||||
mkdir -p profiling-results
|
||||
docker run --rm --gpus all \
|
||||
--user "$(id -u):$(id -g)" \
|
||||
--shm-size=16g \
|
||||
-e HOME=/tmp/lerobot-home \
|
||||
-e HF_HOME=/tmp/hf \
|
||||
-e HF_LEROBOT_HOME=/tmp/hf-lerobot \
|
||||
-e TORCH_HOME=/tmp/torch-home \
|
||||
-e TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor-cache \
|
||||
-e UV_PROJECT_ENVIRONMENT=/tmp/lerobot-venv \
|
||||
-e UV_CACHE_DIR=/tmp/uv-cache \
|
||||
-e UV_PYTHON_PREFERENCE=only-system \
|
||||
-e XDG_DATA_HOME=/tmp/xdg-data \
|
||||
-e XDG_CACHE_HOME=/tmp/xdg-cache \
|
||||
-e HOST_GIT_COMMIT="${HOST_GIT_COMMIT}" \
|
||||
-e PROFILE_GIT_REF="${PROFILE_GIT_REF}" \
|
||||
-e PROFILE_PR_NUMBER="${PROFILE_PR_NUMBER}" \
|
||||
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
|
||||
-e HF_TOKEN="${HF_USER_TOKEN}" \
|
||||
-e PROFILE_MODE="${PROFILE_MODE}" \
|
||||
-e POLICY_FILTER="${POLICY_FILTER}" \
|
||||
-e RESULTS_REPO="${RESULTS_REPO}" \
|
||||
-e SHOULD_PUBLISH="${SHOULD_PUBLISH}" \
|
||||
-v "${GITHUB_WORKSPACE}:/workspace" \
|
||||
-w /workspace \
|
||||
huggingface/lerobot-gpu:latest \
|
||||
bash -c '
|
||||
set -euxo pipefail
|
||||
mkdir -p "${HOME}" "${HF_HOME}" "${HF_LEROBOT_HOME}" "${TORCH_HOME}" "${UV_CACHE_DIR}" "${XDG_CACHE_HOME}" "${XDG_DATA_HOME}" "${TORCHINDUCTOR_CACHE_DIR}"
|
||||
rm -rf /tmp/lerobot-src
|
||||
cp -a /workspace/. /tmp/lerobot-src
|
||||
cd /tmp/lerobot-src
|
||||
|
||||
if [[ -n "${HF_USER_TOKEN:-}" ]]; then
|
||||
hf auth login --token "${HF_USER_TOKEN}" --add-to-git-credential 2>/dev/null || true
|
||||
fi
|
||||
|
||||
policies_to_run=()
|
||||
if [[ -n "${POLICY_FILTER}" ]]; then
|
||||
IFS="," read -ra policies_to_run <<< "${POLICY_FILTER}"
|
||||
else
|
||||
policies_to_run=(act diffusion groot multi_task_dit pi0 pi0_fast pi05 smolvla wall_x xvla)
|
||||
fi
|
||||
|
||||
policy_extras() {
|
||||
case "$1" in
|
||||
act) ;;
|
||||
diffusion) echo "diffusion" ;;
|
||||
groot) echo "groot" ;;
|
||||
multi_task_dit) echo "multi_task_dit" ;;
|
||||
pi0|pi0_fast|pi05) echo "pi" ;;
|
||||
smolvla) echo "smolvla" ;;
|
||||
wall_x) echo "wallx" ;;
|
||||
xvla) echo "xvla" ;;
|
||||
*)
|
||||
echo "Unknown profiling policy $1" >&2
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Policies whose dep-install may fail due to environment constraints
|
||||
# (e.g. groot requires compiling flash-attn, which needs nvcc; the CI
|
||||
# image only ships the CUDA runtime). Install failures for these are
|
||||
# logged as warnings and do not fail the job. See the TODO next to
|
||||
# `lerobot[groot]` in pyproject.toml.
|
||||
is_install_failure_tolerated() {
|
||||
case "$1" in
|
||||
groot) return 0 ;;
|
||||
*) return 1 ;;
|
||||
esac
|
||||
}
|
||||
|
||||
overall_status=0
|
||||
for raw_policy in "${policies_to_run[@]}"; do
|
||||
policy="$(echo "${raw_policy}" | xargs)"
|
||||
[[ -z "${policy}" ]] && continue
|
||||
|
||||
echo "::group::Profile ${policy}"
|
||||
|
||||
extra="$(policy_extras "${policy}")" || { overall_status=1; echo "::endgroup::"; continue; }
|
||||
|
||||
# Fresh, isolated dependency resolution per policy so that
|
||||
# incompatible extras (e.g. flash-attn for groot) never block
|
||||
# the rest of the matrix.
|
||||
sync_cmd=(uv sync --locked --extra training --extra test)
|
||||
if [[ -n "${extra}" ]]; then
|
||||
sync_cmd+=(--extra "${extra}")
|
||||
fi
|
||||
# flash-attn does not declare torch as a build-time dep, so its
|
||||
# isolated build env fails with ModuleNotFoundError. Torch is a
|
||||
# core lerobot dep and is already resolved here, so we disable
|
||||
# build isolation for flash-attn specifically.
|
||||
sync_cmd+=(--no-build-isolation-package flash-attn)
|
||||
if ! "${sync_cmd[@]}"; then
|
||||
if is_install_failure_tolerated "${policy}"; then
|
||||
echo "::warning::Dependency install failed for ${policy} (known-fragile); skipping."
|
||||
else
|
||||
echo "Dependency install failed for ${policy}; skipping." >&2
|
||||
overall_status=1
|
||||
fi
|
||||
echo "::endgroup::"
|
||||
continue
|
||||
fi
|
||||
|
||||
cmd=(
|
||||
uv run python -m lerobot.utils.model_profiling
|
||||
--output_dir=/workspace/profiling-results
|
||||
--hub_org=lerobot
|
||||
--results_repo="${RESULTS_REPO}"
|
||||
--profile_mode="${PROFILE_MODE}"
|
||||
--git_commit="${HOST_GIT_COMMIT}"
|
||||
--git_ref="${PROFILE_GIT_REF}"
|
||||
--pr_number="${PROFILE_PR_NUMBER}"
|
||||
--policies "${policy}"
|
||||
)
|
||||
if [[ "${SHOULD_PUBLISH}" == "true" && -n "${HF_USER_TOKEN:-}" ]]; then
|
||||
cmd+=(--publish)
|
||||
fi
|
||||
|
||||
if ! "${cmd[@]}"; then
|
||||
echo "Profiling failed for ${policy}." >&2
|
||||
overall_status=1
|
||||
fi
|
||||
|
||||
echo "::endgroup::"
|
||||
done
|
||||
|
||||
exit "${overall_status}"
|
||||
'
|
||||
|
||||
- name: Upload profiling artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
|
||||
with:
|
||||
name: model-profiling-results
|
||||
path: profiling-results
|
||||
if-no-files-found: warn
|
||||
84
docker/Dockerfile.benchmark.libero_plus
Normal file
84
docker/Dockerfile.benchmark.libero_plus
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Benchmark image for LIBERO-plus integration tests.
|
||||
# Extends the nightly GPU image (which has lerobot[all]) with the LIBERO-plus
|
||||
# fork source + its 6.4 GB perturbation assets.
|
||||
#
|
||||
# Build: docker build -f docker/Dockerfile.benchmark.libero_plus -t lerobot-benchmark-libero-plus .
|
||||
# Run: docker run --gpus all --rm lerobot-benchmark-libero-plus lerobot-eval ...
|
||||
|
||||
FROM huggingface/lerobot-gpu:latest
|
||||
ENV MUJOCO_GL=egl
|
||||
|
||||
# unzip for the 6.4 GB assets.zip; the rest are LIBERO-plus build-time extras
|
||||
# (wand / ImageMagick / fontconfig) not in the nightly base.
|
||||
USER root
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
unzip libexpat1 libfontconfig1-dev libmagickwand-dev \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
USER user_lerobot
|
||||
|
||||
# robosuite==1.4.1 is mandatory (the fork uses `single_arm_env` removed in
|
||||
# v1.5+). The rest are LIBERO-plus runtime deps pulled from its setup.py.
|
||||
# We install these explicitly instead of via the [libero_plus] extra because
|
||||
# the extra's `libero @ git+...` dep installs as a namespace package and then
|
||||
# clone and PYTHONPATH-override it below.
|
||||
RUN uv pip install --no-cache \
|
||||
"robosuite==1.4.1" \
|
||||
"bddl==1.0.1" \
|
||||
"easydict==1.13" \
|
||||
"mujoco==3.7.0" \
|
||||
"matplotlib==3.10.8" \
|
||||
"Wand==0.6.13" \
|
||||
"scikit-image==0.25.2" \
|
||||
"gym==0.26.2"
|
||||
|
||||
# Clone LIBERO-plus and make it importable as `libero`. The nightly base has
|
||||
# hf-libero (10 tasks) preinstalled via lerobot[libero]; uninstall it so
|
||||
# Python resolves `import libero` to the 2402-task LIBERO-plus module instead.
|
||||
# Pinned to the current upstream main SHA so benchmark builds stay reproducible.
|
||||
ARG LIBERO_PLUS_SHA=4976dc3
|
||||
ENV LIBERO_PLUS_ROOT=/home/user_lerobot/libero-plus/libero/libero
|
||||
RUN git clone https://github.com/sylvestf/LIBERO-plus.git /home/user_lerobot/libero-plus \
|
||||
&& git -C /home/user_lerobot/libero-plus checkout ${LIBERO_PLUS_SHA} \
|
||||
&& cd /home/user_lerobot/libero-plus && uv pip install --no-cache --no-deps -e "." \
|
||||
&& (uv pip uninstall hf-libero 2>/dev/null || true)
|
||||
ENV PYTHONPATH="/home/user_lerobot/libero-plus:${PYTHONPATH}"
|
||||
|
||||
# Perturbation textures/scenes: bddl_base_domain.py resolves XMLs via
|
||||
# DIR_PATH/../assets (package-relative, ignoring ~/.libero/config.yaml). All
|
||||
# 2402 tasks reference files that ship only in Sylvest/LIBERO-plus's
|
||||
# assets.zip (6.4 GB) under a deep author-internal prefix — extract and
|
||||
# flatten it under ${LIBERO_PLUS_ROOT}/assets.
|
||||
RUN python -c "\
|
||||
from huggingface_hub import hf_hub_download; \
|
||||
hf_hub_download(repo_id='Sylvest/LIBERO-plus', repo_type='dataset', \
|
||||
filename='assets.zip', local_dir='/tmp/libero-plus-dl')" \
|
||||
&& unzip -q /tmp/libero-plus-dl/assets.zip -d /tmp/libero-plus-dl/extract \
|
||||
&& ASSETS_DIR=$(find /tmp/libero-plus-dl/extract -type d -name assets | head -1) \
|
||||
&& mv "${ASSETS_DIR}" ${LIBERO_PLUS_ROOT}/assets \
|
||||
&& rm -rf /tmp/libero-plus-dl
|
||||
|
||||
# Point ~/.libero/config.yaml at the clone so LIBERO-plus's imports are
|
||||
# non-interactive (it calls input() when the config is missing).
|
||||
RUN mkdir -p /home/user_lerobot/.libero \
|
||||
&& printf "assets: ${LIBERO_PLUS_ROOT}/assets\nbddl_files: ${LIBERO_PLUS_ROOT}/bddl_files\ndatasets: ${LIBERO_PLUS_ROOT}/../datasets\ninit_states: ${LIBERO_PLUS_ROOT}/init_files\n" \
|
||||
> /home/user_lerobot/.libero/config.yaml
|
||||
|
||||
# Overlay the PR's source code on top of the nightly image.
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
71
docker/Dockerfile.benchmark.robocasa
Normal file
71
docker/Dockerfile.benchmark.robocasa
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Benchmark image for RoboCasa365 integration tests.
|
||||
# Extends the nightly GPU image (which already has all extras installed)
|
||||
# with the PR's source code and RoboCasa-specific asset setup.
|
||||
#
|
||||
# Build: docker build -f docker/Dockerfile.benchmark.robocasa -t lerobot-benchmark-robocasa .
|
||||
# Run: docker run --gpus all --rm lerobot-benchmark-robocasa lerobot-eval ...
|
||||
|
||||
FROM huggingface/lerobot-gpu:latest
|
||||
|
||||
# Install robocasa + robosuite as editable clones. pip-installing from git
|
||||
# omits data files like robocasa/models/assets/box_links/box_links_assets.json
|
||||
# (not declared in package_data), which download_kitchen_assets needs at import.
|
||||
#
|
||||
# `--no-deps` on robocasa is deliberate: its setup.py pins `lerobot==0.3.3`
|
||||
# in install_requires, which would shadow the editable lerobot baked into
|
||||
# this image. We install robocasa's actual runtime deps explicitly instead.
|
||||
# Pinned SHAs for reproducible benchmark runs. Bump when you need an
|
||||
# upstream fix; don't rely on `main`/`master` drift.
|
||||
ARG ROBOCASA_SHA=56e355ccc64389dfc1b8a61a33b9127b975ba681
|
||||
ARG ROBOSUITE_SHA=aaa8b9b214ce8e77e82926d677b4d61d55e577ab
|
||||
RUN git clone https://github.com/robocasa/robocasa.git ~/robocasa && \
|
||||
git -C ~/robocasa checkout ${ROBOCASA_SHA} && \
|
||||
git clone https://github.com/ARISE-Initiative/robosuite.git ~/robosuite && \
|
||||
git -C ~/robosuite checkout ${ROBOSUITE_SHA} && \
|
||||
uv pip install --no-cache -e ~/robocasa --no-deps && \
|
||||
uv pip install --no-cache -e ~/robosuite && \
|
||||
uv pip install --no-cache \
|
||||
"numpy==2.2.5" "numba==0.61.2" "scipy==1.15.3" "mujoco==3.3.1" \
|
||||
"pygame==2.6.1" "Pillow==12.2.0" "opencv-python==4.13.0.92" \
|
||||
"pyyaml==6.0.3" "pynput==1.8.1" "tqdm==4.67.3" "termcolor==3.3.0" \
|
||||
"imageio==2.37.3" "h5py==3.16.0" "lxml==6.0.4" "hidapi==0.14.0.post4" \
|
||||
"tianshou==0.4.10" "gymnasium==1.2.3"
|
||||
|
||||
# Set up robocasa macros and download kitchen assets. We need:
|
||||
# - tex : base environment textures
|
||||
# - tex_generative : AI-generated textures; kitchen fixture XMLs embed
|
||||
# refs to generative_textures/wall/tex*.png
|
||||
# unconditionally, so MjModel.from_xml_string fails
|
||||
# at reset time without them (even if the env is
|
||||
# constructed with generative_textures=None).
|
||||
# - fixtures_lw : lightwheel kitchen fixtures (fridge, counters...)
|
||||
# - objs_lw : lightwheel object meshes (stools, misc props)
|
||||
# We skip the objaverse/aigen object packs (~30GB combined) by pairing
|
||||
# this with --env.obj_registries=["lightwheel"] on the lerobot side.
|
||||
# The download script prompts interactively, so pipe 'y' to auto-accept.
|
||||
RUN python -m robocasa.scripts.setup_macros && \
|
||||
yes y | python -m robocasa.scripts.download_kitchen_assets \
|
||||
--type tex tex_generative fixtures_lw objs_lw
|
||||
|
||||
# Overlay the PR's source code on top of the nightly image.
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
# Re-install lerobot editably so the new source (with RoboCasaEnv registration)
|
||||
# replaces the stale package baked into the nightly image.
|
||||
RUN uv pip install --no-cache --no-deps -e .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
43
docker/Dockerfile.benchmark.robocerebra
Normal file
43
docker/Dockerfile.benchmark.robocerebra
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Benchmark image for RoboCerebra integration tests.
|
||||
# RoboCerebra reuses LIBERO's simulator (libero_10 suite) with a different
|
||||
# rename_map, so this image is identical to the LIBERO benchmark image —
|
||||
# extends the nightly GPU base with LIBERO assets + the PR's source code.
|
||||
#
|
||||
# Build: docker build -f docker/Dockerfile.benchmark.robocerebra -t lerobot-benchmark-robocerebra .
|
||||
# Run: docker run --gpus all --rm lerobot-benchmark-robocerebra lerobot-eval ...
|
||||
|
||||
FROM huggingface/lerobot-gpu:latest
|
||||
|
||||
# Pre-download lerobot/libero-assets from HF Hub so nothing is fetched at
|
||||
# runtime (which times out on CI). Point the libero config at the cached path.
|
||||
# libero/libero/__init__.py calls input() when ~/.libero/config.yaml is missing,
|
||||
# so we write the config before any libero import can happen.
|
||||
RUN LIBERO_DIR=$(python -c \
|
||||
"import importlib.util, os; s=importlib.util.find_spec('libero'); \
|
||||
print(os.path.join(os.path.dirname(s.origin), 'libero'))") && \
|
||||
mkdir -p /home/user_lerobot/.libero && \
|
||||
python -c "\
|
||||
from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='lerobot/libero-assets', repo_type='dataset', \
|
||||
local_dir='/home/user_lerobot/.libero/assets')" && \
|
||||
printf "assets: /home/user_lerobot/.libero/assets\nbddl_files: ${LIBERO_DIR}/bddl_files\ndatasets: ${LIBERO_DIR}/../datasets\ninit_states: ${LIBERO_DIR}/init_files\n" \
|
||||
> /home/user_lerobot/.libero/config.yaml
|
||||
|
||||
# Overlay the PR's source code on top of the nightly image.
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
56
docker/Dockerfile.benchmark.robomme
Normal file
56
docker/Dockerfile.benchmark.robomme
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Benchmark image for RoboMME integration tests.
|
||||
# Extends the nightly GPU image (which has lerobot[all]) with Vulkan system
|
||||
# libs for ManiSkill/SAPIEN and the robomme extra. robomme isn't in [all]
|
||||
# because mani-skill hard-pins gymnasium==0.29.1 and numpy<2.0.0 which
|
||||
# conflict with lerobot's defaults; both are safe at runtime:
|
||||
# - gymnasium 0.29.x has the same 5-tuple step() API as 1.x (since 0.26)
|
||||
# - numpy 1.26.4 is API-compatible with lerobot's actual usage.
|
||||
#
|
||||
# Build: docker build -f docker/Dockerfile.benchmark.robomme -t lerobot-benchmark-robomme .
|
||||
# Run: docker run --gpus all --rm lerobot-benchmark-robomme lerobot-eval ...
|
||||
|
||||
FROM huggingface/lerobot-gpu:latest
|
||||
|
||||
# NVIDIA Container Toolkit: expose Vulkan driver capability for headless rendering.
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=all \
|
||||
VK_ICD_FILENAMES=/usr/share/vulkan/icd.d/nvidia_icd.json
|
||||
|
||||
# ManiSkill/SAPIEN's renderer needs Vulkan, which isn't in the base image.
|
||||
USER root
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
libvulkan1 libvulkan-dev mesa-vulkan-drivers \
|
||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||
> /usr/share/vulkan/icd.d/nvidia_icd.json \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
USER user_lerobot
|
||||
|
||||
# Install smolvla + av-dep via the PR's pyproject, then layer robomme on top
|
||||
# with gymnasium/numpy overrides. robomme isn't a pyproject extra because its
|
||||
# mani-skill pin conflicts with lerobot's base numpy>=2 (see pyproject.toml).
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml uv.lock README.md MANIFEST.in ./
|
||||
RUN printf 'gymnasium==0.29.1\nnumpy==1.26.4\n' > /tmp/robomme_override.txt \
|
||||
&& uv pip install --no-cache --override /tmp/robomme_override.txt \
|
||||
-e ".[smolvla,av-dep]" \
|
||||
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main" \
|
||||
&& python -c "import robomme; print('robomme import OK')"
|
||||
|
||||
# Overlay the PR's source code on top of the nightly image.
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
122
docker/Dockerfile.benchmark.robotwin
Normal file
122
docker/Dockerfile.benchmark.robotwin
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Benchmark image for RoboTwin 2.0 integration tests.
|
||||
# Extends the nightly GPU image with the RoboTwin simulator stack:
|
||||
# sapien/mplib/pytorch3d + NVlabs CuRobo + embodiments.zip + objects.zip
|
||||
# (~3.96 GB of assets; background_texture.zip ~11 GB skipped for smoke eval).
|
||||
#
|
||||
# Build: docker build -f docker/Dockerfile.benchmark.robotwin -t lerobot-benchmark-robotwin .
|
||||
# Run: docker run --gpus all --rm lerobot-benchmark-robotwin \
|
||||
# lerobot-eval --env.type=robotwin --env.task=beat_block_hammer ...
|
||||
|
||||
FROM huggingface/lerobot-gpu:latest
|
||||
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=all \
|
||||
VK_ICD_FILENAMES=/usr/share/vulkan/icd.d/nvidia_icd.json \
|
||||
ROBOTWIN_ROOT=/opt/robotwin
|
||||
|
||||
# The nightly base is CUDA -base (no compiler, no Vulkan loader). CuRobo's
|
||||
# `pip install -e .` runs nvcc, and SAPIEN renders via Vulkan — add both.
|
||||
USER root
|
||||
# Pinned upstream SHA for reproducible benchmark runs. Bump when we need
|
||||
# an upstream fix; don't rely on `main` drift.
|
||||
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||
> /usr/share/vulkan/icd.d/nvidia_icd.json \
|
||||
&& git clone https://github.com/RoboTwin-Platform/RoboTwin.git ${ROBOTWIN_ROOT} \
|
||||
&& git -C ${ROBOTWIN_ROOT} checkout ${ROBOTWIN_SHA} \
|
||||
&& chown -R user_lerobot:user_lerobot ${ROBOTWIN_ROOT} \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
USER user_lerobot
|
||||
|
||||
# RoboTwin runtime deps (av is already in the base via [av-dep]).
|
||||
RUN uv pip install --no-cache \
|
||||
"sapien==3.0.0b1" "mplib==0.2.1" "transforms3d==0.4.2" "trimesh==4.4.3" \
|
||||
"open3d==0.19.0" "imageio==2.34.2" termcolor zarr pydantic h5py
|
||||
|
||||
# pytorch3d has no universal wheel; must be built from source (~10 min, cached).
|
||||
RUN uv pip install --no-cache --no-build-isolation \
|
||||
"git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
||||
|
||||
# CuRobo — NVlabs motion generator; TORCH_CUDA_ARCH_LIST must be set or the
|
||||
# build aborts on an empty arch list. Pinned SHA for reproducibility.
|
||||
ARG CUROBO_SHA=ca941586c33b8482ed9c0e74d60f23efd64b516a
|
||||
RUN cd ${ROBOTWIN_ROOT}/envs \
|
||||
&& git clone https://github.com/NVlabs/curobo.git \
|
||||
&& git -C curobo checkout ${CUROBO_SHA} \
|
||||
&& cd curobo \
|
||||
&& TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9;9.0" \
|
||||
uv pip install -e . --no-build-isolation --no-cache
|
||||
|
||||
# Upstream patches (mirror RoboTwin's script/_install.sh).
|
||||
# These patches target the exact versions pinned above; re-check when upgrading.
|
||||
# mplib==0.2.1: drop a broken `or collide` clause in planner.py.
|
||||
# Safe to remove once mplib > 0.2.1 ships with the fix upstream.
|
||||
# sapien==3.0.0b1: fix URDF loader encoding + .srdf extension check.
|
||||
# Safe to remove once sapien > 3.0.0b1 ships with the fix upstream.
|
||||
RUN python - <<'EOF'
|
||||
import pathlib, re, site
|
||||
for d in site.getsitepackages():
|
||||
p = pathlib.Path(d) / "mplib" / "planner.py"
|
||||
if p.exists():
|
||||
p.write_text(re.sub(r"\bor collide\b", "", p.read_text(), count=1))
|
||||
print(f"mplib patch applied: {p}")
|
||||
p = pathlib.Path(d) / "sapien" / "wrapper" / "urdf_loader.py"
|
||||
if p.exists():
|
||||
src = p.read_text().replace(
|
||||
"with open(srdf_path) as f:", 'with open(srdf_path, encoding="utf-8") as f:'
|
||||
).replace('"srdf"', '".srdf"')
|
||||
p.write_text(src)
|
||||
print(f"sapien patch applied: {p}")
|
||||
EOF
|
||||
|
||||
# Simulation assets from TianxingChen/RoboTwin2.0: embodiments (~220 MB) +
|
||||
# objects (~3.74 GB). background_texture (~11 GB) is intentionally skipped.
|
||||
# The dataset is public — no auth token needed.
|
||||
RUN python - <<'EOF'
|
||||
import os, pathlib, zipfile
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
assets_dir = pathlib.Path(os.environ["ROBOTWIN_ROOT"]) / "assets"
|
||||
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||
for fname in ("embodiments.zip", "objects.zip"):
|
||||
local = hf_hub_download(
|
||||
repo_id="TianxingChen/RoboTwin2.0",
|
||||
repo_type="dataset",
|
||||
filename=fname,
|
||||
local_dir=str(assets_dir),
|
||||
)
|
||||
with zipfile.ZipFile(local, "r") as z:
|
||||
z.extractall(str(assets_dir))
|
||||
pathlib.Path(local).unlink()
|
||||
EOF
|
||||
|
||||
WORKDIR ${ROBOTWIN_ROOT}
|
||||
RUN python script/update_embodiment_config_path.py
|
||||
|
||||
ENV PYTHONPATH="${ROBOTWIN_ROOT}:${PYTHONPATH}"
|
||||
|
||||
# Return to the lerobot source directory (set by base image) before overlaying.
|
||||
WORKDIR /lerobot
|
||||
|
||||
# Overlay the PR's source code on top of the nightly image.
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
99
docker/Dockerfile.benchmark.vlabench
Normal file
99
docker/Dockerfile.benchmark.vlabench
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Benchmark image for VLABench integration tests.
|
||||
# Extends the nightly GPU image with the PR's source code and VLABench setup.
|
||||
#
|
||||
# Build: docker build -f docker/Dockerfile.benchmark.vlabench -t lerobot-benchmark-vlabench .
|
||||
# Run: docker run --gpus all --rm lerobot-benchmark-vlabench lerobot-eval ...
|
||||
|
||||
FROM huggingface/lerobot-gpu:latest
|
||||
|
||||
# Install VLABench from GitHub (not on PyPI) and pin MuJoCo/dm-control.
|
||||
# Shallow-clone without submodule recursion (nested SSH-only submodules fail in CI).
|
||||
# Editable install (-e) because VLABench/utils/ has no __init__.py, so
|
||||
# find_packages() omits it from wheels; editable mode uses the source tree directly.
|
||||
# rrt-algorithms has the same packaging issue (rrt/ dir missing __init__.py).
|
||||
# Patch: constant.py calls os.listdir on ~100 asset/obj/meshes/* dirs at import
|
||||
# time. Guard the call so missing dirs return [] instead of crashing (in case
|
||||
# the asset download is partial).
|
||||
#
|
||||
# Pinned upstream SHAs for reproducible benchmark runs. Bump when you need
|
||||
# an upstream fix; don't rely on `main`/`develop` drift.
|
||||
ARG VLABENCH_SHA=cf588fe60c0c7282174fe979f5913170cfe69017
|
||||
ARG RRT_ALGORITHMS_SHA=e51d95ee489a225220d6ae2a764c4111f6ba7d85
|
||||
RUN git clone https://github.com/OpenMOSS/VLABench.git ~/VLABench && \
|
||||
git -C ~/VLABench checkout ${VLABENCH_SHA} && \
|
||||
git clone https://github.com/motion-planning/rrt-algorithms.git ~/rrt-algorithms && \
|
||||
git -C ~/rrt-algorithms checkout ${RRT_ALGORITHMS_SHA} && \
|
||||
python3 -c "\
|
||||
import pathlib; \
|
||||
p = pathlib.Path.home() / 'VLABench/VLABench/configs/constant.py'; \
|
||||
t = p.read_text(); \
|
||||
p.write_text(t.replace( \
|
||||
'subdirs = os.listdir(xml_dir)', \
|
||||
'if not os.path.isdir(xml_dir): return []\n subdirs = os.listdir(xml_dir)'))" && \
|
||||
uv pip install --no-cache -e ~/VLABench -e ~/rrt-algorithms \
|
||||
mujoco==3.2.2 dm-control==1.0.22 \
|
||||
open3d colorlog scikit-learn openai gdown
|
||||
|
||||
# Download VLABench mesh assets. Task configs reference object meshes
|
||||
# (obj/meshes/fruit/, containers/basket/, tablewares/plates/, etc.); without
|
||||
# them the task builder picks from an empty mesh list and crashes with
|
||||
# IndexError at task-build time (random.choice([]) in config_manager.py).
|
||||
#
|
||||
# Preferred source: an HF Hub mirror. Set VLABENCH_ASSETS_REPO at build time
|
||||
# (e.g. --build-arg VLABENCH_ASSETS_REPO=lerobot/vlabench-assets) and we'll
|
||||
# snapshot_download the repo into VLABench's assets dir. This is the reliable
|
||||
# path for CI — Google Drive frequently returns HTTP 429 ("Too many users have
|
||||
# viewed or downloaded this file recently") on shared academic files.
|
||||
#
|
||||
# After download we *validate* that at least one XML exists under each
|
||||
# task-critical subtree and fail the build loudly if not. Silent-empty asset
|
||||
# dirs are the #1 cause of VLABench runtime crashes in CI, so we surface them
|
||||
# here rather than after a 10-minute eval build.
|
||||
#
|
||||
# Fallback: VLABench's own gdown-based script. Best-effort only.
|
||||
ARG VLABENCH_ASSETS_REPO=""
|
||||
RUN ASSETS_DIR="$HOME/VLABench/VLABench/assets" && \
|
||||
if [ -n "${VLABENCH_ASSETS_REPO}" ]; then \
|
||||
echo "Downloading VLABench assets from HF Hub: ${VLABENCH_ASSETS_REPO}" && \
|
||||
uv pip install --no-cache "huggingface_hub[hf_xet]>=0.26" && \
|
||||
python -c "from huggingface_hub import snapshot_download; \
|
||||
p = snapshot_download(repo_id='${VLABENCH_ASSETS_REPO}', repo_type='dataset', \
|
||||
local_dir='${ASSETS_DIR}', allow_patterns=['obj/**', 'scenes/**']); \
|
||||
print('snapshot_download returned:', p)"; \
|
||||
else \
|
||||
echo "No VLABENCH_ASSETS_REPO set — falling back to gdown" && \
|
||||
python ~/VLABench/scripts/download_assets.py --choice all; \
|
||||
fi && \
|
||||
python -c "\
|
||||
from pathlib import Path; \
|
||||
import sys; \
|
||||
root = Path('${ASSETS_DIR}'); \
|
||||
checks = ['obj/meshes/tablewares/plates', 'obj/meshes/containers/basket', 'obj/meshes/fruit', 'obj/meshes/containers/tray']; \
|
||||
failed = []; \
|
||||
print(f'Validating VLABench assets under {root}'); \
|
||||
[print(f' {c}: {len(list((root/c).rglob(\"*.xml\")))} XMLs') for c in checks]; \
|
||||
[failed.append(c) for c in checks if not any((root/c).rglob('*.xml'))]; \
|
||||
sys.exit(f'Empty asset dirs (no *.xml): {failed}') if failed else print('All asset dirs populated.')"
|
||||
|
||||
# Overlay the PR's source code on top of the nightly image.
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
# Re-install lerobot editably so the new source (with VLABenchEnv registration
|
||||
# and updated obs handling) replaces the stale package baked into the nightly image.
|
||||
RUN uv pip install --no-cache --no-deps -e .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -61,8 +61,6 @@
|
||||
title: SARM
|
||||
title: "Reward Models"
|
||||
- sections:
|
||||
- local: inference
|
||||
title: Policy Deployment (lerobot-rollout)
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
@@ -79,10 +77,22 @@
|
||||
title: Adding a New Benchmark
|
||||
- local: libero
|
||||
title: LIBERO
|
||||
- local: libero_plus
|
||||
title: LIBERO-plus
|
||||
- local: metaworld
|
||||
title: Meta-World
|
||||
- local: robotwin
|
||||
title: RoboTwin 2.0
|
||||
- local: robocasa
|
||||
title: RoboCasa365
|
||||
- local: robocerebra
|
||||
title: RoboCerebra
|
||||
- local: robomme
|
||||
title: RoboMME
|
||||
- local: envhub_isaaclab_arena
|
||||
title: NVIDIA IsaacLab Arena Environments
|
||||
- local: vlabench
|
||||
title: VLABench
|
||||
title: "Benchmarks"
|
||||
- sections:
|
||||
- local: introduction_processors
|
||||
|
||||
@@ -50,30 +50,30 @@ This process can be repeated iteratively: deploy, collect, fine-tune, repeat. Ea
|
||||
|
||||
### Teleoperator Requirements
|
||||
|
||||
The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with active motors** that can:
|
||||
The `examples/hil` HIL scripts require **teleoperators with active motors** that can:
|
||||
|
||||
- Enable/disable torque programmatically
|
||||
- Move to target positions (to mirror the robot state when pausing)
|
||||
|
||||
**Compatible teleoperators:**
|
||||
**Compatible teleoperators in the current `examples/hil` scripts:**
|
||||
|
||||
- `openarm_mini` - OpenArm Mini
|
||||
- `so_leader` - SO100 / SO101 leader arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> The provided `examples/hil` commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||
|
||||
---
|
||||
|
||||
## Script
|
||||
|
||||
Use `lerobot-rollout` with `--strategy.type=dagger` for HIL data collection. Select the inference backend with `--inference.type=sync|rtc`:
|
||||
A single script handles both synchronous and RTC-based inference. Toggle RTC with `--rtc.enabled=true`:
|
||||
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | ---------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--inference.type=rtc` | Pi0, Pi0.5, SmolVLA |
|
||||
| Mode | Flag | Models |
|
||||
| ------------------------ | -------------------- | --------------------- |
|
||||
| Standard (default) | _(no flag needed)_ | ACT, Diffusion Policy |
|
||||
| Real-Time Chunking (RTC) | `--rtc.enabled=true` | Pi0, Pi0.5, SmolVLA |
|
||||
|
||||
---
|
||||
|
||||
@@ -97,7 +97,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
**Standard inference (ACT, Diffusion Policy):**
|
||||
|
||||
```bash
|
||||
lerobot-rollout --strategy.type=dagger \
|
||||
python examples/hil/hil_data_collection.py \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
@@ -111,7 +111,8 @@ lerobot-rollout --strategy.type=dagger \
|
||||
--dataset.repo_id=your-username/hil-dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--strategy.num_episodes=50 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--interpolation_multiplier=2
|
||||
```
|
||||
|
||||
@@ -120,11 +121,11 @@ lerobot-rollout --strategy.type=dagger \
|
||||
For models with high inference latency, enable RTC for smooth execution:
|
||||
|
||||
```bash
|
||||
lerobot-rollout --strategy.type=dagger \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=20 \
|
||||
--inference.rtc.max_guidance_weight=5.0 \
|
||||
--inference.rtc.prefix_attention_schedule=LINEAR \
|
||||
python examples/hil/hil_data_collection.py \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
@@ -138,7 +139,8 @@ lerobot-rollout --strategy.type=dagger \
|
||||
--dataset.repo_id=your-username/hil-rtc-dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
--dataset.fps=30 \
|
||||
--strategy.num_episodes=50 \
|
||||
--dataset.episode_time_s=1000 \
|
||||
--dataset.num_episodes=50 \
|
||||
--interpolation_multiplier=3
|
||||
```
|
||||
|
||||
@@ -233,7 +235,7 @@ This HIL data collection approach builds on ideas from interactive imitation lea
|
||||
|
||||
- **HG-DAgger** (Kelly et al., 2019) made this practical for robotics: a human expert monitors the robot and only intervenes when needed, rather than labeling every state. The gating between autonomous and human control is exactly the pause → takeover → return-to-policy loop used in the scripts here.
|
||||
|
||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the DAgger strategy in `lerobot-rollout`.
|
||||
- **RaC** (Hu et al., 2025) scales this loop to long-horizon tasks by explicitly decomposing interventions into **recovery** (teleoperating back to a good state) and **correction** (demonstrating the right behavior from there). This decomposition is the protocol followed by the HIL scripts in `examples/hil`.
|
||||
|
||||
- **π0.6/RECAP** (Physical Intelligence, 2025) applies the same iterative collect-and-finetune loop at scale with VLA models, showing that even large pretrained policies benefit substantially from targeted human corrections on their own failure modes. π0.6 is trained using RECAP.
|
||||
|
||||
|
||||
@@ -509,42 +509,121 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
|
||||
## Run inference and evaluate your policy
|
||||
|
||||
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
|
||||
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Base mode (no recording)">
|
||||
<hfoption id="Command">
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Put lego brick into the transparent box" \
|
||||
--duration=60
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Sentry mode (with recording)">
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM1 \
|
||||
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=my_awesome_follower_arm \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--duration=600
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
# --teleop.id=my_awesome_leader_arm \
|
||||
--policy.path=${HF_USER}/my_policy
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
_, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording")
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
)
|
||||
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Run the policy inference loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
dataset.push_to_hub()
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The `--strategy.type` flag selects the execution mode:
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
|
||||
- `base`: Autonomous rollout with no data recording (useful for quick evaluation)
|
||||
- `sentry`: Continuous recording with auto-upload (useful for large-scale evaluation)
|
||||
- `highlight`: Ring buffer recording with keystroke save (useful for capturing interesting events)
|
||||
- `dagger`: Human-in-the-loop data collection (see [HIL Data Collection](./hil_data_collection))
|
||||
|
||||
All strategies support `--inference.type=rtc` for smooth execution with slow VLA models (Pi0, Pi0.5, SmolVLA).
|
||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`).
|
||||
|
||||
@@ -1,261 +0,0 @@
|
||||
# Policy Deployment (lerobot-rollout)
|
||||
|
||||
`lerobot-rollout` is the single CLI for deploying trained policies on real robots. It supports multiple execution strategies and inference backends, from quick evaluation to continuous recording and human-in-the-loop data collection.
|
||||
|
||||
## Quick Start
|
||||
|
||||
No extra dependencies are needed beyond your robot and policy extras.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=lerobot/act_koch_real \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--task="pick up cube" \
|
||||
--duration=30
|
||||
```
|
||||
|
||||
This runs the policy for 30 seconds with no recording.
|
||||
|
||||
---
|
||||
|
||||
## Strategies
|
||||
|
||||
Select a strategy with `--strategy.type=<name>`. Each strategy defines a different control loop with its own recording and interaction semantics.
|
||||
|
||||
### Base (`--strategy.type=base`)
|
||||
|
||||
Autonomous policy execution with no data recording. Use this for quick evaluation, demos, or when you only need to observe the robot.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Put lego brick into the box" \
|
||||
--duration=60
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| ---------------- | ------------------------------------------------------ |
|
||||
| `--duration` | Run time in seconds (0 = infinite) |
|
||||
| `--task` | Task description passed to the policy |
|
||||
| `--display_data` | Stream observations/actions to Rerun for visualization |
|
||||
|
||||
### Sentry (`--strategy.type=sentry`)
|
||||
|
||||
Continuous autonomous recording with periodic upload to the Hugging Face Hub. Episode boundaries are auto-computed from camera resolution and FPS so each saved episode produces a complete video file, keeping uploads efficient.
|
||||
|
||||
Policy state (hidden state, RTC queue) persists across episode boundaries: the robot does not reset between episodes.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--dataset.repo_id=${HF_USER}/eval_data \
|
||||
--dataset.single_task="Put lego brick into the box" \
|
||||
--duration=3600
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| -------------------------------------- | ----------------------------------------------------------- |
|
||||
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||
| `--strategy.target_video_file_size_mb` | Target video file size for episode rotation (default: auto) |
|
||||
| `--dataset.repo_id` | **Required.** Hub repository for the recorded dataset |
|
||||
| `--dataset.push_to_hub` | Whether to push to Hub on teardown (default: true) |
|
||||
|
||||
### Highlight (`--strategy.type=highlight`)
|
||||
|
||||
Autonomous rollout with on-demand recording via a memory-bounded ring buffer. The robot runs continuously while the buffer captures the last N seconds of telemetry. Press the save key to flush the buffer and start live recording; press it again to save the episode.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=highlight \
|
||||
--strategy.ring_buffer_seconds=30 \
|
||||
--strategy.save_key=s \
|
||||
--strategy.push_key=h \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=koch_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--dataset.repo_id=${HF_USER}/highlight_data \
|
||||
--dataset.single_task="Pick up the red cube"
|
||||
```
|
||||
|
||||
**Keyboard controls:**
|
||||
|
||||
| Key | Action |
|
||||
| ------------------ | -------------------------------------------------------- |
|
||||
| `s` (configurable) | Start recording (flushes buffer) / stop and save episode |
|
||||
| `h` (configurable) | Push dataset to Hub |
|
||||
| `ESC` | Stop the session |
|
||||
|
||||
| Flag | Description |
|
||||
| -------------------------------------- | ---------------------------------------------- |
|
||||
| `--strategy.ring_buffer_seconds` | Duration of buffered telemetry (default: 30) |
|
||||
| `--strategy.ring_buffer_max_memory_mb` | Memory cap for the ring buffer (default: 2048) |
|
||||
| `--strategy.save_key` | Key to toggle recording (default: `s`) |
|
||||
| `--strategy.push_key` | Key to push to Hub (default: `h`) |
|
||||
|
||||
### DAgger (`--strategy.type=dagger`)
|
||||
|
||||
Human-in-the-loop data collection. Alternates between autonomous policy execution and human intervention via a teleoperator. Intervention frames are tagged with `intervention=True`. Requires a teleoperator (`--teleop.type`).
|
||||
|
||||
See the [Human-In-the-Loop Data Collection](./hil_data_collection) guide for a detailed walkthrough.
|
||||
|
||||
**Corrections-only mode** (default): Only human correction windows are recorded. Each correction becomes one episode.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=dagger \
|
||||
--strategy.num_episodes=20 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--teleop.type=openarm_mini \
|
||||
--dataset.repo_id=${HF_USER}/hil_data \
|
||||
--dataset.single_task="Fold the T-shirt"
|
||||
```
|
||||
|
||||
**Continuous recording mode** (`--strategy.record_autonomous=true`): Both autonomous and correction frames are recorded with time-based episode rotation (same as Sentry).
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=dagger \
|
||||
--strategy.record_autonomous=true \
|
||||
--strategy.num_episodes=50 \
|
||||
--policy.path=${HF_USER}/my_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM1 \
|
||||
--dataset.repo_id=${HF_USER}/dagger_data \
|
||||
--dataset.single_task="Grasp the block"
|
||||
```
|
||||
|
||||
**Keyboard controls** (default input device):
|
||||
|
||||
| Key | Action |
|
||||
| ------- | ------------------------------------------- |
|
||||
| `Space` | Pause / resume policy execution |
|
||||
| `Tab` | Start / stop human correction |
|
||||
| `Enter` | Push dataset to Hub (corrections-only mode) |
|
||||
| `ESC` | Stop the session |
|
||||
|
||||
Foot pedal input is also supported via `--strategy.input_device=pedal`. Configure pedal codes with `--strategy.pedal.*` flags.
|
||||
|
||||
| Flag | Description |
|
||||
| ------------------------------------ | ------------------------------------------------------- |
|
||||
| `--strategy.num_episodes` | Number of correction episodes to record (default: 10) |
|
||||
| `--strategy.record_autonomous` | Record autonomous frames too (default: false) |
|
||||
| `--strategy.upload_every_n_episodes` | Push to Hub every N episodes (default: 5) |
|
||||
| `--strategy.input_device` | Input device: `keyboard` or `pedal` (default: keyboard) |
|
||||
| `--teleop.type` | **Required.** Teleoperator type |
|
||||
|
||||
---
|
||||
|
||||
## Inference Backends
|
||||
|
||||
Select a backend with `--inference.type=<name>`. All strategies work with both backends.
|
||||
|
||||
### Sync (default)
|
||||
|
||||
One policy call per control tick. The main loop blocks until the action is computed.
|
||||
|
||||
Works with all policies. No extra flags needed.
|
||||
|
||||
### Real-Time Chunking (`--inference.type=rtc`)
|
||||
|
||||
A background thread produces action chunks asynchronously. The main control loop polls for the next ready action while the policy computes the next chunk in parallel.
|
||||
|
||||
Use RTC with large, slow VLA models (Pi0, Pi0.5, SmolVLA) for smooth, continuous motion despite high inference latency.
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0 \
|
||||
--policy.path=${HF_USER}/pi0_policy \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/ttyACM0 \
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Pick up the cube" \
|
||||
--duration=60 \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
| Flag | Description |
|
||||
| ------------------------------------------- | -------------------------------------------------------------- |
|
||||
| `--inference.rtc.execution_horizon` | Steps to blend with previous chunk (default: varies by policy) |
|
||||
| `--inference.rtc.max_guidance_weight` | Consistency enforcement strength (default: varies by policy) |
|
||||
| `--inference.rtc.prefix_attention_schedule` | Blend schedule: `LINEAR`, `EXP`, `ONES`, `ZEROS` |
|
||||
| `--inference.queue_threshold` | Max queue size before backpressure (default: 30) |
|
||||
|
||||
See the [Real-Time Chunking](./rtc) guide for details on tuning RTC parameters.
|
||||
|
||||
---
|
||||
|
||||
## Common Flags
|
||||
|
||||
| Flag | Description | Default |
|
||||
| --------------------------------- | ----------------------------------------------------------------- | ------- |
|
||||
| `--policy.path` | **Required.** HF Hub model ID or local checkpoint path | -- |
|
||||
| `--robot.type` | **Required.** Robot type (e.g. `so100_follower`, `koch_follower`) | -- |
|
||||
| `--robot.port` | Serial port for the robot | -- |
|
||||
| `--robot.cameras` | Camera configuration (JSON dict) | -- |
|
||||
| `--fps` | Control loop frequency | 30 |
|
||||
| `--duration` | Run time in seconds (0 = infinite) | 0 |
|
||||
| `--device` | Torch device (`cpu`, `cuda`, `mps`) | auto |
|
||||
| `--task` | Task description (used when no dataset is provided) | -- |
|
||||
| `--display_data` | Stream telemetry to Rerun visualization | false |
|
||||
| `--display_ip` / `--display_port` | Remote Rerun server address | -- |
|
||||
| `--interpolation_multiplier` | Action interpolation factor | 1 |
|
||||
| `--use_torch_compile` | Enable `torch.compile` for inference | false |
|
||||
| `--resume` | Resume a previous recording session | false |
|
||||
| `--play_sounds` | Vocal synthesis for events | true |
|
||||
|
||||
---
|
||||
|
||||
## Programmatic Usage
|
||||
|
||||
For custom deployments (e.g. with kinematics processors), use the rollout module API directly:
|
||||
|
||||
```python
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=my_robot_config,
|
||||
policy=my_policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=30,
|
||||
duration=60,
|
||||
task="my task",
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=my_custom_action_processor, # optional
|
||||
robot_observation_processor=my_custom_obs_processor, # optional
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
```
|
||||
|
||||
See `examples/so100_to_so100_EE/rollout.py` and `examples/phone_to_so100/rollout.py` for full examples with kinematics processors.
|
||||
188
docs/source/libero_plus.mdx
Normal file
188
docs/source/libero_plus.mdx
Normal file
@@ -0,0 +1,188 @@
|
||||
# LIBERO-plus
|
||||
|
||||
LIBERO-plus is a **robustness benchmark** for Vision-Language-Action (VLA) models built on top of [LIBERO](./libero). It systematically stress-tests policies by applying **seven independent perturbation dimensions** to the original LIBERO task set, exposing failure modes that standard benchmarks miss.
|
||||
|
||||
- Paper: [In-depth Robustness Analysis of Vision-Language-Action Models](https://arxiv.org/abs/2510.13626)
|
||||
- GitHub: [sylvestf/LIBERO-plus](https://github.com/sylvestf/LIBERO-plus)
|
||||
- Dataset: [lerobot/libero_plus](https://huggingface.co/datasets/lerobot/libero_plus)
|
||||
|
||||

|
||||
|
||||
## Perturbation dimensions
|
||||
|
||||
LIBERO-plus creates ~10 000 task variants by perturbing each original LIBERO task along these axes:
|
||||
|
||||
| Dimension | What changes |
|
||||
| --------------------- | ----------------------------------------------------- |
|
||||
| Objects layout | Target position, presence of confounding objects |
|
||||
| Camera viewpoints | Camera position, orientation, field-of-view |
|
||||
| Robot initial states | Manipulator start pose |
|
||||
| Language instructions | LLM-rewritten task description (paraphrase / synonym) |
|
||||
| Light conditions | Intensity, direction, color, shadow |
|
||||
| Background textures | Scene surface and object appearance |
|
||||
| Sensor noise | Photometric distortions and image degradation |
|
||||
|
||||
## Available task suites
|
||||
|
||||
LIBERO-plus covers the same five suites as LIBERO:
|
||||
|
||||
| Suite | CLI name | Tasks | Max steps | Description |
|
||||
| -------------- | ---------------- | ----- | --------- | -------------------------------------------------- |
|
||||
| LIBERO-Spatial | `libero_spatial` | 10 | 280 | Tasks requiring reasoning about spatial relations |
|
||||
| LIBERO-Object | `libero_object` | 10 | 280 | Tasks centered on manipulating different objects |
|
||||
| LIBERO-Goal | `libero_goal` | 10 | 300 | Goal-conditioned tasks with changing targets |
|
||||
| LIBERO-90 | `libero_90` | 90 | 400 | Short-horizon tasks from the LIBERO-100 collection |
|
||||
| LIBERO-Long | `libero_10` | 10 | 520 | Long-horizon tasks from the LIBERO-100 collection |
|
||||
|
||||
<Tip warning={true}>
|
||||
Installing LIBERO-plus **replaces** vanilla LIBERO — it uninstalls `hf-libero`
|
||||
so that `import libero` resolves to the LIBERO-plus fork. You cannot have both
|
||||
installed at the same time. To switch back to vanilla LIBERO, uninstall the
|
||||
fork and reinstall with `pip install -e ".[libero]"`.
|
||||
</Tip>
|
||||
|
||||
## Installation
|
||||
|
||||
### System dependencies (Linux only)
|
||||
|
||||
```bash
|
||||
sudo apt install libexpat1 libfontconfig1-dev libmagickwand-dev
|
||||
```
|
||||
|
||||
### Python package
|
||||
|
||||
```bash
|
||||
pip install -e ".[libero]" "robosuite==1.4.1" bddl easydict mujoco wand scikit-image gym
|
||||
git clone https://github.com/sylvestf/LIBERO-plus.git
|
||||
cd LIBERO-plus && pip install --no-deps -e .
|
||||
pip uninstall -y hf-libero # so `import libero` resolves to the fork
|
||||
```
|
||||
|
||||
LIBERO-plus is installed from its GitHub fork rather than a pyproject extra — the fork ships as a namespace package that pip can't handle, so it must be cloned and added to `PYTHONPATH`. See `docker/Dockerfile.benchmark.libero_plus` for the canonical install. MuJoCo is required, so only Linux is supported.
|
||||
|
||||
<Tip>
|
||||
Set the MuJoCo rendering backend before running evaluation:
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl # headless / HPC / cloud
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
### Download LIBERO-plus assets
|
||||
|
||||
LIBERO-plus ships its extended asset pack separately. Download `assets.zip` from the [Hugging Face dataset](https://huggingface.co/datasets/Sylvest/LIBERO-plus/tree/main) and extract it into the LIBERO-plus package directory:
|
||||
|
||||
```bash
|
||||
# After installing the package, find where it was installed:
|
||||
python -c "import libero; print(libero.__file__)"
|
||||
# Then extract assets.zip into <package_root>/libero/assets/
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Default evaluation (recommended)
|
||||
|
||||
Evaluate across the four standard suites (10 episodes per task):
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero_plus \
|
||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--env.max_parallel_tasks=1
|
||||
```
|
||||
|
||||
### Single-suite evaluation
|
||||
|
||||
Evaluate on one LIBERO-plus suite:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero_plus \
|
||||
--env.task=libero_spatial \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10
|
||||
```
|
||||
|
||||
- `--env.task` picks the suite (`libero_spatial`, `libero_object`, etc.).
|
||||
- `--env.task_ids` restricts to specific task indices (`[0]`, `[1,2,3]`, etc.). Omit to run all tasks in the suite.
|
||||
- `--eval.batch_size` controls how many environments run in parallel.
|
||||
- `--eval.n_episodes` sets how many episodes to run per task.
|
||||
|
||||
### Multi-suite evaluation
|
||||
|
||||
Benchmark a policy across multiple suites at once by passing a comma-separated list:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-policy-id" \
|
||||
--env.type=libero_plus \
|
||||
--env.task=libero_spatial,libero_object \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10
|
||||
```
|
||||
|
||||
### Control mode
|
||||
|
||||
LIBERO-plus supports two control modes — `relative` (default) and `absolute`. Different VLA checkpoints are trained with different action parameterizations, so make sure the mode matches your policy:
|
||||
|
||||
```bash
|
||||
--env.control_mode=relative # or "absolute"
|
||||
```
|
||||
|
||||
### Policy inputs and outputs
|
||||
|
||||
**Observations:**
|
||||
|
||||
- `observation.state` — 8-dim proprioceptive features (eef position, axis-angle orientation, gripper qpos)
|
||||
- `observation.images.image` — main camera view (`agentview_image`), HWC uint8
|
||||
- `observation.images.image2` — wrist camera view (`robot0_eye_in_hand_image`), HWC uint8
|
||||
|
||||
**Actions:**
|
||||
|
||||
- Continuous control in `Box(-1, 1, shape=(7,))` — 6D end-effector delta + 1D gripper
|
||||
|
||||
### Recommended evaluation episodes
|
||||
|
||||
For reproducible benchmarking, use **10 episodes per task** across all four standard suites (Spatial, Object, Goal, Long). This gives 400 total episodes and matches the protocol used for published results.
|
||||
|
||||
## Training
|
||||
|
||||
### Dataset
|
||||
|
||||
A LeRobot-format training dataset for LIBERO-plus is available at:
|
||||
|
||||
- [lerobot/libero_plus](https://huggingface.co/datasets/lerobot/libero_plus)
|
||||
|
||||
### Example training command
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.repo_id=${HF_USER}/smolvla_libero_plus \
|
||||
--policy.load_vlm_weights=true \
|
||||
--dataset.repo_id=lerobot/libero_plus \
|
||||
--env.type=libero_plus \
|
||||
--env.task=libero_spatial \
|
||||
--output_dir=./outputs/ \
|
||||
--steps=100000 \
|
||||
--batch_size=4 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval_freq=1000
|
||||
```
|
||||
|
||||
## Relationship to LIBERO
|
||||
|
||||
LIBERO-plus is a drop-in extension of LIBERO:
|
||||
|
||||
- Same Python gym interface (`LiberoEnv`, `LiberoProcessorStep`)
|
||||
- Same camera names and observation/action format
|
||||
- Same task suite names
|
||||
- Installs under the same `libero` Python package name (different GitHub repo)
|
||||
|
||||
To use the original LIBERO benchmark, see [LIBERO](./libero) and use `--env.type=libero`.
|
||||
188
docs/source/robocasa.mdx
Normal file
188
docs/source/robocasa.mdx
Normal file
@@ -0,0 +1,188 @@
|
||||
# RoboCasa365
|
||||
|
||||
[RoboCasa365](https://robocasa.ai) is a large-scale simulation framework for training and benchmarking **generalist robots** in everyday kitchen tasks. It ships 365 diverse manipulation tasks across 2,500 kitchen environments, 3,200+ object assets and 600+ hours of human demonstration data, on a PandaOmron 12-DOF mobile manipulator (Franka arm on a holonomic base).
|
||||
|
||||
- Paper: [RoboCasa: Large-Scale Simulation of Everyday Tasks for Generalist Robots](https://arxiv.org/abs/2406.02523)
|
||||
- GitHub: [robocasa/robocasa](https://github.com/robocasa/robocasa)
|
||||
- Project website: [robocasa.ai](https://robocasa.ai)
|
||||
- Pretrained policy: [`lerobot/smolvla_robocasa`](https://huggingface.co/lerobot/smolvla_robocasa)
|
||||
- Single-task dataset (CloseFridge): [`pepijn223/robocasa_CloseFridge`](https://huggingface.co/datasets/pepijn223/robocasa_CloseFridge)
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/robocasa-banner.webp"
|
||||
alt="RoboCasa365 benchmark overview"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
## Available tasks
|
||||
|
||||
RoboCasa365 organizes its 365 tasks into two families and three upstream benchmark groups that LeRobot exposes as first-class `--env.task` shortcuts:
|
||||
|
||||
| Family | Tasks | Description |
|
||||
| --------- | ----- | ------------------------------------------------------------------------------- |
|
||||
| Atomic | ~65 | Single-skill tasks: pick-and-place, door/drawer manipulation, appliance control |
|
||||
| Composite | ~300 | Multi-step tasks across 60+ categories: cooking, cleaning, organizing, etc. |
|
||||
|
||||
**Atomic task examples:** `CloseFridge`, `OpenDrawer`, `OpenCabinet`, `TurnOnMicrowave`, `TurnOffStove`, `NavigateKitchen`, `PickPlaceCounterToStove`.
|
||||
|
||||
**Composite task categories:** baking, boiling, brewing, chopping, clearing table, defrosting food, loading dishwasher, making tea, microwaving food, washing dishes, and more.
|
||||
|
||||
`--env.task` accepts three forms:
|
||||
|
||||
- a single task name (`CloseFridge`)
|
||||
- a comma-separated list (`CloseFridge,OpenBlenderLid,PickPlaceCoffee`)
|
||||
- a benchmark-group shortcut — `atomic_seen`, `composite_seen`, `composite_unseen`, `pretrain50`, `pretrain100`, `pretrain200`, `pretrain300` — which auto-expands to the upstream task list and auto-sets the dataset `split` (`target` or `pretrain`).
|
||||
|
||||
## Installation
|
||||
|
||||
RoboCasa and its dependency `robosuite` are not published on PyPI, and RoboCasa's own `setup.py` hardcodes `lerobot==0.3.3`, which conflicts with this repo's `lerobot`. LeRobot therefore does **not** expose a `robocasa` extra — install the two packages manually as editable clones (using `--no-deps` on `robocasa` to skip its shadowed `lerobot` pin):
|
||||
|
||||
```bash
|
||||
# After following the standard LeRobot installation instructions.
|
||||
|
||||
git clone https://github.com/robocasa/robocasa.git ~/robocasa
|
||||
git clone https://github.com/ARISE-Initiative/robosuite.git ~/robosuite
|
||||
pip install -e ~/robocasa --no-deps
|
||||
pip install -e ~/robosuite
|
||||
|
||||
# Robocasa's runtime deps (the ones its setup.py would have pulled, minus
|
||||
# the bad lerobot pin).
|
||||
pip install numpy numba scipy mujoco pygame Pillow opencv-python \
|
||||
pyyaml pynput tqdm termcolor imageio h5py lxml hidapi \
|
||||
tianshou gymnasium
|
||||
|
||||
python -m robocasa.scripts.setup_macros
|
||||
# Lightweight assets (lightwheel object meshes + textures). Enough for
|
||||
# the default env out of the box.
|
||||
python -m robocasa.scripts.download_kitchen_assets \
|
||||
--type tex tex_generative fixtures_lw objs_lw
|
||||
# Optional: full objaverse/aigen registries (~30GB) for richer object
|
||||
# variety. Enable at eval time via --env.obj_registries (see below).
|
||||
# python -m robocasa.scripts.download_kitchen_assets --type objs_objaverse
|
||||
```
|
||||
|
||||
<Tip>
|
||||
RoboCasa requires MuJoCo. Set the rendering backend before training or evaluation:
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl # for headless servers (HPC, cloud)
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
### Object registries
|
||||
|
||||
By default the env samples objects only from the `lightwheel` registry (what `--type objs_lw` ships), which avoids a `Probabilities contain NaN` crash when the objaverse / aigen packs aren't on disk. If you've downloaded the full asset set, enable the full registry at runtime:
|
||||
|
||||
```bash
|
||||
--env.obj_registries='[objaverse,lightwheel]'
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
All eval snippets below mirror the CI command (see `.github/workflows/benchmark_tests.yml`). The `--rename_map` argument maps RoboCasa's native camera keys (`robot0_agentview_left` / `robot0_eye_in_hand` / `robot0_agentview_right`) onto the three-camera (`camera1` / `camera2` / `camera3`) input layout the released `smolvla_robocasa` policy was trained on.
|
||||
|
||||
### Single-task evaluation (recommended for quick iteration)
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_robocasa \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={"observation.images.robot0_agentview_left": "observation.images.camera1", "observation.images.robot0_eye_in_hand": "observation.images.camera2", "observation.images.robot0_agentview_right": "observation.images.camera3"}'
|
||||
```
|
||||
|
||||
### Multi-task evaluation
|
||||
|
||||
Pass a comma-separated list of tasks:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_robocasa \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={"observation.images.robot0_agentview_left": "observation.images.camera1", "observation.images.robot0_eye_in_hand": "observation.images.camera2", "observation.images.robot0_agentview_right": "observation.images.camera3"}'
|
||||
```
|
||||
|
||||
### Benchmark-group evaluation
|
||||
|
||||
Run an entire upstream group (e.g. all 18 `atomic_seen` tasks with `split=target`):
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_robocasa \
|
||||
--env.type=robocasa \
|
||||
--env.task=atomic_seen \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={"observation.images.robot0_agentview_left": "observation.images.camera1", "observation.images.robot0_eye_in_hand": "observation.images.camera2", "observation.images.robot0_agentview_right": "observation.images.camera3"}'
|
||||
```
|
||||
|
||||
### Recommended evaluation episodes
|
||||
|
||||
**20 episodes per task** for reproducible benchmarking. Matches the protocol used in published results.
|
||||
|
||||
## Policy inputs and outputs
|
||||
|
||||
**Observations** (raw RoboCasa camera names are preserved verbatim):
|
||||
|
||||
- `observation.state` — 16-dim proprioceptive state (base position, base quaternion, relative end-effector position, relative end-effector quaternion, gripper qpos)
|
||||
- `observation.images.robot0_agentview_left` — left agent view, 256×256 HWC uint8
|
||||
- `observation.images.robot0_eye_in_hand` — wrist camera view, 256×256 HWC uint8
|
||||
- `observation.images.robot0_agentview_right` — right agent view, 256×256 HWC uint8
|
||||
|
||||
**Actions:**
|
||||
|
||||
- Continuous control in `Box(-1, 1, shape=(12,))` — base motion (4D) + control mode (1D) + end-effector position (3D) + end-effector rotation (3D) + gripper (1D).
|
||||
|
||||
## Training
|
||||
|
||||
### Single-task example
|
||||
|
||||
A ready-to-use single-task dataset is on the Hub:
|
||||
[`pepijn223/robocasa_CloseFridge`](https://huggingface.co/datasets/pepijn223/robocasa_CloseFridge).
|
||||
|
||||
Fine-tune a SmolVLA base on `CloseFridge`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.repo_id=${HF_USER}/smolvla_robocasa_CloseFridge \
|
||||
--policy.load_vlm_weights=true \
|
||||
--policy.push_to_hub=true \
|
||||
--dataset.repo_id=pepijn223/robocasa_CloseFridge \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge \
|
||||
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
|
||||
--steps=100000 \
|
||||
--batch_size=4 \
|
||||
--eval_freq=5000 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=5 \
|
||||
--save_freq=10000
|
||||
```
|
||||
|
||||
Evaluate the resulting checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=${HF_USER}/smolvla_robocasa_CloseFridge \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20
|
||||
```
|
||||
|
||||
## Reproducing published results
|
||||
|
||||
The released checkpoint [`lerobot/smolvla_robocasa`](https://huggingface.co/lerobot/smolvla_robocasa) is evaluated with the commands in the [Evaluation](#evaluation) section. CI runs a 10-atomic-task smoke eval (one episode each) on every PR touching the benchmark, picking fixture-centric tasks that don't require the objaverse asset pack.
|
||||
99
docs/source/robocerebra.mdx
Normal file
99
docs/source/robocerebra.mdx
Normal file
@@ -0,0 +1,99 @@
|
||||
# RoboCerebra
|
||||
|
||||
[RoboCerebra](https://robocerebra-project.github.io/) is a long-horizon manipulation benchmark that evaluates **high-level reasoning, planning, and memory** in VLAs. Episodes chain multiple sub-goals with language-grounded intermediate instructions, built on top of LIBERO's simulator stack (MuJoCo + robosuite, Franka Panda 7-DOF).
|
||||
|
||||
- Paper: [RoboCerebra: A Large-scale Benchmark for Long-horizon Robotic Manipulation Evaluation](https://arxiv.org/abs/2506.06677)
|
||||
- Project website: [robocerebra-project.github.io](https://robocerebra-project.github.io/)
|
||||
- Dataset: [`lerobot/robocerebra_unified`](https://huggingface.co/datasets/lerobot/robocerebra_unified) — LeRobot v3.0, 6,660 episodes / 571,116 frames at 20 fps, 1,728 language-grounded sub-tasks.
|
||||
- Pretrained policy: [`lerobot/smolvla_robocerebra`](https://huggingface.co/lerobot/smolvla_robocerebra)
|
||||
|
||||
## Available tasks
|
||||
|
||||
RoboCerebra reuses LIBERO's simulator, so evaluation runs against the LIBERO `libero_10` long-horizon suite:
|
||||
|
||||
| Suite | CLI name | Tasks | Description |
|
||||
| --------- | ----------- | ----- | ------------------------------------------------------------- |
|
||||
| LIBERO-10 | `libero_10` | 10 | Long-horizon kitchen/living room tasks chaining 3–6 sub-goals |
|
||||
|
||||
Each RoboCerebra episode in the dataset is segmented into multiple sub-tasks with natural-language instructions, which the unified dataset exposes as independent supervision signals.
|
||||
|
||||
## Installation
|
||||
|
||||
RoboCerebra piggybacks on LIBERO, so the `libero` extra is all you need:
|
||||
|
||||
```bash
|
||||
pip install -e ".[libero]"
|
||||
```
|
||||
|
||||
<Tip>
|
||||
RoboCerebra requires Linux (MuJoCo / robosuite). Set the rendering backend before training or evaluation:
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl # for headless servers (HPC, cloud)
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
## Evaluation
|
||||
|
||||
RoboCerebra eval runs against LIBERO's `libero_10` suite with RoboCerebra's camera naming (`image` + `wrist_image`) and an extra empty-camera slot so a three-view-trained policy receives the expected input layout:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_robocerebra \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--env.fps=20 \
|
||||
--env.obs_type=pixels_agent_pos \
|
||||
--env.observation_height=256 \
|
||||
--env.observation_width=256 \
|
||||
'--env.camera_name_mapping={"agentview_image": "image", "robot0_eye_in_hand_image": "wrist_image"}' \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={"observation.images.image": "observation.images.camera1", "observation.images.wrist_image": "observation.images.camera2"}' \
|
||||
--policy.empty_cameras=1
|
||||
```
|
||||
|
||||
### Recommended evaluation episodes
|
||||
|
||||
**10 episodes per task** across the `libero_10` suite (100 total) for reproducible benchmarking. Matches the protocol used in the RoboCerebra paper.
|
||||
|
||||
## Policy inputs and outputs
|
||||
|
||||
**Observations:**
|
||||
|
||||
- `observation.state` — 8-dim proprioceptive state (7 joint positions + gripper)
|
||||
- `observation.images.image` — third-person view, 256×256 HWC uint8
|
||||
- `observation.images.wrist_image` — wrist-mounted camera view, 256×256 HWC uint8
|
||||
|
||||
**Actions:**
|
||||
|
||||
- Continuous control in `Box(-1, 1, shape=(7,))` — end-effector delta (6D) + gripper (1D)
|
||||
|
||||
## Training
|
||||
|
||||
The unified dataset at [`lerobot/robocerebra_unified`](https://huggingface.co/datasets/lerobot/robocerebra_unified) exposes two RGB streams and language-grounded sub-task annotations:
|
||||
|
||||
| Feature | Shape | Description |
|
||||
| -------------------------------- | ------------- | -------------------- |
|
||||
| `observation.images.image` | (256, 256, 3) | Third-person view |
|
||||
| `observation.images.wrist_image` | (256, 256, 3) | Wrist-mounted camera |
|
||||
| `observation.state` | (8,) | Joint pos + gripper |
|
||||
| `action` | (7,) | EEF delta + gripper |
|
||||
|
||||
Fine-tune a SmolVLA base on it:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=lerobot/robocerebra_unified \
|
||||
--env.type=libero \
|
||||
--env.task=libero_10 \
|
||||
--output_dir=outputs/smolvla_robocerebra
|
||||
```
|
||||
|
||||
## Reproducing published results
|
||||
|
||||
The released checkpoint [`lerobot/smolvla_robocerebra`](https://huggingface.co/lerobot/smolvla_robocerebra) was trained on `lerobot/robocerebra_unified` and evaluated with the command in the [Evaluation](#evaluation) section. CI runs the same command with `--eval.n_episodes=1` as a smoke test on every PR touching the benchmark.
|
||||
130
docs/source/robomme.mdx
Normal file
130
docs/source/robomme.mdx
Normal file
@@ -0,0 +1,130 @@
|
||||
# RoboMME
|
||||
|
||||
[RoboMME](https://robomme.github.io) is a memory-augmented manipulation benchmark built on ManiSkill (SAPIEN). It evaluates a robot's ability to retain and use information across an episode — counting, object permanence, reference, and imitation.
|
||||
|
||||
- **16 tasks** across 4 memory-skill suites
|
||||
- **1,600 training demos** (100 per task, 50 val, 50 test)
|
||||
- **Dataset**: [`lerobot/robomme`](https://huggingface.co/datasets/lerobot/robomme) — LeRobot v3.0, 768K frames at 10 fps
|
||||
- **Simulator**: ManiSkill / SAPIEN, Panda arm, Linux only
|
||||
|
||||

|
||||
|
||||
## Tasks
|
||||
|
||||
| Suite | Tasks |
|
||||
| --------------------------------- | ------------------------------------------------------------- |
|
||||
| **Counting** (temporal memory) | BinFill, PickXtimes, SwingXtimes, StopCube |
|
||||
| **Permanence** (spatial memory) | VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap |
|
||||
| **Reference** (object memory) | PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder |
|
||||
| **Imitation** (procedural memory) | MoveCube, InsertPeg, PatternLock, RouteStick |
|
||||
|
||||
## Installation
|
||||
|
||||
> RoboMME requires **Linux** (ManiSkill/SAPIEN uses Vulkan rendering). Docker is recommended to isolate dependency conflicts.
|
||||
|
||||
### Native (Linux)
|
||||
|
||||
```bash
|
||||
pip install --override <(printf 'gymnasium==0.29.1\nnumpy==1.26.4\n') \
|
||||
-e '.[smolvla,av-dep]' \
|
||||
'robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main'
|
||||
```
|
||||
|
||||
> **Dependency note**: `mani-skill` (pulled by `robomme`) pins `gymnasium==0.29.1` and `numpy<2.0.0`, which conflict with lerobot's base `numpy>=2.0.0`. That's why `robomme` is not a pyproject extra — use the override install above, or the Docker approach below to avoid conflicts entirely.
|
||||
|
||||
### Docker (recommended)
|
||||
|
||||
```bash
|
||||
# Build base image first (from repo root)
|
||||
docker build -f docker/Dockerfile.eval-base -t lerobot-eval-base .
|
||||
|
||||
# Build RoboMME eval image (applies gymnasium + numpy pin overrides)
|
||||
docker build -f docker/Dockerfile.benchmark.robomme -t lerobot-robomme .
|
||||
```
|
||||
|
||||
The `docker/Dockerfile.benchmark.robomme` image overrides `gymnasium==0.29.1` and `numpy==1.26.4` after lerobot's install. Both versions are runtime-safe for lerobot's actual API usage.
|
||||
|
||||
## Running Evaluation
|
||||
|
||||
### Default (single task, single episode)
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=<your_policy_repo> \
|
||||
--env.type=robomme \
|
||||
--env.task=PickXtimes \
|
||||
--env.dataset_split=test \
|
||||
--env.task_ids=[0] \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1
|
||||
```
|
||||
|
||||
### Multi-task evaluation
|
||||
|
||||
Evaluate multiple tasks in one run by comma-separating task names. Use `task_ids` to control which episodes are evaluated per task. Recommended: 50 episodes per task for the test split.
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=<your_policy_repo> \
|
||||
--env.type=robomme \
|
||||
--env.task=PickXtimes,BinFill,StopCube,MoveCube,InsertPeg \
|
||||
--env.dataset_split=test \
|
||||
--env.task_ids=[0,1,2,3,4,5,6,7,8,9] \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=50
|
||||
```
|
||||
|
||||
### Key CLI options for `env.type=robomme`
|
||||
|
||||
| Option | Default | Description |
|
||||
| -------------------- | ------------- | -------------------------------------------------- |
|
||||
| `env.task` | `PickXtimes` | Any of the 16 task names above (comma-separated) |
|
||||
| `env.dataset_split` | `test` | `train`, `val`, or `test` |
|
||||
| `env.action_space` | `joint_angle` | `joint_angle` (8-D) or `ee_pose` (7-D) |
|
||||
| `env.episode_length` | `300` | Max steps per episode |
|
||||
| `env.task_ids` | `null` | List of episode indices to evaluate (null = `[0]`) |
|
||||
|
||||
## Dataset
|
||||
|
||||
The dataset [`lerobot/robomme`](https://huggingface.co/datasets/lerobot/robomme) is in **LeRobot v3.0 format** and can be loaded directly:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("lerobot/robomme")
|
||||
```
|
||||
|
||||
### Dataset features
|
||||
|
||||
| Feature | Shape | Description |
|
||||
| ------------------ | ------------- | ------------------------------- |
|
||||
| `image` | (256, 256, 3) | Front camera RGB |
|
||||
| `wrist_image` | (256, 256, 3) | Wrist camera RGB |
|
||||
| `actions` | (8,) | Joint angles + gripper |
|
||||
| `state` | (8,) | Joint positions + gripper state |
|
||||
| `simple_subgoal` | str | High-level language annotation |
|
||||
| `grounded_subgoal` | str | Grounded language annotation |
|
||||
| `episode_index` | int | Episode ID |
|
||||
| `frame_index` | int | Frame within episode |
|
||||
|
||||
### Feature key alignment (training)
|
||||
|
||||
The env wrapper exposes `pixels/image` and `pixels/wrist_image` as observation keys. The `features_map` in `RoboMMEEnv` maps these to `observation.images.image` and `observation.images.wrist_image` for the policy. State is exposed as `agent_pos` and maps to `observation.state`.
|
||||
|
||||
The dataset's `image` and `wrist_image` columns already align with the policy input keys, so no renaming is needed when fine-tuning.
|
||||
|
||||
## Action Spaces
|
||||
|
||||
| Type | Dim | Description |
|
||||
| ------------- | --- | --------------------------------------------------------- |
|
||||
| `joint_angle` | 8 | 7 joint angles + 1 gripper (−1 closed, +1 open, absolute) |
|
||||
| `ee_pose` | 7 | xyz + roll/pitch/yaw + gripper |
|
||||
|
||||
Set via `--env.action_space=joint_angle` (default) or `--env.action_space=ee_pose`.
|
||||
|
||||
## Platform Notes
|
||||
|
||||
- **Linux only**: ManiSkill requires SAPIEN/Vulkan. macOS and Windows are not supported.
|
||||
- **GPU recommended**: Rendering is CPU-capable but slow; CUDA + Vulkan gives full speed.
|
||||
- **gymnasium / numpy conflict**: See installation note above. Docker image handles this automatically.
|
||||
- **ManiSkill fork**: `robomme` depends on a specific ManiSkill fork (`YinpeiDai/ManiSkill`), pulled in automatically via the `robomme` package.
|
||||
223
docs/source/robotwin.mdx
Normal file
223
docs/source/robotwin.mdx
Normal file
@@ -0,0 +1,223 @@
|
||||
# RoboTwin 2.0
|
||||
|
||||
RoboTwin 2.0 is a **large-scale dual-arm manipulation benchmark** built on the SAPIEN physics engine. It provides a standardized evaluation protocol for bimanual robotic policies across 50 tasks (as of upstream `main`) with strong domain randomization (clutter, lighting, background, tabletop height, and language instructions).
|
||||
|
||||
- Paper: [RoboTwin 2.0: A Scalable Data Generator and Benchmark with Strong Domain Randomization for Robust Bimanual Robotic Manipulation](https://arxiv.org/abs/2506.18088)
|
||||
- GitHub: [RoboTwin-Platform/RoboTwin](https://github.com/RoboTwin-Platform/RoboTwin)
|
||||
- Leaderboard: [robotwin-platform.github.io/leaderboard](https://robotwin-platform.github.io/leaderboard)
|
||||
- Dataset: [lerobot/robotwin_unified](https://huggingface.co/datasets/lerobot/robotwin_unified)
|
||||
|
||||

|
||||
|
||||
## Overview
|
||||
|
||||
| Property | Value |
|
||||
| ------------- | -------------------------------------------------------- |
|
||||
| Tasks | 50 dual-arm manipulation tasks |
|
||||
| Robot | Aloha-AgileX bimanual (14 DOF, 7 per arm) |
|
||||
| Action space | 14-dim joint-space, continuous in `[-1, 1]` |
|
||||
| Cameras | `head_camera`, `left_camera`, `right_camera` |
|
||||
| Simulator | SAPIEN (not MuJoCo) |
|
||||
| Eval protocol | 100 episodes/task, 50 demo_clean demonstrations |
|
||||
| Eval settings | **Easy** (`demo_clean`) and **Hard** (`demo_randomized`) |
|
||||
|
||||
## Available tasks
|
||||
|
||||
RoboTwin 2.0 ships 50 dual-arm manipulation tasks in its upstream `envs/` directory. The canonical list is the `ROBOTWIN_TASKS` tuple in `src/lerobot/envs/robotwin.py`, mirrored verbatim from the upstream repo. Example tasks:
|
||||
|
||||
| Task | CLI name | Category |
|
||||
| ------------------------ | ------------------------ | ----------------- |
|
||||
| Beat block with hammer | `beat_block_hammer` | Tool use |
|
||||
| Click bell / alarm clock | `click_bell` | Precision press |
|
||||
| Stack blocks (2 / 3) | `stack_blocks_two/three` | Stacking |
|
||||
| Stack bowls (2 / 3) | `stack_bowls_two/three` | Stacking |
|
||||
| Handover block / mic | `handover_block` | Bimanual coord. |
|
||||
| Lift pot | `lift_pot` | Bimanual lift |
|
||||
| Shake bottle | `shake_bottle` | Continuous motion |
|
||||
| Turn switch | `turn_switch` | Articulated obj |
|
||||
| Stamp seal | `stamp_seal` | Precision place |
|
||||
| Scan object | `scan_object` | Mobile manip. |
|
||||
|
||||
Pass a comma-separated list to `--env.task` to run multiple tasks in a single eval sweep.
|
||||
|
||||
<Tip warning={true}>
|
||||
`open_laptop` is currently broken upstream (its `check_success()` uses
|
||||
`self.arm_tag`, which is only set inside the scripted-expert `play_once()`
|
||||
path and therefore unavailable during normal policy eval). Avoid it until the
|
||||
upstream bug is fixed, or patch the task to default `self.arm_tag = "left"` in
|
||||
`load_actors()`.
|
||||
</Tip>
|
||||
|
||||
## Dataset
|
||||
|
||||
The RoboTwin 2.0 dataset is available in **LeRobot v3.0 format** on the Hugging Face Hub:
|
||||
|
||||
```
|
||||
lerobot/robotwin_unified
|
||||
```
|
||||
|
||||
It contains over 100,000 pre-collected trajectories across all 50 tasks (79.6 GB, Apache 2.0 license). No format conversion is needed — it is already in the correct LeRobot v3.0 schema with video observations and action labels.
|
||||
|
||||
You can load it directly with the HF Datasets library:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("lerobot/robotwin_unified", split="train")
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
RoboTwin 2.0 requires **Linux** with an NVIDIA GPU (CUDA 12.1 recommended). Installation takes approximately 20 minutes.
|
||||
|
||||
### 1. Create a conda environment
|
||||
|
||||
```bash
|
||||
conda create -n robotwin python=3.10 -y
|
||||
conda activate robotwin
|
||||
```
|
||||
|
||||
### 2. Install LeRobot
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e "."
|
||||
```
|
||||
|
||||
### 3. Install RoboTwin 2.0
|
||||
|
||||
```bash
|
||||
git clone https://github.com/RoboTwin-Platform/RoboTwin.git
|
||||
cd RoboTwin
|
||||
bash script/_install.sh
|
||||
bash script/_download_assets.sh
|
||||
```
|
||||
|
||||
The install script handles all Python dependencies including SAPIEN, CuRobo, mplib, and pytorch3d.
|
||||
|
||||
<Tip warning={true}>
|
||||
If the automated install fails, install manually:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
||||
cd envs && git clone https://github.com/NVlabs/curobo.git && cd curobo
|
||||
pip install -e . --no-build-isolation
|
||||
```
|
||||
|
||||
Then apply the required mplib fix: in `mplib/planner.py` line 807, remove `or collide` from the conditional.
|
||||
|
||||
</Tip>
|
||||
|
||||
### 4. Add RoboTwin to PYTHONPATH
|
||||
|
||||
The RoboTwin task modules must be importable by LeRobot. From within the `RoboTwin/` directory:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH="${PYTHONPATH}:$(pwd)"
|
||||
```
|
||||
|
||||
Add this to your shell profile to make it permanent.
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Standard evaluation (recommended)
|
||||
|
||||
Evaluate a policy on a single task with the official protocol (100 episodes):
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-hf-policy-id" \
|
||||
--env.type=robotwin \
|
||||
--env.task=beat_block_hammer \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=100
|
||||
```
|
||||
|
||||
### Single-task quick check
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-hf-policy-id" \
|
||||
--env.type=robotwin \
|
||||
--env.task=beat_block_hammer \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=5
|
||||
```
|
||||
|
||||
### Multi-task sweep
|
||||
|
||||
Evaluate on several tasks in one run:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-hf-policy-id" \
|
||||
--env.type=robotwin \
|
||||
--env.task=beat_block_hammer,click_bell,handover_block,stack_blocks_two \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=100
|
||||
```
|
||||
|
||||
### Full benchmark (all 50 tasks)
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-hf-policy-id" \
|
||||
--env.type=robotwin \
|
||||
--env.task=adjust_bottle,beat_block_hammer,blocks_ranking_rgb,blocks_ranking_size,click_alarmclock,click_bell,dump_bin_bigbin,grab_roller,handover_block,handover_mic,hanging_mug,lift_pot,move_can_pot,move_pillbottle_pad,move_playingcard_away,move_stapler_pad,open_microwave,pick_diverse_bottles,pick_dual_bottles,place_a2b_left,place_a2b_right,place_bread_basket,place_bread_skillet,place_burger_fries,place_can_basket,place_cans_plasticbox,place_container_plate,place_dual_shoes,place_empty_cup,place_fan,place_mouse_pad,place_object_basket,place_object_scale,place_object_stand,place_phone_stand,place_shoe,press_stapler,put_bottles_dustbin,put_object_cabinet,rotate_qrcode,scan_object,shake_bottle,shake_bottle_horizontally,stack_blocks_three,stack_blocks_two,stack_bowls_three,stack_bowls_two,stamp_seal,turn_switch \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=100
|
||||
```
|
||||
|
||||
<Tip>
|
||||
`open_laptop` is intentionally omitted above because of the upstream
|
||||
`self.arm_tag` bug (see the **Available tasks** section). Re-add it once the
|
||||
upstream fix lands.
|
||||
</Tip>
|
||||
|
||||
## Camera configuration
|
||||
|
||||
By default, all three cameras are included:
|
||||
|
||||
| Camera key | Description |
|
||||
| -------------- | ------------------------------ |
|
||||
| `head_camera` | Torso-mounted overhead view |
|
||||
| `left_camera` | Left arm wrist-mounted camera |
|
||||
| `right_camera` | Right arm wrist-mounted camera |
|
||||
|
||||
To use a subset of cameras, override `--env.camera_names`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="your-hf-policy-id" \
|
||||
--env.type=robotwin \
|
||||
--env.task=beat_block_hammer \
|
||||
--env.camera_names="head_camera,left_camera" \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10
|
||||
```
|
||||
|
||||
## Environment config reference
|
||||
|
||||
Key parameters for `RoboTwinEnvConfig`:
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------- | ---------------------------------------- | ---------------------------------- |
|
||||
| `task` | `"beat_block_hammer"` | Comma-separated task name(s) |
|
||||
| `fps` | `25` | Simulation FPS |
|
||||
| `episode_length` | `300` | Max steps per episode |
|
||||
| `obs_type` | `"pixels_agent_pos"` | `"pixels"` or `"pixels_agent_pos"` |
|
||||
| `camera_names` | `"head_camera,left_camera,right_camera"` | Comma-separated active cameras |
|
||||
| `observation_height` | `240` | Camera pixel height |
|
||||
| `observation_width` | `320` | Camera pixel width |
|
||||
|
||||
## Leaderboard submission
|
||||
|
||||
Results can be submitted to the [RoboTwin 2.0 leaderboard](https://robotwin-platform.github.io/leaderboard). The official protocol requires:
|
||||
|
||||
- Training on 50 `demo_clean` demonstrations per task
|
||||
- Evaluating 100 episodes per task
|
||||
- Reporting success rate separately for **Easy** (`demo_clean`) and **Hard** (`demo_randomized`) settings
|
||||
|
||||
For submission instructions, refer to the [RoboTwin 2.0 documentation](https://robotwin-platform.github.io/doc/).
|
||||
@@ -34,7 +34,7 @@ pip install -e ".[smolvla]"
|
||||
|
||||
### Using RTC with Pi0
|
||||
|
||||
You can use `lerobot-rollout --strategy.type=base --inference.type=rtc` for RTC deployment on real robots.
|
||||
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
|
||||
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
||||
|
||||
```python
|
||||
@@ -137,12 +137,8 @@ The script generates a visualization of the denoising process, comparing standar
|
||||
## Testing RTC with a Real Robot
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=${HF_USERNAME}/policy_repo_id \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.execution_horizon=10 \
|
||||
--inference.rtc.max_guidance_weight=10.0 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
@@ -182,7 +178,7 @@ visualizer = RTCDebugVisualizer()
|
||||
# ... create plots
|
||||
```
|
||||
|
||||
See `examples/rtc/eval_dataset.py` for a complete example of offline RTC visualization.
|
||||
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
|
||||
|
||||
## References
|
||||
|
||||
|
||||
@@ -274,8 +274,7 @@ python src/lerobot/scripts/lerobot_train.py \
|
||||
Once trained, we recommend deploying policies using inference-time RTC:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=your-username/your-repo-id \
|
||||
--policy.device=cuda \
|
||||
--robot.type=unitree_g1 \
|
||||
@@ -285,7 +284,7 @@ lerobot-rollout \
|
||||
--task="task_description" \
|
||||
--duration=1000 \
|
||||
--fps=30 \
|
||||
--inference.type=rtc
|
||||
--rtc.enabled=true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
176
docs/source/vlabench.mdx
Normal file
176
docs/source/vlabench.mdx
Normal file
@@ -0,0 +1,176 @@
|
||||
# VLABench
|
||||
|
||||
[VLABench](https://github.com/OpenMOSS/VLABench) is a large-scale benchmark for **language-conditioned robotic manipulation with long-horizon reasoning**. The upstream suite covers 100 task categories across 2,000+ objects and evaluates six dimensions of robot intelligence: mesh & texture understanding, spatial reasoning, world-knowledge transfer, semantic instruction comprehension, physical-law understanding, and long-horizon planning. Built on MuJoCo / dm_control with a Franka Panda 7-DOF arm. LeRobot exposes **43 of these tasks** through `--env.task` (21 primitives + 22 composites, see [Available tasks](#available-tasks) below).
|
||||
|
||||
- Paper: [VLABench: A Large-Scale Benchmark for Language-Conditioned Robotics Manipulation with Long-Horizon Reasoning](https://arxiv.org/abs/2412.18194)
|
||||
- GitHub: [OpenMOSS/VLABench](https://github.com/OpenMOSS/VLABench)
|
||||
- Project website: [vlabench.github.io](https://vlabench.github.io)
|
||||
- Pretrained policy: [`lerobot/smolvla_vlabench`](https://huggingface.co/lerobot/smolvla_vlabench)
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/vlabench.png"
|
||||
alt="VLABench benchmark overview"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
## Available tasks
|
||||
|
||||
VLABench ships two task suites covering **43 task categories** in LeRobot's `--env.task` surface:
|
||||
|
||||
| Suite | CLI name | Tasks | Description |
|
||||
| --------- | ----------- | ----- | ---------------------------------------------------------------- |
|
||||
| Primitive | `primitive` | 21 | Single / few-skill combinations (select, insert, physics QA) |
|
||||
| Composite | `composite` | 22 | Multi-step reasoning and long-horizon planning (cook, rearrange) |
|
||||
|
||||
**Primitive tasks:** `select_fruit`, `select_toy`, `select_chemistry_tube`, `add_condiment`, `select_book`, `select_painting`, `select_drink`, `insert_flower`, `select_billiards`, `select_ingredient`, `select_mahjong`, `select_poker`, and physical-reasoning tasks (`density_qa`, `friction_qa`, `magnetism_qa`, `reflection_qa`, `simple_cuestick_usage`, `simple_seesaw_usage`, `sound_speed_qa`, `thermal_expansion_qa`, `weight_qa`).
|
||||
|
||||
**Composite tasks:** `cluster_billiards`, `cluster_book`, `cluster_drink`, `cluster_toy`, `cook_dishes`, `cool_drink`, `find_unseen_object`, `get_coffee`, `hammer_nail`, `heat_food`, `make_juice`, `play_mahjong`, `play_math_game`, `play_poker`, `play_snooker`, `rearrange_book`, `rearrange_chemistry_tube`, `set_dining_table`, `set_study_table`, `store_food`, `take_chemistry_experiment`, `use_seesaw_complex`.
|
||||
|
||||
`--env.task` accepts three forms:
|
||||
|
||||
- a single task name (`select_fruit`)
|
||||
- a comma-separated list (`select_fruit,heat_food`)
|
||||
- a suite shortcut (`primitive`, `composite`, or `primitive,composite`)
|
||||
|
||||
## Installation
|
||||
|
||||
VLABench is **not on PyPI** — its only distribution is the [OpenMOSS/VLABench](https://github.com/OpenMOSS/VLABench) GitHub repo — so LeRobot does not expose a `vlabench` extra. Install it manually as an editable clone, alongside the MuJoCo / dm_control pins VLABench needs, then fetch the mesh assets:
|
||||
|
||||
```bash
|
||||
# After following the standard LeRobot installation instructions.
|
||||
|
||||
git clone https://github.com/OpenMOSS/VLABench.git ~/VLABench
|
||||
git clone https://github.com/motion-planning/rrt-algorithms.git ~/rrt-algorithms
|
||||
pip install -e ~/VLABench -e ~/rrt-algorithms
|
||||
pip install "mujoco==3.2.2" "dm-control==1.0.22" \
|
||||
open3d colorlog scikit-learn openai gdown
|
||||
|
||||
python ~/VLABench/scripts/download_assets.py
|
||||
```
|
||||
|
||||
<Tip>
|
||||
VLABench requires Linux (`sys_platform == 'linux'`) and Python 3.10+. Set the MuJoCo rendering backend before running:
|
||||
|
||||
```bash
|
||||
export MUJOCO_GL=egl # for headless servers (HPC, cloud)
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
## Evaluation
|
||||
|
||||
All eval snippets below mirror the command CI runs (see `.github/workflows/benchmark_tests.yml`). The `--rename_map` argument maps VLABench's `image` / `second_image` / `wrist_image` camera keys onto the three-camera (`camera1` / `camera2` / `camera3`) input layout the released `smolvla_vlabench` policy was trained on.
|
||||
|
||||
### Single-task evaluation (recommended for quick iteration)
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=select_fruit \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={"observation.images.image": "observation.images.camera1", "observation.images.second_image": "observation.images.camera2", "observation.images.wrist_image": "observation.images.camera3"}'
|
||||
```
|
||||
|
||||
### Multi-task evaluation
|
||||
|
||||
Pass a comma-separated list of tasks:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=select_fruit,select_toy,add_condiment,heat_food \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
'--rename_map={"observation.images.image": "observation.images.camera1", "observation.images.second_image": "observation.images.camera2", "observation.images.wrist_image": "observation.images.camera3"}'
|
||||
```
|
||||
|
||||
### Suite-wide evaluation
|
||||
|
||||
Run an entire suite (all 21 primitives or all 22 composites):
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=primitive \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
--env.max_parallel_tasks=1 \
|
||||
'--rename_map={"observation.images.image": "observation.images.camera1", "observation.images.second_image": "observation.images.camera2", "observation.images.wrist_image": "observation.images.camera3"}'
|
||||
```
|
||||
|
||||
Or both suites:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=primitive,composite \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.use_async_envs=false \
|
||||
--policy.device=cuda \
|
||||
--env.max_parallel_tasks=1 \
|
||||
'--rename_map={"observation.images.image": "observation.images.camera1", "observation.images.second_image": "observation.images.camera2", "observation.images.wrist_image": "observation.images.camera3"}'
|
||||
```
|
||||
|
||||
### Recommended evaluation episodes
|
||||
|
||||
**10 episodes per task** for reproducible benchmarking (210 total for the full primitive suite, 220 for composite). Matches the protocol in the VLABench paper.
|
||||
|
||||
## Policy inputs and outputs
|
||||
|
||||
**Observations:**
|
||||
|
||||
- `observation.state` — 7-dim end-effector state (position xyz + Euler xyz + gripper)
|
||||
- `observation.images.image` — front camera, 480×480 HWC uint8
|
||||
- `observation.images.second_image` — second camera, 480×480 HWC uint8
|
||||
- `observation.images.wrist_image` — wrist camera, 480×480 HWC uint8
|
||||
|
||||
**Actions:**
|
||||
|
||||
- Continuous control in `Box(-1, 1, shape=(7,))` — 3D position + 3D Euler orientation + 1D gripper.
|
||||
|
||||
## Training
|
||||
|
||||
### Datasets
|
||||
|
||||
Pre-collected VLABench datasets in LeRobot format on the Hub:
|
||||
|
||||
- [`VLABench/vlabench_primitive_ft_lerobot_video`](https://huggingface.co/datasets/VLABench/vlabench_primitive_ft_lerobot_video) — 5,000 episodes, 128 tasks, 480×480 images.
|
||||
- [`VLABench/vlabench_composite_ft_lerobot_video`](https://huggingface.co/datasets/VLABench/vlabench_composite_ft_lerobot_video) — 5,977 episodes, 167 tasks, 224×224 images.
|
||||
|
||||
### Example training command
|
||||
|
||||
Fine-tune a SmolVLA base on the primitive suite:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=smolvla \
|
||||
--policy.repo_id=${HF_USER}/smolvla_vlabench_primitive \
|
||||
--policy.load_vlm_weights=true \
|
||||
--policy.push_to_hub=true \
|
||||
--dataset.repo_id=VLABench/vlabench_primitive_ft_lerobot_video \
|
||||
--env.type=vlabench \
|
||||
--env.task=select_fruit \
|
||||
--output_dir=./outputs/smolvla_vlabench_primitive \
|
||||
--steps=100000 \
|
||||
--batch_size=4 \
|
||||
--eval_freq=5000 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--save_freq=10000
|
||||
```
|
||||
|
||||
## Reproducing published results
|
||||
|
||||
The released checkpoint [`lerobot/smolvla_vlabench`](https://huggingface.co/lerobot/smolvla_vlabench) was trained on the primitive-suite dataset above and is evaluated with the [Single-task](#single-task-evaluation-recommended-for-quick-iteration) / [Suite-wide](#suite-wide-evaluation) commands. CI runs a 10-primitive-task smoke eval (one episode each) on every PR touching the benchmark.
|
||||
1184
examples/hil/hil_data_collection.py
Normal file
1184
examples/hil/hil_data_collection.py
Normal file
File diff suppressed because it is too large
Load Diff
226
examples/hil/hil_utils.py
Normal file
226
examples/hil/hil_utils.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared utilities for Human-in-the-Loop data collection scripts."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.processor import (
|
||||
IdentityProcessorStep,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILDatasetConfig:
|
||||
repo_id: str
|
||||
single_task: str
|
||||
root: str | Path | None = None
|
||||
fps: int = 30
|
||||
episode_time_s: float = 120
|
||||
num_episodes: int = 50
|
||||
video: bool = True
|
||||
push_to_hub: bool = True
|
||||
private: bool = False
|
||||
tags: list[str] | None = None
|
||||
num_image_writer_processes: int = 0
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
video_encoding_batch_size: int = 1
|
||||
vcodec: str = "auto"
|
||||
streaming_encoding: bool = True
|
||||
encoder_queue_maxsize: int = 30
|
||||
encoder_threads: int | None = None
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
def teleop_has_motor_control(teleop: Teleoperator) -> bool:
|
||||
"""Check if teleoperator has motor control capabilities."""
|
||||
return all(hasattr(teleop, attr) for attr in ("enable_torque", "disable_torque", "write_goal_positions"))
|
||||
|
||||
|
||||
def teleop_disable_torque(teleop: Teleoperator) -> None:
|
||||
"""Disable teleop torque if supported."""
|
||||
if hasattr(teleop, "disable_torque"):
|
||||
teleop.disable_torque()
|
||||
|
||||
|
||||
def teleop_enable_torque(teleop: Teleoperator) -> None:
|
||||
"""Enable teleop torque if supported."""
|
||||
if hasattr(teleop, "enable_torque"):
|
||||
teleop.enable_torque()
|
||||
|
||||
|
||||
def teleop_smooth_move_to(teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move teleop to target position if motor control is available."""
|
||||
if not teleop_has_motor_control(teleop):
|
||||
logger.warning("Teleop does not support motor control - cannot mirror robot position")
|
||||
return
|
||||
|
||||
teleop_enable_torque(teleop)
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.write_goal_positions(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""Initialize keyboard listener with HIL controls."""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"resume_policy": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
logger.warning("Headless environment - keyboard controls unavailable")
|
||||
return None, events
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key in [keyboard.Key.space, keyboard.Key.right]:
|
||||
logger.info("[HIL] Starting next episode...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording, pushing to hub...")
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] PAUSED - Press 'c' to take control or 'p' to resume policy")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
logger.info("[HIL] Taking control...")
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "p":
|
||||
if events["policy_paused"] or events["correction_active"]:
|
||||
logger.info("[HIL] Resuming policy...")
|
||||
events["resume_policy"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
logger.info("[HIL] End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
logger.info("[HIL] Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
logger.info("[HIL] ESC - Stop recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
logger.info(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
return listener, events
|
||||
|
||||
|
||||
def make_identity_processors():
|
||||
"""Create identity processors for recording."""
|
||||
teleop_proc = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
obs_proc = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[IdentityProcessorStep()],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
return teleop_proc, obs_proc
|
||||
|
||||
|
||||
def reset_loop(robot: Robot, teleop: Teleoperator, events: dict, fps: int):
|
||||
"""Reset period where human repositions environment."""
|
||||
logger.info("[HIL] RESET")
|
||||
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
logger.info("Press any key to enable teleoperation")
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
return
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop_disable_torque(teleop)
|
||||
logger.info("Teleop enabled - press any key to start episode")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
events["exit_early"] = False
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
events["resume_policy"] = False
|
||||
|
||||
|
||||
def print_controls(rtc: bool = False):
|
||||
"""Print control instructions."""
|
||||
mode = "Human-in-the-Loop Data Collection" + (" (RTC)" if rtc else "")
|
||||
logger.info(
|
||||
"%s\n Controls:\n"
|
||||
" SPACE - Pause policy\n"
|
||||
" c - Take control\n"
|
||||
" p - Resume policy after pause/correction\n"
|
||||
" → - End episode\n"
|
||||
" ESC - Stop and push to hub",
|
||||
mode,
|
||||
)
|
||||
@@ -14,21 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 2
|
||||
FPS = 30
|
||||
@@ -39,9 +35,6 @@ HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
@@ -90,67 +83,43 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot
|
||||
robot_action_to_send = robot_action_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
|
||||
@@ -45,6 +45,9 @@ def main():
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
@@ -74,10 +77,6 @@ def main():
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = (
|
||||
make_default_processors()
|
||||
)
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
@@ -88,14 +87,14 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -107,13 +106,13 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained policy on LeKiwi without recording (base rollout).
|
||||
|
||||
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||
control tick). For a CLI entry point with the same capabilities plus
|
||||
recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``.
|
||||
"""
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.robots.lekiwi import LeKiwiClientConfig
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
# Robot: LeKiwi client — make sure lekiwi_host is already running on the robot.
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Policy: load the pretrained config. ``pretrained_path`` is read downstream
|
||||
# by ``build_rollout_context`` to reload the full model.
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
# Assemble the rollout config: base strategy (no recording) + sync inference.
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
# Graceful Ctrl-C: the strategy loop exits when shutdown_event is set.
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
# Build the context (connects robot, loads policy, wires the inference strategy).
|
||||
# No custom processors here — LeKiwi runs on raw joint features.
|
||||
ctx = build_rollout_context(cfg, signal_handler.shutdown_event)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,17 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
@@ -38,12 +34,11 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -54,9 +49,6 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
@@ -151,67 +143,43 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot (EE -> joints via IK)
|
||||
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
@@ -222,6 +190,7 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -65,15 +65,14 @@ def main():
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint).
|
||||
# Build pipeline to convert phone action to EE action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
@@ -95,7 +94,7 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action (IK).
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
@@ -108,7 +107,7 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation (FK).
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -119,12 +118,13 @@ def main():
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||
# matches exactly what the pipelines produce at runtime.
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
@@ -163,14 +163,14 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -182,13 +182,13 @@ def main():
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained EE-space policy on SO100 (phone-trained) without recording.
|
||||
|
||||
Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained
|
||||
with phone teleoperation in EE space, so at deployment we only need the
|
||||
joint↔EE conversion on the robot side; the phone is not used.
|
||||
|
||||
Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig`
|
||||
(inline policy call). For recording during rollout, switch to Sentry,
|
||||
Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``.
|
||||
"""
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Peek at motor names once to build the kinematic solver.
|
||||
temp_robot = SO100Follower(robot_config)
|
||||
motor_names = list(temp_robot.bus.motors.keys())
|
||||
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
673
examples/rtc/eval_with_real_robot.py
Normal file
673
examples/rtc/eval_with_real_robot.py
Normal file
@@ -0,0 +1,673 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
||||
|
||||
This script demonstrates:
|
||||
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
||||
2. Consuming actions from the policy while the robot executes
|
||||
3. Periodically requesting new action chunks in the background using threads
|
||||
4. Managing action buffers and timing for real-time operation
|
||||
|
||||
For simulation environments, see eval_with_simulation.py
|
||||
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=<USER>/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot-data-collection/folding_final \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
||||
--robot.left_arm_config.port=can0 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.left_arm_config.can_interface=socketcan \
|
||||
--robot.left_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.left_arm_config.max_relative_target=8.0 \
|
||||
--robot.right_arm_config.port=can1 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.right_arm_config.can_interface=socketcan \
|
||||
--robot.right_arm_config.disable_torque_on_disconnect=true \
|
||||
--robot.right_arm_config.max_relative_target=8.0 \
|
||||
--task="Fold the T-shirt properly" \
|
||||
--fps=30 \
|
||||
--duration=2000 \
|
||||
--interpolation_multiplier=3 \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--device=cuda
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import PreTrainedConfig, RTCAttentionSchedule, parser
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
interpolation_multiplier: int = 1 # Control rate multiplier (1=off, 2=2x, 3=3x)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: Tensor,
|
||||
current_state: Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> Tensor:
|
||||
"""Convert absolute leftovers into model-space for relative-action RTC policies.
|
||||
|
||||
When a policy uses relative actions, the RTC prefix (leftover actions from
|
||||
the previous chunk) is stored in absolute space. Before feeding it back to
|
||||
the policy we need to re-express it relative to the *current* robot state
|
||||
and then re-normalize.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
# Only keep .pos joints + camera streams if the policy was trained on positions,
|
||||
# not the full pos/vel/torque state the robot exposes.
|
||||
observation_features_hw = {
|
||||
key: value
|
||||
for key, value in robot.observation_features().items()
|
||||
if key.endswith(".pos") or isinstance(value, tuple)
|
||||
}
|
||||
|
||||
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
# The stats are embedded in the processor .safetensors files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
if relative_step is not None:
|
||||
if relative_step.action_names is None:
|
||||
cfg_names = getattr(cfg.policy, "action_feature_names", None)
|
||||
if cfg_names:
|
||||
relative_step.action_names = list(cfg_names)
|
||||
else:
|
||||
relative_step.action_names = [
|
||||
k for k in robot.robot.action_features if k.endswith(".pos")
|
||||
]
|
||||
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Re-anchor leftover actions for relative-action policies.
|
||||
# We need the *postprocessed* (absolute) leftover, not the original
|
||||
# (normalized/relative) one that get_left_over() returns.
|
||||
if (
|
||||
prev_actions is not None
|
||||
and relative_step is not None
|
||||
and OBS_STATE in obs_with_policy_features
|
||||
):
|
||||
with action_queue.lock:
|
||||
if action_queue.queue is not None:
|
||||
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
|
||||
else:
|
||||
prev_actions_abs = None
|
||||
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_abs,
|
||||
current_state=obs_with_policy_features[OBS_STATE],
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Store original actions (before postprocessing) for RTC
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
|
||||
|
||||
action_count = 0
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
action_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with draccus configuration."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_pretrained_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processort, as by default if RTC disabled in the config
|
||||
# The processor won't be created
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
@@ -14,17 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.common.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.common.control_utils import init_keyboard_listener
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies import make_pre_post_processors
|
||||
from lerobot.policies.act import ACTPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
make_default_teleop_action_processor,
|
||||
@@ -38,12 +34,11 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
NUM_EPISODES = 5
|
||||
FPS = 30
|
||||
@@ -54,9 +49,6 @@ HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
# NOTE: For production policy deployment, use `lerobot-rollout` CLI instead.
|
||||
# This script provides a self-contained example for educational purposes.
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
@@ -151,67 +143,43 @@ def main():
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
control_interval = 1 / FPS
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Inline evaluation loop: predict actions and send to robot
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < EPISODE_TIME_SEC:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
# Get robot observation
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_joints_to_ee_pose_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Predict action using the policy
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=policy.config.device,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.device.type == "cuda",
|
||||
task=TASK_DESCRIPTION,
|
||||
robot_type=robot.name,
|
||||
)
|
||||
|
||||
# Convert policy output to robot action dict
|
||||
action_values = make_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Process and send action to robot (EE -> joints via IK)
|
||||
robot_action_to_send = robot_ee_to_joints_processor((action_values, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": TASK_DESCRIPTION}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
sleep_time_s = control_interval - dt_s
|
||||
if sleep_time_s < 0:
|
||||
logging.warning(
|
||||
f"Evaluate loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({FPS} Hz)."
|
||||
)
|
||||
precise_sleep(max(sleep_time_s, 0.0))
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
log_say("Waiting for environment reset, press right arrow key when ready...")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
@@ -222,6 +190,7 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -62,20 +62,21 @@ def main():
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation.
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -86,7 +87,7 @@ def main():
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action.
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
@@ -97,9 +98,9 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints (with safety bounds).
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
@@ -114,12 +115,13 @@ def main():
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset, deriving features from the pipelines so the on-disk schema
|
||||
# matches exactly what the pipelines produce at runtime.
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
@@ -142,7 +144,7 @@ def main():
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_so100_ee")
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
try:
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
@@ -158,14 +160,14 @@ def main():
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
@@ -177,13 +179,13 @@ def main():
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run a trained EE-space policy on SO100 without recording (base rollout).
|
||||
|
||||
Uses the rollout engine's :class:`BaseStrategy` (autonomous execution,
|
||||
no dataset) with :class:`SyncInferenceConfig` (inline policy call per
|
||||
control tick). The custom observation/action processors convert between
|
||||
joint space (robot hardware) and end-effector space (policy I/O) via
|
||||
forward/inverse kinematics.
|
||||
"""
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
RobotProcessorPipeline,
|
||||
observation_to_transition,
|
||||
robot_action_observation_to_transition,
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
from lerobot.rollout import BaseStrategyConfig, RolloutConfig, build_rollout_context
|
||||
from lerobot.rollout.inference import SyncInferenceConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
FPS = 30
|
||||
DURATION_SEC = 60
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
|
||||
# Robot configuration — the rollout engine will connect it inside build_rollout_context.
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Kinematic solver: we need the motor-name list, so peek at the robot once.
|
||||
# (The rollout engine owns the connected instance; we only use this for introspection.)
|
||||
temp_robot = SO100Follower(robot_config)
|
||||
motor_names = list(temp_robot.bus.motors.keys())
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
# Joint-space observation → EE-space observation (consumed by the policy).
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# EE-space action (produced by the policy) → joint-space action (sent to robot).
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Policy config (full model is loaded inside build_rollout_context).
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
cfg = RolloutConfig(
|
||||
robot=robot_config,
|
||||
policy=policy_config,
|
||||
strategy=BaseStrategyConfig(),
|
||||
inference=SyncInferenceConfig(),
|
||||
fps=FPS,
|
||||
duration=DURATION_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True)
|
||||
|
||||
# Pass the EE kinematic processors via kwargs; the defaults (identity) would
|
||||
# otherwise skip the joint↔EE conversion and the policy would receive the
|
||||
# wrong observation/action space.
|
||||
ctx = build_rollout_context(
|
||||
cfg,
|
||||
signal_handler.shutdown_event,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
strategy = BaseStrategy(cfg.strategy)
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
strategy.run(ctx)
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -212,6 +212,20 @@ aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
# release), so any `vlabench>=X` pip spec is unresolvable. Install it
|
||||
# manually alongside MuJoCo / dm-control — see docs/source/vlabench.mdx
|
||||
# for the recipe.
|
||||
# NOTE: robomme is NOT a pyproject extra — mani-skill hard-pins numpy<2
|
||||
# which conflicts with lerobot's numpy>=2 base pin, so the two trees can't
|
||||
# resolve into a single env. Install it only in the RoboMME Docker image
|
||||
# via `uv pip install --override` (see docker/Dockerfile.benchmark.robomme).
|
||||
# NOTE: robocasa is NOT exposed as a `lerobot` extra. Its setup.py pins
|
||||
# `lerobot==0.3.3` in install_requires, which cyclically shadows our own
|
||||
# workspace `lerobot` and makes the graph unsolvable under any resolver
|
||||
# (uv, pip). Install it manually alongside robosuite — see
|
||||
# docs/source/robocasa.mdx for the recipe.
|
||||
|
||||
# All
|
||||
all = [
|
||||
@@ -275,7 +289,6 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-rollout="lerobot.scripts.lerobot_rollout:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
|
||||
@@ -31,9 +31,23 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# LIBERO-plus derives task.language by space-joining the perturbation-variant
|
||||
# filename (grab_language_from_filename in libero/libero/benchmark/__init__.py),
|
||||
# so non-_language_ variants inherit a trailing metadata blob like
|
||||
# "view 0 0 100 0 0 initstate 0 noise 45" or "add 16". Strip those tokens so
|
||||
# the description matches the base instruction used in the training dataset.
|
||||
_LIBERO_PERTURBATION_TAIL_RE = re.compile(
|
||||
r"(?:\s(?:view|initstate|noise|add|tb|table|light|level)(?:\s\d+)+)+$"
|
||||
)
|
||||
|
||||
|
||||
def _strip_libero_perturbation_tail(instruction: str) -> str:
|
||||
return _LIBERO_PERTURBATION_TAIL_RE.sub("", instruction).strip()
|
||||
|
||||
|
||||
def _libero_descriptions(task_suite: str) -> dict[str, str]:
|
||||
from libero.libero import benchmark # type: ignore[import-untyped]
|
||||
@@ -47,7 +61,10 @@ def _libero_descriptions(task_suite: str) -> dict[str, str]:
|
||||
)
|
||||
return {}
|
||||
suite = suite_dict[task_suite]()
|
||||
return {f"{task_suite}_{i}": suite.get_task(i).language for i in range(suite.n_tasks)}
|
||||
return {
|
||||
f"{task_suite}_{i}": _strip_libero_perturbation_tail(suite.get_task(i).language)
|
||||
for i in range(suite.n_tasks)
|
||||
}
|
||||
|
||||
|
||||
def _metaworld_descriptions(task_name: str) -> dict[str, str]:
|
||||
@@ -57,19 +74,120 @@ def _metaworld_descriptions(task_name: str) -> dict[str, str]:
|
||||
return {f"{task_name}_0": label}
|
||||
|
||||
|
||||
def _robotwin_descriptions(task_names: str) -> dict[str, str]:
|
||||
"""Return descriptions for each requested RoboTwin task. Reads
|
||||
`description/task_instruction/<task>.json` from the RoboTwin clone
|
||||
(cwd is /opt/robotwin in CI). Falls back to the task name if missing."""
|
||||
out: dict[str, str] = {}
|
||||
root = Path("description/task_instruction")
|
||||
for name in (t.strip() for t in task_names.split(",") if t.strip()):
|
||||
desc_file = root / f"{name}.json"
|
||||
desc = name.replace("_", " ")
|
||||
if desc_file.is_file():
|
||||
data = json.loads(desc_file.read_text())
|
||||
full = data.get("full_description") or desc
|
||||
# Strip the schema placeholders ({A}, {a}) — keep the sentence readable.
|
||||
desc = full.replace("<", "").replace(">", "")
|
||||
out[f"{name}_0"] = desc
|
||||
return out
|
||||
|
||||
|
||||
def _robocasa_descriptions(task_spec: str) -> dict[str, str]:
|
||||
"""For each task in the comma-separated list, emit a cleaned-name label.
|
||||
|
||||
RoboCasa episodes carry their language instruction in the env's
|
||||
`ep_meta['lang']`, populated per reset. Pulling it requires spinning
|
||||
up the full kitchen env per task (~seconds each); we use the task
|
||||
name as the key here and let the eval's episode info carry the
|
||||
actual instruction.
|
||||
"""
|
||||
out: dict[str, str] = {}
|
||||
for task in (t.strip() for t in task_spec.split(",") if t.strip()):
|
||||
# Split CamelCase into words: "CloseFridge" → "close fridge".
|
||||
label = "".join(f" {c.lower()}" if c.isupper() else c for c in task).strip()
|
||||
out[f"{task}_0"] = label or task
|
||||
return out
|
||||
|
||||
|
||||
_ROBOMME_DESCRIPTIONS = {
|
||||
"BinFill": "Fill the target bin with the correct number of cubes",
|
||||
"PickXtimes": "Pick the indicated cube the specified number of times",
|
||||
"SwingXtimes": "Swing the object the specified number of times",
|
||||
"StopCube": "Grasp and stop the moving cube",
|
||||
"VideoUnmask": "Pick the cube shown in the reference video",
|
||||
"VideoUnmaskSwap": "Pick the cube matching the reference video after a swap",
|
||||
"ButtonUnmask": "Press the button indicated by the reference",
|
||||
"ButtonUnmaskSwap": "Press the correct button after objects are swapped",
|
||||
"PickHighlight": "Pick the highlighted cube",
|
||||
"VideoRepick": "Repick the cube shown in the reference video",
|
||||
"VideoPlaceButton": "Place the cube on the button shown in the video",
|
||||
"VideoPlaceOrder": "Place cubes in the order shown in the video",
|
||||
"MoveCube": "Move the cube to the target location",
|
||||
"InsertPeg": "Insert the peg into the target hole",
|
||||
"PatternLock": "Unlock the pattern by pressing buttons in sequence",
|
||||
"RouteStick": "Route the stick through the required waypoints",
|
||||
}
|
||||
|
||||
|
||||
def _robomme_descriptions(task_names: str, task_ids: list[int] | None = None) -> dict[str, str]:
|
||||
"""Return descriptions for each requested RoboMME task. Keys match the
|
||||
video filename pattern `<task>_<task_id>` used by the eval script."""
|
||||
if task_ids is None:
|
||||
task_ids = [0]
|
||||
out: dict[str, str] = {}
|
||||
for name in (t.strip() for t in task_names.split(",") if t.strip()):
|
||||
desc = _ROBOMME_DESCRIPTIONS.get(name, name)
|
||||
for tid in task_ids:
|
||||
out[f"{name}_{tid}"] = desc
|
||||
return out
|
||||
|
||||
|
||||
def _vlabench_descriptions(task_spec: str) -> dict[str, str]:
|
||||
"""For each task in the comma-separated list, emit a cleaned-name label.
|
||||
|
||||
VLABench tasks carry language instructions on their dm_control task
|
||||
object, but pulling them requires loading the full env per task
|
||||
(~seconds each). The CI smoke-eval already captures the instruction
|
||||
inside its episode info; this mapping is just enough to key
|
||||
`metrics.json` by `<task>_0`.
|
||||
"""
|
||||
out: dict[str, str] = {}
|
||||
for task in (t.strip() for t in task_spec.split(",") if t.strip()):
|
||||
out[f"{task}_0"] = task.replace("_", " ").strip()
|
||||
return out
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--env", required=True, help="Environment family (libero, metaworld, ...)")
|
||||
parser.add_argument("--task", required=True, help="Task/suite name (e.g. libero_spatial)")
|
||||
parser.add_argument(
|
||||
"--task-ids",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Comma-separated task IDs (e.g. '0,1,2'). Default: [0]",
|
||||
)
|
||||
parser.add_argument("--output", required=True, help="Path to write task_descriptions.json")
|
||||
args = parser.parse_args()
|
||||
|
||||
task_ids: list[int] | None = None
|
||||
if args.task_ids:
|
||||
task_ids = [int(x.strip()) for x in args.task_ids.split(",")]
|
||||
|
||||
descriptions: dict[str, str] = {}
|
||||
try:
|
||||
if args.env == "libero":
|
||||
if args.env == ("libero", "libero_plus"):
|
||||
descriptions = _libero_descriptions(args.task)
|
||||
elif args.env == "metaworld":
|
||||
descriptions = _metaworld_descriptions(args.task)
|
||||
elif args.env == "robotwin":
|
||||
descriptions = _robotwin_descriptions(args.task)
|
||||
elif args.env == "robocasa":
|
||||
descriptions = _robocasa_descriptions(args.task)
|
||||
elif args.env == "robomme":
|
||||
descriptions = _robomme_descriptions(args.task, task_ids=task_ids)
|
||||
elif args.env == "vlabench":
|
||||
descriptions = _vlabench_descriptions(args.task)
|
||||
else:
|
||||
print(
|
||||
f"[extract_task_descriptions] No description extractor for env '{args.env}'.",
|
||||
|
||||
@@ -21,7 +21,6 @@ are intentionally NOT re-exported here to avoid circular dependencies
|
||||
Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
"""
|
||||
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .types import (
|
||||
@@ -40,7 +39,6 @@ __all__ = [
|
||||
"PolicyFeature",
|
||||
"RTCAttentionSchedule",
|
||||
# Config classes
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"PeftConfig",
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str = ""
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str = ""
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.repo_id:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.repo_id = f"{self.repo_id}_{timestamp}"
|
||||
@@ -16,7 +16,7 @@ import datetime as dt
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -58,6 +58,8 @@ class TrainPipelineConfig(HubMixin):
|
||||
batch_size: int = 8
|
||||
prefetch_factor: int = 4
|
||||
persistent_workers: bool = True
|
||||
profile_mode: Literal["off", "summary", "trace"] = "off"
|
||||
profile_output_dir: Path | None = None
|
||||
steps: int = 100_000
|
||||
eval_freq: int = 20_000
|
||||
log_freq: int = 200
|
||||
@@ -130,9 +132,15 @@ class TrainPipelineConfig(HubMixin):
|
||||
now = dt.datetime.now()
|
||||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
||||
self.output_dir = Path("outputs/train") / train_dir
|
||||
if self.profile_mode != "off" and self.profile_output_dir is None:
|
||||
self.profile_output_dir = self.output_dir / "profiling"
|
||||
|
||||
if isinstance(self.dataset.repo_id, list):
|
||||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
||||
if self.profile_mode not in {"off", "summary", "trace"}:
|
||||
raise ValueError(
|
||||
f"`profile_mode` must be one of 'off', 'summary', or 'trace', got {self.profile_mode}."
|
||||
)
|
||||
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
|
||||
@@ -71,8 +71,8 @@ class ForwardCompatibilityError(CompatibilityError):
|
||||
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 50 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
|
||||
@@ -331,6 +331,7 @@ class LiberoEnv(EnvConfig):
|
||||
camera_name_mapping: dict[str, str] | None = None
|
||||
observation_height: int = 360
|
||||
observation_width: int = 360
|
||||
is_libero_plus: bool = False
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
@@ -432,6 +433,7 @@ class LiberoEnv(EnvConfig):
|
||||
control_mode=self.control_mode,
|
||||
episode_length=self.episode_length,
|
||||
camera_name_mapping=self.camera_name_mapping,
|
||||
is_libero_plus=self.is_libero_plus,
|
||||
)
|
||||
|
||||
def get_env_processors(self):
|
||||
@@ -496,6 +498,146 @@ class MetaworldEnv(EnvConfig):
|
||||
)
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robocasa")
|
||||
@dataclass
|
||||
class RoboCasaEnv(EnvConfig):
|
||||
task: str = "CloseFridge"
|
||||
fps: int = 20
|
||||
episode_length: int = 1000
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
camera_name: str = "robot0_agentview_left,robot0_eye_in_hand,robot0_agentview_right"
|
||||
observation_height: int = 256
|
||||
observation_width: int = 256
|
||||
visualization_height: int = 512
|
||||
visualization_width: int = 512
|
||||
split: str | None = None
|
||||
# Object-mesh registries to sample from. Upstream default is
|
||||
# ("objaverse", "lightwheel"), but objaverse is ~30GB and the CI image
|
||||
# only ships the lightwheel pack. Override to include objaverse once
|
||||
# you've run `python -m robocasa.scripts.download_kitchen_assets
|
||||
# --type objaverse` locally.
|
||||
obj_registries: list[str] = field(default_factory=lambda: ["lightwheel"])
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,))}
|
||||
)
|
||||
features_map: dict[str, str] = field(default_factory=lambda: {ACTION: ACTION, "agent_pos": OBS_STATE})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type not in ("pixels", "pixels_agent_pos"):
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
|
||||
# Preserve raw RoboCasa camera names end-to-end (e.g.
|
||||
# `observation.images.robot0_agentview_left`). This matches the
|
||||
# naming convention used by the RoboCasa datasets on the Hub, so
|
||||
# trained policies don't need a `--rename_map` at eval time.
|
||||
cams = [c.strip() for c in self.camera_name.split(",") if c.strip()]
|
||||
for cam in cams:
|
||||
self.features[f"pixels/{cam}"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
)
|
||||
self.features_map[f"pixels/{cam}"] = f"{OBS_IMAGES}.{cam}"
|
||||
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
kwargs: dict[str, Any] = {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"observation_height": self.observation_height,
|
||||
"observation_width": self.observation_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"visualization_width": self.visualization_width,
|
||||
}
|
||||
if self.split is not None:
|
||||
kwargs["split"] = self.split
|
||||
return kwargs
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||
from .robocasa import create_robocasa_envs
|
||||
|
||||
if self.task is None:
|
||||
raise ValueError("RoboCasaEnv requires a task to be specified")
|
||||
env_cls = _make_vec_env_cls(use_async_envs, n_envs)
|
||||
return create_robocasa_envs(
|
||||
task=self.task,
|
||||
n_envs=n_envs,
|
||||
camera_name=self.camera_name,
|
||||
gym_kwargs=self.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
episode_length=self.episode_length,
|
||||
obj_registries=tuple(self.obj_registries),
|
||||
)
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("vlabench")
|
||||
@dataclass
|
||||
class VLABenchEnv(EnvConfig):
|
||||
task: str = "select_fruit"
|
||||
fps: int = 10
|
||||
episode_length: int = 500
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
render_resolution: tuple[int, int] = (480, 480)
|
||||
robot: str = "franka"
|
||||
action_mode: str = "eef"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels/image": f"{OBS_IMAGES}.image",
|
||||
"pixels/second_image": f"{OBS_IMAGES}.second_image",
|
||||
"pixels/wrist_image": f"{OBS_IMAGES}.wrist_image",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
h, w = self.render_resolution
|
||||
if self.obs_type == "pixels":
|
||||
self.features["pixels/image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(h, w, 3))
|
||||
self.features["pixels/second_image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(h, w, 3))
|
||||
self.features["pixels/wrist_image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(h, w, 3))
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels/image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(h, w, 3))
|
||||
self.features["pixels/second_image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(h, w, 3))
|
||||
self.features["pixels/wrist_image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(h, w, 3))
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(7,))
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"render_resolution": self.render_resolution,
|
||||
"robot": self.robot,
|
||||
"max_episode_steps": self.episode_length,
|
||||
"action_mode": self.action_mode,
|
||||
}
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = False):
|
||||
from .vlabench import create_vlabench_envs
|
||||
|
||||
if self.task is None:
|
||||
raise ValueError("VLABenchEnv requires a task to be specified")
|
||||
env_cls = _make_vec_env_cls(use_async_envs, n_envs)
|
||||
return create_vlabench_envs(
|
||||
task=self.task,
|
||||
n_envs=n_envs,
|
||||
gym_kwargs=self.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("isaaclab_arena")
|
||||
@dataclass
|
||||
class IsaaclabArenaEnv(HubEnvConfig):
|
||||
@@ -574,3 +716,171 @@ class IsaaclabArenaEnv(HubEnvConfig):
|
||||
),
|
||||
PolicyProcessorPipeline(steps=[]),
|
||||
)
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero_plus")
|
||||
@dataclass
|
||||
class LiberoPlusEnv(LiberoEnv):
|
||||
"""Config for LIBERO-plus robustness benchmark evaluation.
|
||||
|
||||
LIBERO-plus extends LIBERO with 7 perturbation dimensions (camera viewpoints,
|
||||
object layouts, robot initial states, language instructions, lighting, background
|
||||
textures, sensor noise) producing ~10k task variants.
|
||||
|
||||
The gym interface is identical to LIBERO so this class reuses ``LiberoEnv``
|
||||
entirely — only the registered name and default task suite differ.
|
||||
|
||||
Install: see docker/Dockerfile.benchmark.libero_plus — LIBERO-plus ships
|
||||
as a namespace package from a git fork and must be cloned + PYTHONPATH'd
|
||||
rather than installed as a pyproject extra.
|
||||
|
||||
See Also:
|
||||
https://github.com/sylvestf/LIBERO-plus
|
||||
"""
|
||||
|
||||
task: str = "libero_spatial"
|
||||
is_libero_plus: bool = True
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robotwin")
|
||||
@dataclass
|
||||
class RoboTwinEnvConfig(EnvConfig):
|
||||
"""Configuration for RoboTwin 2.0 benchmark environments.
|
||||
|
||||
RoboTwin 2.0 is a dual-arm manipulation benchmark with 50 tasks built on the
|
||||
SAPIEN simulator. The robot is an Aloha-AgileX bimanual platform with 14 DOF
|
||||
(7 per arm). All three cameras are enabled by default.
|
||||
|
||||
See: https://robotwin-platform.github.io
|
||||
Dataset: https://huggingface.co/datasets/lerobot/robotwin_unified
|
||||
"""
|
||||
|
||||
task: str = "beat_block_hammer" # single task or comma-separated list
|
||||
fps: int = 25
|
||||
episode_length: int = 300
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
# Available cameras from RoboTwin's aloha-agilex embodiment: head_camera
|
||||
# (torso-mounted) + left_camera / right_camera (wrists).
|
||||
camera_names: str = "head_camera,left_camera,right_camera"
|
||||
# Match the D435 dims in task_config/demo_clean.yml (_camera_config.yml).
|
||||
# Gym's vector-env concatenate pre-allocates buffers of this shape, so it
|
||||
# must equal what SAPIEN actually renders.
|
||||
observation_height: int = 240
|
||||
observation_width: int = 320
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"pixels/head_camera": f"{OBS_IMAGES}.head_camera",
|
||||
"pixels/left_camera": f"{OBS_IMAGES}.left_camera",
|
||||
"pixels/right_camera": f"{OBS_IMAGES}.right_camera",
|
||||
"agent_pos": OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
cam_list = [c.strip() for c in self.camera_names.split(",") if c.strip()]
|
||||
for cam in cam_list:
|
||||
self.features[f"pixels/{cam}"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
)
|
||||
# Keep features_map entry if already set (default_factory); add if missing.
|
||||
key = f"pixels/{cam}"
|
||||
if key not in self.features_map:
|
||||
self.features_map[key] = f"{OBS_IMAGES}.{cam}"
|
||||
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(14,), # 14 DOF: 7 per arm
|
||||
)
|
||||
elif self.obs_type != "pixels":
|
||||
raise ValueError(
|
||||
f"Unsupported obs_type '{self.obs_type}'. "
|
||||
"RoboTwinEnvConfig supports 'pixels' and 'pixels_agent_pos'."
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = True):
|
||||
from lerobot.envs.robotwin import create_robotwin_envs
|
||||
|
||||
if not self.task:
|
||||
raise ValueError("RoboTwinEnvConfig requires `task` to be specified.")
|
||||
|
||||
env_cls = _make_vec_env_cls(use_async_envs, n_envs)
|
||||
cam_list = [c.strip() for c in self.camera_names.split(",") if c.strip()]
|
||||
return create_robotwin_envs(
|
||||
task=self.task,
|
||||
n_envs=n_envs,
|
||||
env_cls=env_cls,
|
||||
camera_names=cam_list,
|
||||
observation_height=self.observation_height,
|
||||
observation_width=self.observation_width,
|
||||
episode_length=self.episode_length,
|
||||
)
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robomme")
|
||||
@dataclass
|
||||
class RoboMMEEnv(EnvConfig):
|
||||
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
|
||||
|
||||
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
|
||||
Dataset: lerobot/robomme (LeRobot v3.0, 1,600 episodes).
|
||||
Benchmark: https://github.com/RoboMME/robomme_benchmark
|
||||
|
||||
Requires the `robomme` git package installed separately (Linux only);
|
||||
see docker/Dockerfile.benchmark.robomme for the canonical install.
|
||||
"""
|
||||
|
||||
task: str = "PickXtimes"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
action_space: str = "joint_angle" # or "ee_pose" (7-D)
|
||||
dataset_split: str = "test" # "train" | "val" | "test"
|
||||
task_ids: list[int] | None = None
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"pixels/image": f"{OBS_IMAGES}.image",
|
||||
"pixels/wrist_image": f"{OBS_IMAGES}.wrist_image",
|
||||
"agent_pos": OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
action_dim = 8 if self.action_space == "joint_angle" else 7
|
||||
self.features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||
"pixels/image": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
"pixels/wrist_image": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||
}
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
def create_envs(self, n_envs: int, use_async_envs: bool = True):
|
||||
from lerobot.envs.robomme import create_robomme_envs
|
||||
|
||||
env_cls = _make_vec_env_cls(use_async_envs, n_envs)
|
||||
return create_robomme_envs(
|
||||
task=self.task,
|
||||
n_envs=n_envs,
|
||||
action_space_type=self.action_space,
|
||||
dataset=self.dataset_split,
|
||||
episode_length=self.episode_length,
|
||||
task_ids=self.task_ids,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
@@ -31,20 +32,7 @@ from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.types import RobotObservation
|
||||
|
||||
from .utils import _LazyAsyncVectorEnv
|
||||
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
if isinstance(camera_name, str):
|
||||
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||
elif isinstance(camera_name, (list | tuple)):
|
||||
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
||||
else:
|
||||
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
||||
if not cams:
|
||||
raise ValueError("camera_name resolved to an empty list.")
|
||||
return cams
|
||||
from .utils import _LazyAsyncVectorEnv, parse_camera_names
|
||||
|
||||
|
||||
def _get_suite(name: str) -> benchmark.Benchmark:
|
||||
@@ -69,14 +57,34 @@ def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[i
|
||||
return ids
|
||||
|
||||
|
||||
def get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
|
||||
init_states_path = (
|
||||
Path(get_libero_path("init_states"))
|
||||
/ task_suite.tasks[i].problem_folder
|
||||
/ task_suite.tasks[i].init_states_file
|
||||
)
|
||||
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
|
||||
return init_states
|
||||
# LIBERO-plus perturbation variants encode the perturbation in the filename
|
||||
# but on disk only the base `.pruned_init` exists — strip the suffix to match
|
||||
# LIBERO-plus's own suite.get_task_init_states() (we reimplement it here so we
|
||||
# can pass weights_only=False for PyTorch 2.6+ numpy pickles).
|
||||
_LIBERO_PERTURBATION_SUFFIX_RE = re.compile(r"_(?:language|view|light)_[^.]*|_(?:table|tb)_\d+")
|
||||
|
||||
|
||||
def get_task_init_states(task_suite: Any, i: int, is_libero_plus: bool = False) -> np.ndarray:
|
||||
task = task_suite.tasks[i]
|
||||
filename = Path(task.init_states_file)
|
||||
root = Path(get_libero_path("init_states"))
|
||||
|
||||
if not is_libero_plus:
|
||||
init_states_path = root / task.problem_folder / filename.name
|
||||
return torch.load(init_states_path, weights_only=False) # nosec B614
|
||||
|
||||
# LIBERO-plus: `_add_` / `_level` variants store extra-object layouts under
|
||||
# libero_newobj/ as a flat array that must be reshaped to (1, -1).
|
||||
if "_add_" in filename.name or "_level" in filename.name:
|
||||
init_states_path = root / "libero_newobj" / task.problem_folder / filename.name
|
||||
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
|
||||
return init_states.reshape(1, -1)
|
||||
|
||||
# LIBERO-plus perturbation variants encode the perturbation in the filename
|
||||
# but on disk only the base `.pruned_init` exists — strip the suffix to match.
|
||||
stripped = _LIBERO_PERTURBATION_SUFFIX_RE.sub("", filename.stem) + filename.suffix
|
||||
init_states_path = root / task.problem_folder / stripped
|
||||
return torch.load(init_states_path, weights_only=False) # nosec B614
|
||||
|
||||
|
||||
def get_libero_dummy_action():
|
||||
@@ -118,9 +126,11 @@ class LiberoEnv(gym.Env):
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
num_steps_wait: int = 10,
|
||||
control_mode: str = "relative",
|
||||
is_libero_plus: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.task_id = task_id
|
||||
self.is_libero_plus = is_libero_plus
|
||||
self.obs_type = obs_type
|
||||
self.render_mode = render_mode
|
||||
self.observation_width = observation_width
|
||||
@@ -128,7 +138,7 @@ class LiberoEnv(gym.Env):
|
||||
self.visualization_width = visualization_width
|
||||
self.visualization_height = visualization_height
|
||||
self.init_states = init_states
|
||||
self.camera_name = _parse_camera_names(
|
||||
self.camera_name = parse_camera_names(
|
||||
camera_name
|
||||
) # agentview_image (main) or robot0_eye_in_hand_image (wrist)
|
||||
|
||||
@@ -147,7 +157,11 @@ class LiberoEnv(gym.Env):
|
||||
self.episode_index = episode_index
|
||||
self.episode_length = episode_length
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._init_states = (
|
||||
get_task_init_states(task_suite, self.task_id, is_libero_plus=self.is_libero_plus)
|
||||
if self.init_states
|
||||
else None
|
||||
)
|
||||
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
|
||||
|
||||
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
@@ -380,6 +394,7 @@ def _make_env_fns(
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
control_mode: str,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
is_libero_plus: bool = False,
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
|
||||
@@ -396,6 +411,7 @@ def _make_env_fns(
|
||||
n_envs=n_envs,
|
||||
control_mode=control_mode,
|
||||
camera_name_mapping=camera_name_mapping,
|
||||
is_libero_plus=is_libero_plus,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -418,6 +434,7 @@ def create_libero_envs(
|
||||
control_mode: str = "relative",
|
||||
episode_length: int | None = None,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
is_libero_plus: bool = False,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized LIBERO environments with a consistent return shape.
|
||||
@@ -437,7 +454,7 @@ def create_libero_envs(
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
task_ids_filter = gym_kwargs.pop("task_ids", None) # optional: limit to specific tasks
|
||||
|
||||
camera_names = _parse_camera_names(camera_name)
|
||||
camera_names = parse_camera_names(camera_name)
|
||||
suite_names = [s.strip() for s in str(task).split(",") if s.strip()]
|
||||
if not suite_names:
|
||||
raise ValueError("`task` must contain at least one LIBERO suite name.")
|
||||
@@ -462,6 +479,7 @@ def create_libero_envs(
|
||||
# Probe once and reuse to avoid creating a temp env per task.
|
||||
cached_obs_space: spaces.Space | None = None
|
||||
cached_act_space: spaces.Space | None = None
|
||||
cached_metadata: dict[str, Any] | None = None
|
||||
|
||||
for tid in selected:
|
||||
fns = _make_env_fns(
|
||||
@@ -475,12 +493,14 @@ def create_libero_envs(
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
camera_name_mapping=camera_name_mapping,
|
||||
is_libero_plus=is_libero_plus,
|
||||
)
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space)
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
if cached_obs_space is None:
|
||||
cached_obs_space = lazy.observation_space
|
||||
cached_act_space = lazy.action_space
|
||||
cached_metadata = lazy.metadata
|
||||
out[suite_name][tid] = lazy
|
||||
else:
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
|
||||
@@ -311,6 +311,7 @@ def create_metaworld_envs(
|
||||
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||
cached_obs_space = None
|
||||
cached_act_space = None
|
||||
cached_metadata = None
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for group in task_groups:
|
||||
@@ -324,10 +325,11 @@ def create_metaworld_envs(
|
||||
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space)
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
if cached_obs_space is None:
|
||||
cached_obs_space = lazy.observation_space
|
||||
cached_act_space = lazy.action_space
|
||||
cached_metadata = lazy.metadata
|
||||
out[group][tid] = lazy
|
||||
else:
|
||||
out[group][tid] = env_cls(fns)
|
||||
|
||||
425
src/lerobot/envs/robocasa.py
Normal file
425
src/lerobot/envs/robocasa.py
Normal file
@@ -0,0 +1,425 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.types import RobotObservation
|
||||
|
||||
from .utils import _LazyAsyncVectorEnv, parse_camera_names
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Dimensions for the flat action/state vectors used by the LeRobot wrapper.
|
||||
# These correspond to the PandaOmron robot in RoboCasa365.
|
||||
OBS_STATE_DIM = 16 # base_pos(3) + base_quat(4) + ee_pos_rel(3) + ee_quat_rel(4) + gripper_qpos(2)
|
||||
ACTION_DIM = 12 # base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
# Default PandaOmron cameras. We surface these raw names directly as
|
||||
# `observation.images.<name>` so the LeRobot dataset/policy keys match
|
||||
# RoboCasa's native convention (no implicit renaming).
|
||||
DEFAULT_CAMERAS = [
|
||||
"robot0_agentview_left",
|
||||
"robot0_eye_in_hand",
|
||||
"robot0_agentview_right",
|
||||
]
|
||||
|
||||
# Object-mesh registries to sample from. RoboCasa's upstream default is
|
||||
# ("objaverse", "lightwheel"), but the objaverse pack is huge (~30GB) and
|
||||
# most users — including our CI image — only download the lightwheel pack
|
||||
# (`--type objs_lw` in `download_kitchen_assets`). When a sampled object
|
||||
# category has zero candidates in every registry, robocasa crashes with
|
||||
# `ValueError: Probabilities contain NaN` (0/0 divide in the probability
|
||||
# normalization). Restricting to registries that are actually on disk
|
||||
# avoids the NaN and matches what the asset download provides.
|
||||
DEFAULT_OBJ_REGISTRIES: tuple[str, ...] = ("lightwheel",)
|
||||
|
||||
# Task-group shortcuts accepted as `--env.task`. When the user passes one of
|
||||
# these names, we expand it to the upstream RoboCasa task list and auto-set
|
||||
# the dataset split. Individual task names (optionally comma-separated) still
|
||||
# take precedence; this only triggers on an exact group-name match.
|
||||
_TASK_GROUP_SPLITS = {
|
||||
"atomic_seen": "target",
|
||||
"composite_seen": "target",
|
||||
"composite_unseen": "target",
|
||||
"pretrain50": "pretrain",
|
||||
"pretrain100": "pretrain",
|
||||
"pretrain200": "pretrain",
|
||||
"pretrain300": "pretrain",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_tasks(task: str) -> tuple[list[str], str | None]:
|
||||
"""Resolve a `--env.task` value to (task_names, split_override).
|
||||
|
||||
If `task` is a known task-group name (e.g. `atomic_seen`, `pretrain100`),
|
||||
expand it via `robocasa.utils.dataset_registry.{TARGET,PRETRAINING}_TASKS`
|
||||
and return the matching split. Otherwise treat `task` as a single task or
|
||||
comma-separated list and leave the split untouched (None).
|
||||
"""
|
||||
key = task.strip()
|
||||
if key in _TASK_GROUP_SPLITS:
|
||||
from robocasa.utils.dataset_registry import PRETRAINING_TASKS, TARGET_TASKS
|
||||
|
||||
combined = {**TARGET_TASKS, **PRETRAINING_TASKS}
|
||||
if key not in combined:
|
||||
raise ValueError(
|
||||
f"Task group '{key}' is not available in this version of robocasa. "
|
||||
f"Known groups: {sorted(combined.keys())}."
|
||||
)
|
||||
return list(combined[key]), _TASK_GROUP_SPLITS[key]
|
||||
|
||||
names = [t.strip() for t in task.split(",") if t.strip()]
|
||||
if not names:
|
||||
raise ValueError("`task` must contain at least one RoboCasa task name.")
|
||||
return names, None
|
||||
|
||||
|
||||
def convert_action(flat_action: np.ndarray) -> dict[str, Any]:
|
||||
"""Split a flat (12,) action vector into a RoboCasa action dict.
|
||||
|
||||
Layout: base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
|
||||
"""
|
||||
return {
|
||||
"action.base_motion": flat_action[0:4],
|
||||
"action.control_mode": flat_action[4:5],
|
||||
"action.end_effector_position": flat_action[5:8],
|
||||
"action.end_effector_rotation": flat_action[8:11],
|
||||
"action.gripper_close": flat_action[11:12],
|
||||
}
|
||||
|
||||
|
||||
class RoboCasaEnv(gym.Env):
|
||||
"""LeRobot gym.Env wrapper for RoboCasa365 kitchen environments.
|
||||
|
||||
Wraps RoboCasaGymEnv from the robocasa package and converts its
|
||||
dict-based observations and actions into the flat arrays LeRobot expects.
|
||||
Raw RoboCasa camera names are preserved verbatim under `pixels/<cam>`.
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
camera_name: str | Sequence[str] = ",".join(DEFAULT_CAMERAS),
|
||||
obs_type: str = "pixels_agent_pos",
|
||||
render_mode: str = "rgb_array",
|
||||
observation_width: int = 256,
|
||||
observation_height: int = 256,
|
||||
visualization_width: int = 512,
|
||||
visualization_height: int = 512,
|
||||
split: str | None = None,
|
||||
episode_length: int | None = None,
|
||||
obj_registries: Sequence[str] = DEFAULT_OBJ_REGISTRIES,
|
||||
episode_index: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.task = task
|
||||
self.obs_type = obs_type
|
||||
self.render_mode = render_mode
|
||||
self.observation_width = observation_width
|
||||
self.observation_height = observation_height
|
||||
self.visualization_width = visualization_width
|
||||
self.visualization_height = visualization_height
|
||||
self.split = split
|
||||
self.obj_registries = tuple(obj_registries)
|
||||
# Per-worker index (0..n_envs-1) used to spread the user-provided
|
||||
# seed across factories so each sub-env explores a distinct layout
|
||||
# even when the same seed is passed to `reset()`.
|
||||
self.episode_index = int(episode_index)
|
||||
|
||||
self.camera_name = parse_camera_names(camera_name)
|
||||
|
||||
self._max_episode_steps = episode_length if episode_length is not None else 1000
|
||||
|
||||
# Deferred — created on first reset() inside the worker subprocess
|
||||
# to avoid inheriting stale GPU/EGL contexts across fork().
|
||||
self._env: Any = None
|
||||
self.task_description = ""
|
||||
|
||||
images = {
|
||||
cam: spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
for cam in self.camera_name
|
||||
}
|
||||
|
||||
if self.obs_type == "pixels":
|
||||
self.observation_space = spaces.Dict({"pixels": spaces.Dict(images)})
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"agent_pos": spaces.Box(
|
||||
low=-np.inf,
|
||||
high=np.inf,
|
||||
shape=(OBS_STATE_DIM,),
|
||||
dtype=np.float32,
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type '{self.obs_type}'. Use 'pixels' or 'pixels_agent_pos'.")
|
||||
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW,
|
||||
high=ACTION_HIGH,
|
||||
shape=(ACTION_DIM,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
def _ensure_env(self) -> None:
|
||||
"""Create the underlying RoboCasaGymEnv on first use.
|
||||
|
||||
Called inside the worker subprocess after fork(), so each worker gets
|
||||
its own clean rendering context rather than inheriting a stale one from
|
||||
the parent process (which causes crashes with AsyncVectorEnv).
|
||||
"""
|
||||
if self._env is not None:
|
||||
return
|
||||
from robocasa.wrappers.gym_wrapper import RoboCasaGymEnv
|
||||
|
||||
# RoboCasaGymEnv defaults split="test", which create_env rejects
|
||||
# (only None/"all"/"pretrain"/"target" are valid). Always pass a
|
||||
# valid value so we don't hit that default. Extra kwargs are
|
||||
# forwarded to the underlying kitchen env via create_env/robosuite.make.
|
||||
self._env = RoboCasaGymEnv(
|
||||
env_name=self.task,
|
||||
camera_widths=self.observation_width,
|
||||
camera_heights=self.observation_height,
|
||||
split=self.split if self.split is not None else "all",
|
||||
obj_registries=self.obj_registries,
|
||||
)
|
||||
|
||||
ep_meta = self._env.env.get_ep_meta()
|
||||
self.task_description = ep_meta.get("lang", self.task)
|
||||
|
||||
def _format_raw_obs(self, raw_obs: dict) -> RobotObservation:
|
||||
"""Convert RoboCasaGymEnv observation dict to LeRobot format."""
|
||||
# RoboCasaGymEnv emits camera frames under "video.<cam>".
|
||||
images = {cam: raw_obs[f"video.{cam}"] for cam in self.camera_name if f"video.{cam}" in raw_obs}
|
||||
|
||||
if self.obs_type == "pixels":
|
||||
return {"pixels": images}
|
||||
|
||||
# `state.*` keys come from PandaOmronKeyConverter inside the wrapper.
|
||||
agent_pos = np.concatenate(
|
||||
[
|
||||
raw_obs.get("state.base_position", np.zeros(3)),
|
||||
raw_obs.get("state.base_rotation", np.zeros(4)),
|
||||
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
|
||||
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
|
||||
raw_obs.get("state.gripper_qpos", np.zeros(2)),
|
||||
],
|
||||
axis=-1,
|
||||
).astype(np.float32)
|
||||
|
||||
return {"pixels": images, "agent_pos": agent_pos}
|
||||
|
||||
def render(self) -> np.ndarray:
|
||||
self._ensure_env()
|
||||
assert self._env is not None
|
||||
return self._env.render()
|
||||
|
||||
def reset(self, seed=None, **kwargs):
|
||||
self._ensure_env()
|
||||
assert self._env is not None
|
||||
super().reset(seed=seed)
|
||||
# Spread the seed across workers so n_envs factories don't all
|
||||
# roll the same scene. With an explicit user seed we shift it by
|
||||
# episode_index; with no seed we fall back to episode_index so
|
||||
# each worker is still distinct rather than inheriting the same
|
||||
# global RNG state.
|
||||
worker_seed = seed + self.episode_index if seed is not None else self.episode_index
|
||||
raw_obs, info = self._env.reset(seed=worker_seed)
|
||||
|
||||
ep_meta = self._env.env.get_ep_meta()
|
||||
self.task_description = ep_meta.get("lang", self.task)
|
||||
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
self._ensure_env()
|
||||
assert self._env is not None
|
||||
if action.ndim != 1:
|
||||
raise ValueError(
|
||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||
f"but got shape {action.shape} with ndim={action.ndim}"
|
||||
)
|
||||
|
||||
action_dict = convert_action(action)
|
||||
raw_obs, reward, done, truncated, info = self._env.step(action_dict)
|
||||
|
||||
is_success = bool(info.get("success", False))
|
||||
terminated = done or is_success
|
||||
info.update({"task": self.task, "done": done, "is_success": is_success})
|
||||
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
if terminated:
|
||||
info["final_info"] = {
|
||||
"task": self.task,
|
||||
"done": bool(done),
|
||||
"is_success": bool(is_success),
|
||||
}
|
||||
self.reset()
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def close(self):
|
||||
if self._env is not None:
|
||||
self._env.close()
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
camera_names: list[str],
|
||||
obs_type: str,
|
||||
render_mode: str,
|
||||
observation_width: int,
|
||||
observation_height: int,
|
||||
visualization_width: int,
|
||||
visualization_height: int,
|
||||
split: str | None,
|
||||
episode_length: int | None,
|
||||
obj_registries: Sequence[str],
|
||||
) -> list[Callable[[], RoboCasaEnv]]:
|
||||
"""Build n_envs factory callables for a single task.
|
||||
|
||||
Each factory carries a distinct ``episode_index`` (``0..n_envs-1``) so
|
||||
``RoboCasaEnv.reset()`` can derive a per-worker seed series from the
|
||||
user-provided seed.
|
||||
"""
|
||||
|
||||
def _make_env(episode_index: int) -> RoboCasaEnv:
|
||||
return RoboCasaEnv(
|
||||
task=task,
|
||||
camera_name=camera_names,
|
||||
obs_type=obs_type,
|
||||
render_mode=render_mode,
|
||||
observation_width=observation_width,
|
||||
observation_height=observation_height,
|
||||
visualization_width=visualization_width,
|
||||
visualization_height=visualization_height,
|
||||
split=split,
|
||||
episode_length=episode_length,
|
||||
obj_registries=obj_registries,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
|
||||
return [partial(_make_env, i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robocasa_envs(
|
||||
task: str,
|
||||
n_envs: int,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
camera_name: str | Sequence[str] = ",".join(DEFAULT_CAMERAS),
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
episode_length: int | None = None,
|
||||
obj_registries: Sequence[str] = DEFAULT_OBJ_REGISTRIES,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboCasa365 environments with a consistent return shape.
|
||||
|
||||
Returns:
|
||||
dict[task_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
|
||||
|
||||
`task` can be:
|
||||
- a single task name (e.g. `CloseFridge`)
|
||||
- a comma-separated list of task names (e.g. `CloseFridge,PickPlaceCoffee`)
|
||||
- a benchmark-group shortcut (`atomic_seen`, `composite_seen`,
|
||||
`composite_unseen`, `pretrain50`, `pretrain100`, `pretrain200`,
|
||||
`pretrain300`), which auto-expands to the upstream task list and
|
||||
auto-sets the dataset `split` ("target" or "pretrain").
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
obs_type = gym_kwargs.pop("obs_type", "pixels_agent_pos")
|
||||
render_mode = gym_kwargs.pop("render_mode", "rgb_array")
|
||||
observation_width = gym_kwargs.pop("observation_width", 256)
|
||||
observation_height = gym_kwargs.pop("observation_height", 256)
|
||||
visualization_width = gym_kwargs.pop("visualization_width", 512)
|
||||
visualization_height = gym_kwargs.pop("visualization_height", 512)
|
||||
split = gym_kwargs.pop("split", None)
|
||||
|
||||
camera_names = parse_camera_names(camera_name)
|
||||
task_names, group_split = _resolve_tasks(str(task))
|
||||
if group_split is not None and split is None:
|
||||
split = group_split
|
||||
|
||||
logger.info(
|
||||
"Creating RoboCasa envs | tasks=%s | split=%s | n_envs(per task)=%d",
|
||||
task_names,
|
||||
split,
|
||||
n_envs,
|
||||
)
|
||||
|
||||
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||
|
||||
cached_obs_space: spaces.Space | None = None
|
||||
cached_act_space: spaces.Space | None = None
|
||||
cached_metadata: dict[str, Any] | None = None
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for task_name in task_names:
|
||||
fns = _make_env_fns(
|
||||
task=task_name,
|
||||
n_envs=n_envs,
|
||||
camera_names=camera_names,
|
||||
obs_type=obs_type,
|
||||
render_mode=render_mode,
|
||||
observation_width=observation_width,
|
||||
observation_height=observation_height,
|
||||
visualization_width=visualization_width,
|
||||
visualization_height=visualization_height,
|
||||
split=split,
|
||||
episode_length=episode_length,
|
||||
obj_registries=obj_registries,
|
||||
)
|
||||
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
if cached_obs_space is None:
|
||||
cached_obs_space = lazy.observation_space
|
||||
cached_act_space = lazy.action_space
|
||||
cached_metadata = lazy.metadata
|
||||
out[task_name][0] = lazy
|
||||
else:
|
||||
out[task_name][0] = env_cls(fns)
|
||||
logger.info("Built vec env | task=%s | n_envs=%d", task_name, n_envs)
|
||||
|
||||
return {name: dict(task_map) for name, task_map in out.items()}
|
||||
245
src/lerobot/envs/robomme.py
Normal file
245
src/lerobot/envs/robomme.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""RoboMME environment wrapper for LeRobot evaluation.
|
||||
|
||||
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
|
||||
``VectorEnv`` suitable for ``lerobot_eval``.
|
||||
|
||||
RoboMME tasks:
|
||||
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
|
||||
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
|
||||
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
|
||||
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
|
||||
|
||||
Dataset: lerobot/robomme (LeRobot v3.0, 1,600 episodes)
|
||||
Install: see docker/Dockerfile.benchmark.robomme (Linux only — mani-skill vs numpy pin conflict)
|
||||
Benchmark: https://github.com/RoboMME/robomme_benchmark
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from .utils import _LazyAsyncVectorEnv
|
||||
|
||||
ROBOMME_TASKS = [
|
||||
"BinFill",
|
||||
"PickXtimes",
|
||||
"SwingXtimes",
|
||||
"StopCube",
|
||||
"VideoUnmask",
|
||||
"VideoUnmaskSwap",
|
||||
"ButtonUnmask",
|
||||
"ButtonUnmaskSwap",
|
||||
"PickHighlight",
|
||||
"VideoRepick",
|
||||
"VideoPlaceButton",
|
||||
"VideoPlaceOrder",
|
||||
"MoveCube",
|
||||
"InsertPeg",
|
||||
"PatternLock",
|
||||
"RouteStick",
|
||||
]
|
||||
|
||||
|
||||
class RoboMMEGymEnv(gym.Env):
|
||||
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 10}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str = "PickXtimes",
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_idx: int = 0,
|
||||
max_steps: int = 300,
|
||||
):
|
||||
super().__init__()
|
||||
from robomme.env_record_wrapper import BenchmarkEnvBuilder
|
||||
|
||||
self._task = task
|
||||
self._action_space_type = action_space_type
|
||||
self._dataset = dataset
|
||||
self._episode_idx = episode_idx
|
||||
self._max_steps = max_steps
|
||||
self._max_episode_steps = max_steps
|
||||
|
||||
self._builder = BenchmarkEnvBuilder(
|
||||
env_id=task,
|
||||
dataset=dataset,
|
||||
action_space=action_space_type,
|
||||
gui_render=False,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
self._env = None
|
||||
self._last_raw_obs: dict | None = None
|
||||
|
||||
action_dim = 8 if action_space_type == "joint_angle" else 7
|
||||
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
|
||||
# `pixels` must be a nested Dict so `preprocess_observation()` in
|
||||
# envs/utils.py picks it up and maps each camera to
|
||||
# `observation.images.<cam>`. A flat layout (`pixels/image`,
|
||||
# `pixels/wrist_image`) silently drops every image from the batch.
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(
|
||||
{
|
||||
"image": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
"wrist_image": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
}
|
||||
),
|
||||
"agent_pos": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
self._env = self._builder.make_env_for_episode(
|
||||
episode_idx=self._episode_idx,
|
||||
max_steps=self._max_steps,
|
||||
)
|
||||
obs, info = self._env.reset()
|
||||
self._last_raw_obs = obs
|
||||
return self._convert_obs(obs), self._convert_info(info)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self._env.step(action)
|
||||
self._last_raw_obs = obs
|
||||
|
||||
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
|
||||
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
|
||||
|
||||
status = info.get("status", "ongoing")
|
||||
is_success = status == "success"
|
||||
conv_info = self._convert_info(info)
|
||||
conv_info["is_success"] = is_success
|
||||
|
||||
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
|
||||
|
||||
def render(self) -> np.ndarray | None:
|
||||
"""Return the front camera image from the last observation for video recording."""
|
||||
if self._last_raw_obs is None:
|
||||
return np.zeros((256, 256, 3), dtype=np.uint8)
|
||||
front = self._last_raw_obs.get("front_rgb_list")
|
||||
if front is None:
|
||||
return np.zeros((256, 256, 3), dtype=np.uint8)
|
||||
frame = front[-1] if isinstance(front, list) else front
|
||||
return np.asarray(frame, dtype=np.uint8)
|
||||
|
||||
def _convert_obs(self, obs: dict) -> dict:
|
||||
front_rgb = (
|
||||
obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
|
||||
)
|
||||
wrist_rgb = (
|
||||
obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
|
||||
)
|
||||
joint_state = (
|
||||
obs["joint_state_list"][-1]
|
||||
if isinstance(obs["joint_state_list"], list)
|
||||
else obs["joint_state_list"]
|
||||
)
|
||||
gripper_state = (
|
||||
obs["gripper_state_list"][-1]
|
||||
if isinstance(obs["gripper_state_list"], list)
|
||||
else obs["gripper_state_list"]
|
||||
)
|
||||
|
||||
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
|
||||
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
|
||||
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
|
||||
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
|
||||
state = np.concatenate([joint, gripper])
|
||||
|
||||
return {
|
||||
"pixels": {"image": front_rgb, "wrist_image": wrist_rgb},
|
||||
"agent_pos": state,
|
||||
}
|
||||
|
||||
def _convert_info(self, info: dict) -> dict:
|
||||
return {
|
||||
"status": info.get("status", "ongoing"),
|
||||
"task_goal": info.get("task_goal", ""),
|
||||
}
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
action_space_type: str,
|
||||
dataset: str,
|
||||
episode_length: int,
|
||||
task_id: int,
|
||||
) -> list[Callable[[], RoboMMEGymEnv]]:
|
||||
"""Build n_envs factory callables for one RoboMME task id."""
|
||||
|
||||
def _make_one(episode_index: int) -> RoboMMEGymEnv:
|
||||
return RoboMMEGymEnv(
|
||||
task=task,
|
||||
action_space_type=action_space_type,
|
||||
dataset=dataset,
|
||||
episode_idx=episode_index,
|
||||
max_steps=episode_length,
|
||||
)
|
||||
|
||||
return [partial(_make_one, task_id + i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robomme_envs(
|
||||
task: str,
|
||||
n_envs: int = 1,
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_length: int = 300,
|
||||
task_ids: list[int] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Create vectorized RoboMME environments for evaluation.
|
||||
|
||||
`task` may be a single RoboMME task name (e.g. "PickXtimes") or a
|
||||
comma-separated list (e.g. "PickXtimes,BinFill,StopCube"). Each task
|
||||
becomes its own suite in the returned mapping.
|
||||
|
||||
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable that wraps a list of env factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
if task_ids is None:
|
||||
task_ids = [0]
|
||||
|
||||
task_names = [t.strip() for t in task.split(",") if t.strip()]
|
||||
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||
cached_obs_space: spaces.Space | None = None
|
||||
cached_act_space: spaces.Space | None = None
|
||||
cached_metadata: dict[str, Any] | None = None
|
||||
out: dict[str, dict[int, gym.vector.VectorEnv]] = {}
|
||||
for task_name in task_names:
|
||||
envs_by_task: dict[int, gym.vector.VectorEnv] = {}
|
||||
for task_id in task_ids:
|
||||
fns = _make_env_fns(
|
||||
task=task_name,
|
||||
n_envs=n_envs,
|
||||
action_space_type=action_space_type,
|
||||
dataset=dataset,
|
||||
episode_length=episode_length,
|
||||
task_id=task_id,
|
||||
)
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
if cached_obs_space is None:
|
||||
cached_obs_space = lazy.observation_space
|
||||
cached_act_space = lazy.action_space
|
||||
cached_metadata = lazy.metadata
|
||||
envs_by_task[task_id] = lazy
|
||||
else:
|
||||
envs_by_task[task_id] = env_cls(fns)
|
||||
out[task_name] = envs_by_task
|
||||
return out
|
||||
488
src/lerobot/envs/robotwin.py
Normal file
488
src/lerobot/envs/robotwin.py
Normal file
@@ -0,0 +1,488 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.types import RobotObservation
|
||||
|
||||
from .utils import _LazyAsyncVectorEnv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Camera names as used by RoboTwin 2.0. The wrapper appends "_rgb" when looking
|
||||
# up keys in get_obs() output (e.g. "head_camera" → "head_camera_rgb").
|
||||
ROBOTWIN_CAMERA_NAMES: tuple[str, ...] = (
|
||||
"head_camera",
|
||||
"left_camera",
|
||||
"right_camera",
|
||||
)
|
||||
|
||||
ACTION_DIM = 14 # 7 DOF × 2 arms
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
DEFAULT_EPISODE_LENGTH = 300
|
||||
# D435 dims from task_config/_camera_config.yml (what demo_clean.yml selects).
|
||||
DEFAULT_CAMERA_H = 240
|
||||
DEFAULT_CAMERA_W = 320
|
||||
|
||||
# Task list from RoboTwin 2.0's `envs/` directory — mirrors upstream exactly
|
||||
# (50 tasks as of main; earlier revisions had 60 with a different split).
|
||||
# Keep this in sync with:
|
||||
# gh api /repos/RoboTwin-Platform/RoboTwin/contents/envs --paginate \
|
||||
# | jq -r '.[].name' | grep -E '\.py$' | grep -v '^_' | sed 's/\.py$//'
|
||||
ROBOTWIN_TASKS: tuple[str, ...] = (
|
||||
"adjust_bottle",
|
||||
"beat_block_hammer",
|
||||
"blocks_ranking_rgb",
|
||||
"blocks_ranking_size",
|
||||
"click_alarmclock",
|
||||
"click_bell",
|
||||
"dump_bin_bigbin",
|
||||
"grab_roller",
|
||||
"handover_block",
|
||||
"handover_mic",
|
||||
"hanging_mug",
|
||||
"lift_pot",
|
||||
"move_can_pot",
|
||||
"move_pillbottle_pad",
|
||||
"move_playingcard_away",
|
||||
"move_stapler_pad",
|
||||
"open_laptop",
|
||||
"open_microwave",
|
||||
"pick_diverse_bottles",
|
||||
"pick_dual_bottles",
|
||||
"place_a2b_left",
|
||||
"place_a2b_right",
|
||||
"place_bread_basket",
|
||||
"place_bread_skillet",
|
||||
"place_burger_fries",
|
||||
"place_can_basket",
|
||||
"place_cans_plasticbox",
|
||||
"place_container_plate",
|
||||
"place_dual_shoes",
|
||||
"place_empty_cup",
|
||||
"place_fan",
|
||||
"place_mouse_pad",
|
||||
"place_object_basket",
|
||||
"place_object_scale",
|
||||
"place_object_stand",
|
||||
"place_phone_stand",
|
||||
"place_shoe",
|
||||
"press_stapler",
|
||||
"put_bottles_dustbin",
|
||||
"put_object_cabinet",
|
||||
"rotate_qrcode",
|
||||
"scan_object",
|
||||
"shake_bottle",
|
||||
"shake_bottle_horizontally",
|
||||
"stack_blocks_three",
|
||||
"stack_blocks_two",
|
||||
"stack_bowls_three",
|
||||
"stack_bowls_two",
|
||||
"stamp_seal",
|
||||
"turn_switch",
|
||||
)
|
||||
|
||||
|
||||
_ROBOTWIN_SETUP_CACHE: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _load_robotwin_setup_kwargs(task_name: str) -> dict[str, Any]:
|
||||
"""Build the kwargs dict RoboTwin's setup_demo expects.
|
||||
|
||||
Mirrors the config loading done by RoboTwin's ``script/eval_policy.py``:
|
||||
reads ``task_config/demo_clean.yml``, resolves the embodiment file from
|
||||
``_embodiment_config.yml``, loads the robot's own ``config.yml``, and
|
||||
reads camera dimensions from ``_camera_config.yml``.
|
||||
|
||||
Uses ``aloha-agilex`` single-robot dual-arm by default (the only embodiment
|
||||
used by beat_block_hammer and most smoke-test tasks).
|
||||
"""
|
||||
if task_name in _ROBOTWIN_SETUP_CACHE:
|
||||
return dict(_ROBOTWIN_SETUP_CACHE[task_name])
|
||||
|
||||
import os
|
||||
|
||||
import yaml # type: ignore[import-untyped]
|
||||
from envs import CONFIGS_PATH # type: ignore[import-not-found]
|
||||
|
||||
task_config = "demo_clean"
|
||||
with open(os.path.join(CONFIGS_PATH, f"{task_config}.yml"), encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
|
||||
# Resolve embodiment — demo_clean.yml uses [aloha-agilex] (dual-arm single robot)
|
||||
with open(os.path.join(CONFIGS_PATH, "_embodiment_config.yml"), encoding="utf-8") as f:
|
||||
embodiment_types = yaml.safe_load(f)
|
||||
embodiment = args.get("embodiment", ["aloha-agilex"])
|
||||
if len(embodiment) == 1:
|
||||
robot_file = embodiment_types[embodiment[0]]["file_path"]
|
||||
args["left_robot_file"] = robot_file
|
||||
args["right_robot_file"] = robot_file
|
||||
args["dual_arm_embodied"] = True
|
||||
elif len(embodiment) == 3:
|
||||
args["left_robot_file"] = embodiment_types[embodiment[0]]["file_path"]
|
||||
args["right_robot_file"] = embodiment_types[embodiment[1]]["file_path"]
|
||||
args["embodiment_dis"] = embodiment[2]
|
||||
args["dual_arm_embodied"] = False
|
||||
else:
|
||||
raise ValueError(f"embodiment must have 1 or 3 items, got {len(embodiment)}")
|
||||
|
||||
with open(os.path.join(args["left_robot_file"], "config.yml"), encoding="utf-8") as f:
|
||||
args["left_embodiment_config"] = yaml.safe_load(f)
|
||||
with open(os.path.join(args["right_robot_file"], "config.yml"), encoding="utf-8") as f:
|
||||
args["right_embodiment_config"] = yaml.safe_load(f)
|
||||
|
||||
# Camera dimensions
|
||||
with open(os.path.join(CONFIGS_PATH, "_camera_config.yml"), encoding="utf-8") as f:
|
||||
camera_config = yaml.safe_load(f)
|
||||
head_cam = args["camera"]["head_camera_type"]
|
||||
args["head_camera_h"] = camera_config[head_cam]["h"]
|
||||
args["head_camera_w"] = camera_config[head_cam]["w"]
|
||||
|
||||
# Headless overrides
|
||||
args["render_freq"] = 0
|
||||
args["task_name"] = task_name
|
||||
args["task_config"] = task_config
|
||||
|
||||
_ROBOTWIN_SETUP_CACHE[task_name] = args
|
||||
return dict(args)
|
||||
|
||||
|
||||
def _load_robotwin_task(task_name: str) -> type:
|
||||
"""Dynamically import and return a RoboTwin 2.0 task class.
|
||||
|
||||
RoboTwin tasks live in ``envs/<task_name>.py`` relative to the repository
|
||||
root and are expected to be on ``sys.path`` after installation.
|
||||
"""
|
||||
try:
|
||||
module = importlib.import_module(f"envs.{task_name}")
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Could not import RoboTwin task '{task_name}'. "
|
||||
"Ensure RoboTwin 2.0 is installed and its 'envs/' directory is on PYTHONPATH. "
|
||||
"See the RoboTwin installation guide: https://robotwin-platform.github.io/doc/usage/robotwin-install.html"
|
||||
) from e
|
||||
task_cls = getattr(module, task_name, None)
|
||||
if task_cls is None:
|
||||
raise AttributeError(f"Task class '{task_name}' not found in envs/{task_name}.py")
|
||||
return task_cls
|
||||
|
||||
|
||||
class RoboTwinEnv(gym.Env):
|
||||
"""Gymnasium wrapper around a single RoboTwin 2.0 task.
|
||||
|
||||
RoboTwin uses a custom SAPIEN-based API (``setup_demo`` / ``get_obs`` /
|
||||
``take_action`` / ``check_success``) rather than the standard gym interface.
|
||||
This class bridges that API to Gymnasium so that ``lerobot-eval`` can drive
|
||||
RoboTwin exactly like LIBERO or Meta-World.
|
||||
|
||||
The underlying SAPIEN environment is created lazily on the first ``reset()``
|
||||
call *inside the worker process*. This is required for
|
||||
``gym.vector.AsyncVectorEnv`` compatibility: SAPIEN allocates EGL/GPU
|
||||
contexts that must not be forked from the parent process.
|
||||
|
||||
Observations
|
||||
------------
|
||||
The ``pixels`` dict uses the raw RoboTwin camera names as keys (e.g.
|
||||
``"head_camera"``, ``"left_camera"``). ``preprocess_observation`` in
|
||||
``envs/utils.py`` then converts these to ``observation.images.<cam>``.
|
||||
|
||||
Actions
|
||||
-------
|
||||
14-dim float32 array in ``[-1, 1]`` (joint-space, 7 DOF per arm).
|
||||
|
||||
Autograd
|
||||
--------
|
||||
``setup_demo`` and ``take_action`` drive CuRobo's Newton trajectory
|
||||
optimizer, which calls ``cost.backward()`` internally. lerobot_eval wraps
|
||||
the rollout in ``torch.no_grad()``, so both call sites re-enable grad.
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 25}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_name: str,
|
||||
episode_index: int = 0,
|
||||
n_envs: int = 1,
|
||||
camera_names: Sequence[str] = ROBOTWIN_CAMERA_NAMES,
|
||||
observation_height: int | None = None,
|
||||
observation_width: int | None = None,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
render_mode: str = "rgb_array",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_name = task_name
|
||||
self.task = task_name # used by add_envs_task() in utils.py
|
||||
self.task_description = task_name.replace("_", " ")
|
||||
self.episode_index = episode_index
|
||||
self._reset_stride = n_envs
|
||||
self.camera_names = list(camera_names)
|
||||
# Default to D435 dims (the camera type baked into task_config/demo_clean.yml).
|
||||
# The YAML-driven lookup is deferred to reset() so construction doesn't
|
||||
# import RoboTwin's `envs` module — fast-tests run without RoboTwin installed.
|
||||
self.observation_height = observation_height or DEFAULT_CAMERA_H
|
||||
self.observation_width = observation_width or DEFAULT_CAMERA_W
|
||||
self.episode_length = episode_length
|
||||
self._max_episode_steps = episode_length # lerobot_eval.rollout reads this
|
||||
self.render_mode = render_mode
|
||||
|
||||
self._env: Any | None = None # deferred — created on first reset() inside worker
|
||||
self._step_count: int = 0
|
||||
self._black_frame = np.zeros((self.observation_height, self.observation_width, 3), dtype=np.uint8)
|
||||
|
||||
image_spaces = {
|
||||
cam: spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.observation_height, self.observation_width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
for cam in self.camera_names
|
||||
}
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(image_spaces),
|
||||
"agent_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(ACTION_DIM,), dtype=np.float32),
|
||||
}
|
||||
)
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32
|
||||
)
|
||||
|
||||
def _ensure_env(self) -> None:
|
||||
"""Create the SAPIEN environment on first use.
|
||||
|
||||
Called inside the worker subprocess after fork(), so each worker gets
|
||||
its own EGL/GPU context rather than inheriting a stale one from the
|
||||
parent process (which causes crashes with AsyncVectorEnv).
|
||||
"""
|
||||
if self._env is not None:
|
||||
return
|
||||
task_cls = _load_robotwin_task(self.task_name)
|
||||
self._env = task_cls()
|
||||
|
||||
def _get_obs(self) -> RobotObservation:
|
||||
assert self._env is not None, "_get_obs called before _ensure_env()"
|
||||
raw = self._env.get_obs()
|
||||
cameras_raw = raw.get("observation", {})
|
||||
|
||||
images: dict[str, np.ndarray] = {}
|
||||
for cam in self.camera_names:
|
||||
cam_data = cameras_raw.get(cam)
|
||||
img = cam_data.get("rgb") if cam_data else None
|
||||
if img is None:
|
||||
images[cam] = self._black_frame
|
||||
continue
|
||||
img = np.asarray(img, dtype=np.uint8)
|
||||
if img.ndim == 2:
|
||||
img = np.stack([img, img, img], axis=-1)
|
||||
elif img.shape[-1] != 3:
|
||||
img = img[..., :3]
|
||||
images[cam] = img
|
||||
|
||||
ja = raw.get("joint_action") or {}
|
||||
vec = ja.get("vector")
|
||||
if vec is not None:
|
||||
arr = np.asarray(vec, dtype=np.float32).ravel()
|
||||
joint_state = (
|
||||
arr[:ACTION_DIM] if arr.size >= ACTION_DIM else np.zeros(ACTION_DIM, dtype=np.float32)
|
||||
)
|
||||
else:
|
||||
joint_state = np.zeros(ACTION_DIM, dtype=np.float32)
|
||||
|
||||
return {"pixels": images, "agent_pos": joint_state}
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[RobotObservation, dict]:
|
||||
self._ensure_env()
|
||||
super().reset(seed=seed)
|
||||
assert self._env is not None # set by _ensure_env() above
|
||||
|
||||
actual_seed = self.episode_index if seed is None else seed
|
||||
setup_kwargs = _load_robotwin_setup_kwargs(self.task_name)
|
||||
setup_kwargs.update(seed=actual_seed, is_test=True)
|
||||
with torch.enable_grad():
|
||||
self._env.setup_demo(**setup_kwargs)
|
||||
self.episode_index += self._reset_stride
|
||||
self._step_count = 0
|
||||
|
||||
obs = self._get_obs()
|
||||
return obs, {"is_success": False, "task": self.task_name}
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
assert self._env is not None, "step() called before reset()"
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}")
|
||||
|
||||
with torch.enable_grad():
|
||||
if hasattr(self._env, "take_action"):
|
||||
self._env.take_action(action)
|
||||
else:
|
||||
self._env.step(action)
|
||||
|
||||
self._step_count += 1
|
||||
|
||||
is_success = bool(getattr(self._env, "eval_success", False))
|
||||
if not is_success and hasattr(self._env, "check_success"):
|
||||
is_success = bool(self._env.check_success())
|
||||
|
||||
obs = self._get_obs()
|
||||
reward = float(is_success)
|
||||
terminated = is_success
|
||||
truncated = self._step_count >= self.episode_length
|
||||
|
||||
info: dict[str, Any] = {
|
||||
"task": self.task_name,
|
||||
"is_success": is_success,
|
||||
"step": self._step_count,
|
||||
}
|
||||
if terminated or truncated:
|
||||
info["final_info"] = {
|
||||
"task": self.task_name,
|
||||
"is_success": is_success,
|
||||
}
|
||||
self.reset()
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def render(self) -> np.ndarray:
|
||||
self._ensure_env()
|
||||
obs = self._get_obs()
|
||||
# Prefer head camera for rendering; fall back to first available.
|
||||
if "head_camera" in obs["pixels"]:
|
||||
return obs["pixels"]["head_camera"]
|
||||
return next(iter(obs["pixels"].values()))
|
||||
|
||||
def close(self) -> None:
|
||||
if self._env is not None:
|
||||
if hasattr(self._env, "close_env"):
|
||||
import contextlib
|
||||
|
||||
with contextlib.suppress(TypeError):
|
||||
self._env.close_env()
|
||||
self._env = None
|
||||
|
||||
|
||||
# ---- Multi-task factory --------------------------------------------------------
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task_name: str,
|
||||
n_envs: int,
|
||||
camera_names: list[str],
|
||||
observation_height: int,
|
||||
observation_width: int,
|
||||
episode_length: int,
|
||||
) -> list[Callable[[], RoboTwinEnv]]:
|
||||
"""Return n_envs factory callables for a single task."""
|
||||
|
||||
def _make_one(episode_index: int) -> RoboTwinEnv:
|
||||
return RoboTwinEnv(
|
||||
task_name=task_name,
|
||||
episode_index=episode_index,
|
||||
n_envs=n_envs,
|
||||
camera_names=camera_names,
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
)
|
||||
|
||||
return [partial(_make_one, i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robotwin_envs(
|
||||
task: str,
|
||||
n_envs: int,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
camera_names: Sequence[str] = ROBOTWIN_CAMERA_NAMES,
|
||||
observation_height: int = DEFAULT_CAMERA_H,
|
||||
observation_width: int = DEFAULT_CAMERA_W,
|
||||
episode_length: int = DEFAULT_EPISODE_LENGTH,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboTwin 2.0 environments.
|
||||
|
||||
Returns:
|
||||
``dict[task_name][0] -> VectorEnv`` — one entry per task, each wrapping
|
||||
``n_envs`` parallel rollouts.
|
||||
|
||||
Args:
|
||||
task: Comma-separated list of task names (e.g. ``"beat_block_hammer"``
|
||||
or ``"beat_block_hammer,click_bell"``).
|
||||
n_envs: Number of parallel rollouts per task.
|
||||
env_cls: Vector env constructor (e.g. ``gym.vector.AsyncVectorEnv``).
|
||||
camera_names: Cameras to include in observations.
|
||||
observation_height: Pixel height for all cameras.
|
||||
observation_width: Pixel width for all cameras.
|
||||
episode_length: Max steps before truncation.
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be callable (e.g. gym.vector.AsyncVectorEnv).")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
task_names = [t.strip() for t in str(task).split(",") if t.strip()]
|
||||
if not task_names:
|
||||
raise ValueError("`task` must contain at least one RoboTwin task name.")
|
||||
|
||||
unknown = [t for t in task_names if t not in ROBOTWIN_TASKS]
|
||||
if unknown:
|
||||
raise ValueError(f"Unknown RoboTwin tasks: {unknown}. Available tasks: {sorted(ROBOTWIN_TASKS)}")
|
||||
|
||||
logger.info(
|
||||
"Creating RoboTwin envs | tasks=%s | n_envs(per task)=%d",
|
||||
task_names,
|
||||
n_envs,
|
||||
)
|
||||
|
||||
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||
cached_obs_space: spaces.Space | None = None
|
||||
cached_act_space: spaces.Space | None = None
|
||||
cached_metadata: dict[str, Any] | None = None
|
||||
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
for task_name in task_names:
|
||||
fns = _make_env_fns(
|
||||
task_name=task_name,
|
||||
n_envs=n_envs,
|
||||
camera_names=list(camera_names),
|
||||
observation_height=observation_height,
|
||||
observation_width=observation_width,
|
||||
episode_length=episode_length,
|
||||
)
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
if cached_obs_space is None:
|
||||
cached_obs_space = lazy.observation_space
|
||||
cached_act_space = lazy.action_space
|
||||
cached_metadata = lazy.metadata
|
||||
out[task_name][0] = lazy
|
||||
else:
|
||||
out[task_name][0] = env_cls(fns)
|
||||
logger.info("Built vec env | task=%s | n_envs=%d", task_name, n_envs)
|
||||
|
||||
return {k: dict(v) for k, v in out.items()}
|
||||
@@ -34,6 +34,25 @@ from lerobot.utils.utils import get_channel_first_image_shape
|
||||
from .configs import EnvConfig
|
||||
|
||||
|
||||
def parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize ``camera_name`` into a non-empty list of strings.
|
||||
|
||||
Accepts a comma-separated string (``"cam_a,cam_b"``) or a sequence of
|
||||
strings (tuples/lists). Whitespace is stripped; empty entries are
|
||||
dropped. Raises ``TypeError`` for unsupported input types and
|
||||
``ValueError`` when the normalized list is empty.
|
||||
"""
|
||||
if isinstance(camera_name, str):
|
||||
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||
elif isinstance(camera_name, (list | tuple)):
|
||||
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
||||
else:
|
||||
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
||||
if not cams:
|
||||
raise ValueError("camera_name resolved to an empty list.")
|
||||
return cams
|
||||
|
||||
|
||||
def _convert_nested_dict(d):
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
@@ -153,17 +172,20 @@ class _LazyAsyncVectorEnv:
|
||||
env_fns: list[Callable],
|
||||
observation_space=None,
|
||||
action_space=None,
|
||||
metadata=None,
|
||||
):
|
||||
self._env_fns = env_fns
|
||||
self._env: gym.vector.AsyncVectorEnv | None = None
|
||||
self.num_envs = len(env_fns)
|
||||
if observation_space is not None and action_space is not None:
|
||||
if observation_space is not None and action_space is not None and metadata is not None:
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.metadata = metadata
|
||||
else:
|
||||
tmp = env_fns[0]()
|
||||
self.observation_space = tmp.observation_space
|
||||
self.action_space = tmp.action_space
|
||||
self.metadata = tmp.metadata
|
||||
tmp.close()
|
||||
self.single_observation_space = self.observation_space
|
||||
self.single_action_space = self.action_space
|
||||
@@ -172,6 +194,10 @@ class _LazyAsyncVectorEnv:
|
||||
if self._env is None:
|
||||
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver", shared_memory=True)
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
return self
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._ensure()
|
||||
return self._env.reset(**kwargs)
|
||||
|
||||
589
src/lerobot/envs/vlabench.py
Normal file
589
src/lerobot/envs/vlabench.py
Normal file
@@ -0,0 +1,589 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""VLABench environment wrapper for LeRobot.
|
||||
|
||||
VLABench is a large-scale benchmark for language-conditioned robotic manipulation
|
||||
with long-horizon reasoning, built on MuJoCo/dm_control.
|
||||
|
||||
- Paper: https://arxiv.org/abs/2412.18194
|
||||
- GitHub: https://github.com/OpenMOSS/VLABench
|
||||
- Website: https://vlabench.github.io
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.types import RobotObservation
|
||||
|
||||
from .utils import _LazyAsyncVectorEnv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACTION_DIM = 7 # pos(3) + euler(3) + gripper(1)
|
||||
ACTION_LOW = np.array([-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 0.0], dtype=np.float32)
|
||||
ACTION_HIGH = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], dtype=np.float32)
|
||||
|
||||
# Default max episode steps per task type
|
||||
DEFAULT_MAX_EPISODE_STEPS = 500
|
||||
|
||||
# VLABench task suites
|
||||
PRIMITIVE_TASKS = [
|
||||
"select_fruit",
|
||||
"select_toy",
|
||||
"select_chemistry_tube",
|
||||
"add_condiment",
|
||||
"select_book",
|
||||
"select_painting",
|
||||
"select_drink",
|
||||
"insert_flower",
|
||||
"select_billiards",
|
||||
"select_ingredient",
|
||||
"select_mahjong",
|
||||
"select_poker",
|
||||
# Physical series
|
||||
"density_qa",
|
||||
"friction_qa",
|
||||
"magnetism_qa",
|
||||
"reflection_qa",
|
||||
"simple_cuestick_usage",
|
||||
"simple_seesaw_usage",
|
||||
"sound_speed_qa",
|
||||
"thermal_expansion_qa",
|
||||
"weight_qa",
|
||||
]
|
||||
|
||||
COMPOSITE_TASKS = [
|
||||
"cluster_billiards",
|
||||
"cluster_book",
|
||||
"cluster_drink",
|
||||
"cluster_toy",
|
||||
"cook_dishes",
|
||||
"cool_drink",
|
||||
"find_unseen_object",
|
||||
"get_coffee",
|
||||
"hammer_nail",
|
||||
"heat_food",
|
||||
"make_juice",
|
||||
"play_mahjong",
|
||||
"play_math_game",
|
||||
"play_poker",
|
||||
"play_snooker",
|
||||
"rearrange_book",
|
||||
"rearrange_chemistry_tube",
|
||||
"set_dining_table",
|
||||
"set_study_table",
|
||||
"store_food",
|
||||
"take_chemistry_experiment",
|
||||
"use_seesaw_complex",
|
||||
]
|
||||
|
||||
SUITE_TASKS: dict[str, list[str]] = {
|
||||
"primitive": PRIMITIVE_TASKS,
|
||||
"composite": COMPOSITE_TASKS,
|
||||
}
|
||||
|
||||
|
||||
class VLABenchEnv(gym.Env):
|
||||
"""Gymnasium wrapper for VLABench environments.
|
||||
|
||||
Wraps the dm_control-based VLABench simulator behind a standard gym.Env interface.
|
||||
Supports multiple cameras (front, second, wrist) and end-effector control.
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 10}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str = "select_fruit",
|
||||
obs_type: str = "pixels_agent_pos",
|
||||
render_mode: str = "rgb_array",
|
||||
render_resolution: tuple[int, int] = (480, 480),
|
||||
robot: str = "franka",
|
||||
max_episode_steps: int = DEFAULT_MAX_EPISODE_STEPS,
|
||||
action_mode: str = "eef",
|
||||
):
|
||||
super().__init__()
|
||||
self.task = task
|
||||
self.obs_type = obs_type
|
||||
self.render_mode = render_mode
|
||||
self.render_resolution = render_resolution
|
||||
self.robot = robot
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self.action_mode = action_mode
|
||||
|
||||
# Deferred — created on first reset() inside worker subprocess to avoid
|
||||
# inheriting stale GPU/EGL contexts when AsyncVectorEnv spawns workers.
|
||||
# We never cache `env.physics`: dm_control exposes it as a weakref
|
||||
# proxy that goes stale across resets (rebuilds the sim), so we always
|
||||
# refetch it via `self._env.physics` at the call site.
|
||||
self._env = None
|
||||
self.task_description = "" # populated on first reset
|
||||
# Cached world-frame XYZ of the robot base link. The VLABench datasets
|
||||
# log both `observation.state` positions and `actions` positions in
|
||||
# robot-base frame (see VLABench/scripts/convert_to_lerobot.py which
|
||||
# subtracts `robot_frame_pos` from ee_pos). The robot is attached at a
|
||||
# fixed offset per task so this is safe to cache once per env build.
|
||||
self._robot_base_xyz: np.ndarray | None = None
|
||||
|
||||
h, w = self.render_resolution
|
||||
|
||||
if self.obs_type == "state":
|
||||
raise NotImplementedError(
|
||||
"The 'state' observation type is not supported in VLABenchEnv. "
|
||||
"Please use 'pixels' or 'pixels_agent_pos'."
|
||||
)
|
||||
elif self.obs_type == "pixels":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(
|
||||
{
|
||||
"image": spaces.Box(low=0, high=255, shape=(h, w, 3), dtype=np.uint8),
|
||||
"second_image": spaces.Box(low=0, high=255, shape=(h, w, 3), dtype=np.uint8),
|
||||
"wrist_image": spaces.Box(low=0, high=255, shape=(h, w, 3), dtype=np.uint8),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(
|
||||
{
|
||||
"image": spaces.Box(low=0, high=255, shape=(h, w, 3), dtype=np.uint8),
|
||||
"second_image": spaces.Box(low=0, high=255, shape=(h, w, 3), dtype=np.uint8),
|
||||
"wrist_image": spaces.Box(low=0, high=255, shape=(h, w, 3), dtype=np.uint8),
|
||||
}
|
||||
),
|
||||
"agent_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
|
||||
self.action_space = spaces.Box(low=ACTION_LOW, high=ACTION_HIGH, dtype=np.float32)
|
||||
|
||||
# Max attempts to rebuild the underlying env when MuJoCo throws
|
||||
# `PhysicsError` (e.g. mjWARN_BADQACC) during VLABench's 20-step
|
||||
# reset warm-up. Some random task/layout samples land in unstable
|
||||
# initial configurations; re-sampling the layout almost always
|
||||
# gives a stable one. A handful of upstream tasks (notably
|
||||
# `select_mahjong`) have layout samplers that diverge often enough
|
||||
# to need >>5 retries, so we pick a generous ceiling.
|
||||
_ENSURE_ENV_MAX_ATTEMPTS = 20
|
||||
|
||||
def _ensure_env(self) -> None:
|
||||
"""Create the underlying VLABench env on first use.
|
||||
|
||||
Called inside the worker subprocess after fork(), so each worker gets
|
||||
its own clean rendering context rather than inheriting a stale one from
|
||||
the parent process (which causes crashes with AsyncVectorEnv).
|
||||
|
||||
Retries on `PhysicsError`: VLABench's `LM4ManipDMEnv.reset()` runs 20
|
||||
warm-up `step()` calls while toggling gravity/fluids to let the scene
|
||||
settle; for some random layouts MuJoCo's integrator diverges and
|
||||
raises `mjWARN_BADQACC`. Re-sampling the layout almost always yields
|
||||
a stable one, so we retry a number of times before giving up. Between
|
||||
attempts we reseed NumPy's global RNG from OS entropy so the upstream
|
||||
task sampler explores fresh initial states — without this, retries
|
||||
can replay the same diverging configuration when the sampler is
|
||||
deterministic given the current RNG state.
|
||||
"""
|
||||
if self._env is not None:
|
||||
return
|
||||
|
||||
import VLABench.robots # noqa: F401 # type: ignore[import-untyped]
|
||||
import VLABench.tasks # noqa: F401 # type: ignore[import-untyped]
|
||||
from dm_control.rl.control import PhysicsError # type: ignore[import-untyped]
|
||||
from VLABench.envs import load_env # type: ignore[import-untyped]
|
||||
|
||||
h, w = self.render_resolution
|
||||
last_exc: PhysicsError | None = None
|
||||
for attempt in range(1, self._ENSURE_ENV_MAX_ATTEMPTS + 1):
|
||||
try:
|
||||
env = load_env(task=self.task, robot=self.robot, render_resolution=(h, w))
|
||||
self._env = env
|
||||
break
|
||||
except PhysicsError as exc:
|
||||
last_exc = exc
|
||||
logger.warning(
|
||||
"PhysicsError on attempt %d/%d while building task '%s': %s. Retrying with fresh layout…",
|
||||
attempt,
|
||||
self._ENSURE_ENV_MAX_ATTEMPTS,
|
||||
self.task,
|
||||
exc,
|
||||
)
|
||||
np.random.seed(None)
|
||||
if self._env is None:
|
||||
assert last_exc is not None
|
||||
raise RuntimeError(
|
||||
f"VLABench task '{self.task}' failed to produce a stable "
|
||||
f"initial layout after {self._ENSURE_ENV_MAX_ATTEMPTS} "
|
||||
f"attempts. This task's upstream sampler diverges too "
|
||||
f"often for the configured robot; consider removing it "
|
||||
f"from the eval set. Last physics error: {last_exc}"
|
||||
) from last_exc
|
||||
|
||||
# Extract task description from the dm_control task
|
||||
task_obj = self._env.task
|
||||
if hasattr(task_obj, "task_description"):
|
||||
self.task_description = task_obj.task_description
|
||||
elif hasattr(task_obj, "language_instruction"):
|
||||
self.task_description = task_obj.language_instruction
|
||||
else:
|
||||
self.task_description = self.task
|
||||
|
||||
# Cache robot base world position so `_build_ctrl_from_action` and
|
||||
# `_get_obs` can translate between robot-frame (dataset) and
|
||||
# world-frame (dm_control) without hitting physics every call.
|
||||
try:
|
||||
self._robot_base_xyz = np.asarray(self._env.get_robot_frame_position(), dtype=np.float64).reshape(
|
||||
3
|
||||
)
|
||||
except Exception:
|
||||
# Fallback to VLABench's default Franka base position.
|
||||
self._robot_base_xyz = np.array([0.0, -0.4, 0.78], dtype=np.float64)
|
||||
|
||||
def _get_obs(self) -> dict:
|
||||
"""Get current observation from the environment."""
|
||||
assert self._env is not None
|
||||
|
||||
obs = self._env.get_observation()
|
||||
h, w = self.render_resolution
|
||||
|
||||
def _to_hwc3(arr: np.ndarray) -> np.ndarray:
|
||||
"""Coerce any camera array to the declared (h, w, 3) uint8 shape."""
|
||||
a = np.asarray(arr)
|
||||
# Drop a leading singleton batch dim if present.
|
||||
while a.ndim > 3 and a.shape[0] == 1:
|
||||
a = a[0]
|
||||
if a.ndim == 3 and a.shape[0] in (1, 3, 4) and a.shape[-1] not in (1, 3, 4):
|
||||
# CHW → HWC
|
||||
a = np.transpose(a, (1, 2, 0))
|
||||
if a.ndim == 2:
|
||||
a = np.stack([a] * 3, axis=-1)
|
||||
if a.ndim != 3:
|
||||
return np.zeros((h, w, 3), dtype=np.uint8)
|
||||
# Force 3 channels.
|
||||
if a.shape[-1] == 1:
|
||||
a = np.repeat(a, 3, axis=-1)
|
||||
elif a.shape[-1] == 4:
|
||||
a = a[..., :3]
|
||||
elif a.shape[-1] != 3:
|
||||
return np.zeros((h, w, 3), dtype=np.uint8)
|
||||
if a.shape[:2] != (h, w):
|
||||
a = cv2.resize(a, (w, h), interpolation=cv2.INTER_AREA)
|
||||
return a.astype(np.uint8)
|
||||
|
||||
# Extract camera images — VLABench returns (n_cameras, C, H, W) or individual arrays
|
||||
raw_frames: list[np.ndarray] = []
|
||||
if "rgb" in obs:
|
||||
rgb = obs["rgb"]
|
||||
if isinstance(rgb, np.ndarray):
|
||||
if rgb.ndim == 4:
|
||||
raw_frames = [rgb[i] for i in range(rgb.shape[0])]
|
||||
elif rgb.ndim == 3:
|
||||
raw_frames = [rgb]
|
||||
|
||||
image_keys = ["image", "second_image", "wrist_image"]
|
||||
images: dict[str, np.ndarray] = {}
|
||||
for i, key in enumerate(image_keys):
|
||||
if i < len(raw_frames):
|
||||
images[key] = _to_hwc3(raw_frames[i])
|
||||
else:
|
||||
images[key] = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
|
||||
# Convert VLABench's raw ee_state `[pos_world(3), quat_wxyz(4), open(1)]`
|
||||
# to the dataset's observation.state layout `[pos_robot(3), euler_xyz(3),
|
||||
# gripper(1)]`. See VLABench/scripts/convert_to_lerobot.py — positions
|
||||
# are stored in robot-base frame and orientations as scipy extrinsic
|
||||
# 'xyz' euler angles.
|
||||
raw = np.asarray(obs.get("ee_state", np.zeros(8)), dtype=np.float64).ravel()
|
||||
pos_world = raw[:3] if raw.size >= 3 else np.zeros(3, dtype=np.float64)
|
||||
quat_wxyz = raw[3:7] if raw.size >= 7 else np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
||||
gripper = float(raw[7]) if raw.size >= 8 else 0.0
|
||||
|
||||
base = self._robot_base_xyz if self._robot_base_xyz is not None else np.zeros(3, dtype=np.float64)
|
||||
pos_robot = pos_world - base
|
||||
euler_xyz = Rotation.from_quat([quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]]).as_euler(
|
||||
"xyz", degrees=False
|
||||
)
|
||||
|
||||
ee_state = np.concatenate([pos_robot, euler_xyz, [gripper]]).astype(np.float64)
|
||||
|
||||
if self.obs_type == "pixels":
|
||||
return {"pixels": images}
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
return {
|
||||
"pixels": images,
|
||||
"agent_pos": ee_state.astype(np.float64),
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown obs_type: {self.obs_type}")
|
||||
|
||||
# ---- Action adaptation (EEF → joint ctrl) --------------------------------
|
||||
#
|
||||
# The HF vlabench datasets log 7D actions
|
||||
# `[x, y, z (robot frame), rx, ry, rz (scipy extrinsic xyz), gripper]`,
|
||||
# exactly matching VLABench's own eval pipeline (evaluator.base):
|
||||
# pos, euler, g = policy(...)
|
||||
# quat = euler_to_quaternion(*euler) # extrinsic xyz -> wxyz
|
||||
# _, qpos = robot.get_qpos_from_ee_pos(physics, pos=pos + base, quat=quat)
|
||||
# env.step(np.concatenate([qpos, [g, g]]))
|
||||
#
|
||||
# VLABench's dm_control task writes `data.ctrl[:] = action` directly — for
|
||||
# Franka that's 9 entries (7 arm joints + 2 gripper fingers). We mirror the
|
||||
# above conversion so the policy's EEF commands actually drive the robot.
|
||||
|
||||
_FRANKA_FINGER_OPEN = 0.04 # qpos when gripper fully open
|
||||
|
||||
def _build_ctrl_from_action(self, action: np.ndarray, ctrl_dim: int) -> np.ndarray:
|
||||
"""Convert a 7D EEF action into the `ctrl_dim`-sized joint command vector.
|
||||
|
||||
For the Franka default (ctrl_dim=9): 7 arm joint qposes (via IK) +
|
||||
2 gripper finger qposes (open/closed based on the gripper scalar).
|
||||
If the action is already joint-space (shape matches ctrl_dim), pass
|
||||
through.
|
||||
"""
|
||||
if action.shape[0] == ctrl_dim:
|
||||
return action.astype(np.float64, copy=False)
|
||||
|
||||
if action.shape[0] != 7:
|
||||
# Unknown layout — fall back to zero-pad so the sim doesn't crash.
|
||||
padded = np.zeros(ctrl_dim, dtype=np.float64)
|
||||
padded[: min(action.shape[0], ctrl_dim)] = action[:ctrl_dim]
|
||||
return padded
|
||||
|
||||
from dm_control.utils.inverse_kinematics import qpos_from_site_pose
|
||||
|
||||
# Action position is in robot-base frame (see convert_to_lerobot.py);
|
||||
# dm_control's IK expects a world-frame target.
|
||||
base = self._robot_base_xyz if self._robot_base_xyz is not None else np.zeros(3, dtype=np.float64)
|
||||
pos_world = np.asarray(action[:3], dtype=np.float64) + base
|
||||
rx, ry, rz = float(action[3]), float(action[4]), float(action[5])
|
||||
gripper = float(np.clip(action[6], 0.0, 1.0))
|
||||
|
||||
# Dataset euler is scipy extrinsic 'xyz' (same as VLABench's
|
||||
# `euler_to_quaternion`). scipy emits `[x, y, z, w]`; dm_control's IK
|
||||
# and MuJoCo use `[w, x, y, z]`, so reorder.
|
||||
qxyzw = Rotation.from_euler("xyz", [rx, ry, rz], degrees=False).as_quat()
|
||||
quat = np.array([qxyzw[3], qxyzw[0], qxyzw[1], qxyzw[2]], dtype=np.float64)
|
||||
|
||||
assert self._env is not None
|
||||
robot = self._env.task.robot
|
||||
site_name = robot.end_effector_site.full_identifier
|
||||
|
||||
# inplace=False so IK doesn't mutate physics state mid-step — we only
|
||||
# want the solved qpos. Fetch a fresh physics handle — caching it can
|
||||
# yield a stale weakref after a reset.
|
||||
ik_result = qpos_from_site_pose(
|
||||
self._env.physics,
|
||||
site_name=site_name,
|
||||
target_pos=pos_world,
|
||||
target_quat=quat,
|
||||
inplace=False,
|
||||
max_steps=100,
|
||||
)
|
||||
n_dof = robot.n_dof # 7 for Franka
|
||||
arm_qpos = ik_result.qpos[:n_dof]
|
||||
|
||||
# Dataset gripper convention: 1 = open (finger qpos = 0.04),
|
||||
# 0 = closed (finger qpos = 0.0). See VLABench/scripts/convert_to_lerobot.py
|
||||
# where `trajectory[i][-1] > 0.03` is encoded as `1`.
|
||||
finger_qpos = gripper * self._FRANKA_FINGER_OPEN
|
||||
|
||||
ctrl = np.zeros(ctrl_dim, dtype=np.float64)
|
||||
ctrl[:n_dof] = arm_qpos
|
||||
# Remaining entries are gripper fingers (usually 2 for Franka).
|
||||
ctrl[n_dof:] = finger_qpos
|
||||
return ctrl
|
||||
|
||||
def reset(self, seed=None, **kwargs) -> tuple[RobotObservation, dict[str, Any]]:
|
||||
self._ensure_env()
|
||||
assert self._env is not None
|
||||
super().reset(seed=seed)
|
||||
|
||||
if seed is not None:
|
||||
self._seed_inner_env(int(self.np_random.integers(0, 2**31 - 1)))
|
||||
|
||||
self._env.reset()
|
||||
|
||||
observation = self._get_obs()
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
|
||||
def _seed_inner_env(self, seed: int) -> None:
|
||||
"""Propagate `seed` to the inner dm_control env. `Environment.reset()`
|
||||
doesn't accept a seed, so we re-seed the task and environment
|
||||
`RandomState`s directly. Best-effort: silently skipped when the
|
||||
expected attributes are absent on a given VLABench version.
|
||||
"""
|
||||
for owner_attr, rng_attr in (("task", "random"), (None, "_random_state")):
|
||||
owner = getattr(self._env, owner_attr) if owner_attr else self._env
|
||||
rng = getattr(owner, rng_attr, None)
|
||||
rng_seed = getattr(rng, "seed", None)
|
||||
if callable(rng_seed):
|
||||
rng_seed(seed)
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]:
|
||||
from dm_control.rl.control import PhysicsError # type: ignore[import-untyped]
|
||||
|
||||
self._ensure_env()
|
||||
assert self._env is not None
|
||||
|
||||
if action.ndim != 1:
|
||||
raise ValueError(
|
||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||
f"but got shape {action.shape} with ndim={action.ndim}"
|
||||
)
|
||||
|
||||
if self.action_mode not in ("eef", "joint", "delta_eef"):
|
||||
raise ValueError(f"Unknown action_mode: {self.action_mode}")
|
||||
|
||||
# Always refetch physics — dm_control returns a weakref proxy that can
|
||||
# go stale across resets.
|
||||
physics = self._env.physics
|
||||
ctrl_dim = int(physics.data.ctrl.shape[0])
|
||||
ctrl = self._build_ctrl_from_action(action, ctrl_dim)
|
||||
try:
|
||||
timestep = self._env.step(ctrl)
|
||||
except PhysicsError as exc:
|
||||
# Physics integrator diverged (e.g. mjWARN_BADQACC). Treat it as
|
||||
# a graceful failed termination rather than a hard crash — the
|
||||
# rest of the multi-task eval should still run.
|
||||
logger.warning(
|
||||
"PhysicsError during step on task '%s': %s. Terminating episode.",
|
||||
self.task,
|
||||
exc,
|
||||
)
|
||||
observation = self._get_obs()
|
||||
info = {"task": self.task, "is_success": False, "physics_error": True}
|
||||
# Drop the stale env so the next reset() rebuilds it cleanly.
|
||||
with contextlib.suppress(Exception):
|
||||
self._env.close()
|
||||
self._env = None
|
||||
return observation, 0.0, True, False, info
|
||||
|
||||
# Extract reward from dm_control timestep
|
||||
reward = float(timestep.reward) if timestep.reward is not None else 0.0
|
||||
|
||||
# Check success via the task's termination condition
|
||||
is_success = False
|
||||
if hasattr(self._env, "task") and hasattr(self._env.task, "should_terminate_episode"):
|
||||
is_success = bool(self._env.task.should_terminate_episode(self._env.physics))
|
||||
|
||||
terminated = is_success
|
||||
truncated = False
|
||||
info = {
|
||||
"task": self.task,
|
||||
"is_success": is_success,
|
||||
}
|
||||
|
||||
observation = self._get_obs()
|
||||
|
||||
if terminated:
|
||||
self.reset()
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def render(self) -> np.ndarray:
|
||||
self._ensure_env()
|
||||
obs = self._get_obs()
|
||||
return obs["pixels"]["image"]
|
||||
|
||||
def close(self):
|
||||
if self._env is not None:
|
||||
self._env.close()
|
||||
self._env = None
|
||||
|
||||
|
||||
# ---- Main API ----------------------------------------------------------------
|
||||
|
||||
|
||||
def create_vlabench_envs(
|
||||
task: str,
|
||||
n_envs: int,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized VLABench environments with a consistent return shape.
|
||||
|
||||
Returns:
|
||||
dict[suite_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
|
||||
|
||||
Notes:
|
||||
- n_envs is the number of rollouts *per task*.
|
||||
- `task` can be a suite name ("primitive", "composite"), a comma-separated list of
|
||||
suite names, or individual task names (e.g. "select_fruit,heat_food").
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
task_groups = [t.strip() for t in task.split(",") if t.strip()]
|
||||
if not task_groups:
|
||||
raise ValueError("`task` must contain at least one VLABench task or suite name.")
|
||||
|
||||
logger.info(
|
||||
"Creating VLABench envs | task_groups=%s | n_envs(per task)=%d",
|
||||
task_groups,
|
||||
n_envs,
|
||||
)
|
||||
|
||||
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||
cached_obs_space = None
|
||||
cached_act_space = None
|
||||
cached_metadata = None
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for group in task_groups:
|
||||
# Check if it's a suite name, otherwise treat as individual task
|
||||
tasks = SUITE_TASKS.get(group, [group])
|
||||
|
||||
for tid, task_name in enumerate(tasks):
|
||||
logger.info(
|
||||
"Building vec env | group=%s | task_id=%d | task=%s",
|
||||
group,
|
||||
tid,
|
||||
task_name,
|
||||
)
|
||||
|
||||
fns = [(lambda tn=task_name: VLABenchEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||
|
||||
if is_async:
|
||||
lazy = _LazyAsyncVectorEnv(fns, cached_obs_space, cached_act_space, cached_metadata)
|
||||
if cached_obs_space is None:
|
||||
cached_obs_space = lazy.observation_space
|
||||
cached_act_space = lazy.action_space
|
||||
cached_metadata = lazy.metadata
|
||||
out[group][tid] = lazy
|
||||
else:
|
||||
out[group][tid] = env_cls(fns)
|
||||
|
||||
return {group: dict(task_map) for group, task_map in out.items()}
|
||||
@@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterpolator
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
@@ -23,6 +21,7 @@ from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .rtc import ActionInterpolator as ActionInterpolator
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
|
||||
@@ -1,4 +1,116 @@
|
||||
# Moved to lerobot.utils.action_interpolator — re-exported for backwards compatibility.
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ["ActionInterpolator"]
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
||||
self._buffer = [action.clone()]
|
||||
self._prev = action.clone()
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
|
||||
@@ -92,10 +92,10 @@ class ActionQueue:
|
||||
Returns:
|
||||
int: Number of unconsumed actions.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.queue is None:
|
||||
return 0
|
||||
return len(self.queue) - self.last_index
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
return length - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the queue is empty.
|
||||
@@ -103,10 +103,11 @@ class ActionQueue:
|
||||
Returns:
|
||||
bool: True if no actions remain, False otherwise.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.queue is None:
|
||||
return True
|
||||
return len(self.queue) - self.last_index <= 0
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
"""Get the current action consumption index.
|
||||
@@ -114,8 +115,7 @@ class ActionQueue:
|
||||
Returns:
|
||||
int: Index of the next action to be consumed.
|
||||
"""
|
||||
with self.lock:
|
||||
return self.last_index
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor | None:
|
||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||
|
||||
@@ -655,7 +655,6 @@ class VLAFlowMatching(nn.Module):
|
||||
pad_masks.append(image_start_mask)
|
||||
|
||||
img_emb = self.vlm_with_expert.embed_image(img)
|
||||
img_emb = img_emb
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
|
||||
@@ -76,7 +76,6 @@ from lerobot.transport.utils import (
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
@@ -95,6 +94,7 @@ from .gym_manipulator import (
|
||||
make_robot_env,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from .process import ProcessSignalHandler
|
||||
from .queue import get_last_item_from_queue
|
||||
|
||||
# Main entry point
|
||||
|
||||
@@ -90,7 +90,6 @@ from lerobot.utils.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
)
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.utils.utils import (
|
||||
@@ -100,6 +99,7 @@ from lerobot.utils.utils import (
|
||||
|
||||
from .buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
|
||||
from .process import ProcessSignalHandler
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Policy deployment engine with pluggable rollout strategies."""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("datasets", extra="dataset")
|
||||
|
||||
from .configs import (
|
||||
BaseStrategyConfig,
|
||||
DAggerKeyboardConfig,
|
||||
DAggerPedalConfig,
|
||||
DAggerStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
RolloutConfig,
|
||||
RolloutStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
from .context import (
|
||||
DatasetContext,
|
||||
HardwareContext,
|
||||
PolicyContext,
|
||||
ProcessorContext,
|
||||
RolloutContext,
|
||||
RuntimeContext,
|
||||
build_rollout_context,
|
||||
)
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
InferenceEngineConfig,
|
||||
RTCInferenceConfig,
|
||||
RTCInferenceEngine,
|
||||
SyncInferenceConfig,
|
||||
SyncInferenceEngine,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .ring_buffer import RolloutRingBuffer
|
||||
from .robot_wrapper import ThreadSafeRobot
|
||||
from .strategies import RolloutStrategy, create_strategy
|
||||
|
||||
__all__ = [
|
||||
"BaseStrategyConfig",
|
||||
"DAggerKeyboardConfig",
|
||||
"DAggerPedalConfig",
|
||||
"DAggerStrategyConfig",
|
||||
"DatasetContext",
|
||||
"HardwareContext",
|
||||
"HighlightStrategyConfig",
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"PolicyContext",
|
||||
"ProcessorContext",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"RolloutConfig",
|
||||
"RolloutContext",
|
||||
"RolloutRingBuffer",
|
||||
"RolloutStrategy",
|
||||
"RolloutStrategyConfig",
|
||||
"RuntimeContext",
|
||||
"SentryStrategyConfig",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
"ThreadSafeRobot",
|
||||
"build_rollout_context",
|
||||
"create_inference_engine",
|
||||
"create_strategy",
|
||||
]
|
||||
@@ -1,310 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration dataclasses for the rollout deployment engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs import PreTrainedConfig, parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.robots.config import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
|
||||
from .inference import InferenceEngineConfig, SyncInferenceConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategy configs (polymorphic dispatch via draccus ChoiceRegistry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutStrategyConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
"""Abstract base for rollout strategy configurations.
|
||||
|
||||
Use ``--strategy.type=<name>`` on the CLI to select a strategy.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("base")
|
||||
@dataclass
|
||||
class BaseStrategyConfig(RolloutStrategyConfig):
|
||||
"""Autonomous rollout with no data recording."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("sentry")
|
||||
@dataclass
|
||||
class SentryStrategyConfig(RolloutStrategyConfig):
|
||||
"""Continuous autonomous rollout with always-on recording.
|
||||
|
||||
Episode duration is derived from camera resolution, FPS, and
|
||||
``target_video_file_size_mb`` so that each saved episode produces a
|
||||
video file that has crossed the target size. This aligns episode
|
||||
boundaries with the dataset's video file chunking, so each
|
||||
``push_to_hub`` call uploads complete video files rather than
|
||||
re-uploading a growing file that hasn't crossed the chunk boundary.
|
||||
"""
|
||||
|
||||
upload_every_n_episodes: int = 5
|
||||
# Target video file size in MB for episode rotation. Episodes are
|
||||
# saved once the estimated video duration would exceed this limit.
|
||||
# Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when set to None.
|
||||
target_video_file_size_mb: float | None = None
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("highlight")
|
||||
@dataclass
|
||||
class HighlightStrategyConfig(RolloutStrategyConfig):
|
||||
"""Autonomous rollout with on-demand recording via ring buffer.
|
||||
|
||||
A memory-bounded ring buffer continuously captures telemetry. When
|
||||
the user presses the save key, the buffer contents are flushed to
|
||||
the dataset and live recording continues until the key is pressed
|
||||
again.
|
||||
"""
|
||||
|
||||
ring_buffer_seconds: float = 30.0
|
||||
ring_buffer_max_memory_mb: float = 2048.0
|
||||
save_key: str = "s"
|
||||
push_key: str = "h"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DAggerKeyboardConfig:
|
||||
"""Keyboard key bindings for DAgger controls.
|
||||
|
||||
Keys are specified as single characters (e.g. ``"c"``, ``"h"``) or
|
||||
special key names (``"space"``).
|
||||
"""
|
||||
|
||||
pause_resume: str = "space"
|
||||
correction: str = "tab"
|
||||
upload: str = "enter"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DAggerPedalConfig:
|
||||
"""Foot pedal configuration for DAgger controls.
|
||||
|
||||
Pedal codes are evdev key code strings (e.g. ``"KEY_A"``).
|
||||
"""
|
||||
|
||||
device_path: str = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
pause_resume: str = "KEY_A"
|
||||
correction: str = "KEY_B"
|
||||
upload: str = "KEY_C"
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("dagger")
|
||||
@dataclass
|
||||
class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
"""Human-in-the-loop data collection (DAgger / RaC).
|
||||
|
||||
Alternates between autonomous policy execution and human intervention.
|
||||
Intervention frames are tagged with ``intervention=True``.
|
||||
|
||||
Input is controlled via either a keyboard or foot pedal, selected by
|
||||
``input_device``. Each device exposes three actions:
|
||||
|
||||
1. **pause_resume** — toggle policy execution on/off.
|
||||
2. **correction** — toggle human correction recording.
|
||||
3. **upload** — push dataset to hub on demand (corrections-only mode).
|
||||
|
||||
When ``record_autonomous=True`` (default) both autonomous and correction
|
||||
frames are recorded with size-based episode rotation (same as Sentry)
|
||||
and background uploading. ``push_to_hub`` is blocked while a correction
|
||||
is in progress. Set to ``False`` to record only the human-correction
|
||||
windows, where each correction becomes its own episode.
|
||||
"""
|
||||
|
||||
# Number of correction episodes to collect (corrections-only mode).
|
||||
# When None, falls back to ``--dataset.num_episodes``.
|
||||
num_episodes: int | None = None
|
||||
record_autonomous: bool = False
|
||||
upload_every_n_episodes: int = 5
|
||||
# Target video file size in MB for episode rotation (record_autonomous
|
||||
# mode only). Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when None.
|
||||
target_video_file_size_mb: float | None = None
|
||||
input_device: str = "keyboard"
|
||||
keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig)
|
||||
pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.input_device not in ("keyboard", "pedal"):
|
||||
raise ValueError(f"DAgger input_device must be 'keyboard' or 'pedal', got '{self.input_device}'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Top-level rollout config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutConfig:
|
||||
"""Top-level configuration for the ``lerobot-rollout`` CLI.
|
||||
|
||||
Combines hardware, policy, strategy, and runtime settings. The
|
||||
``__post_init__`` method performs fail-fast validation to reject
|
||||
invalid flag combinations early.
|
||||
"""
|
||||
|
||||
# Hardware
|
||||
robot: RobotConfig | None = None
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
|
||||
# Policy (loaded from --policy.path via __post_init__)
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Strategy (polymorphic: --strategy.type=base|sentry|highlight|dagger)
|
||||
strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig)
|
||||
|
||||
# Inference backend (polymorphic: --inference.type=sync|rtc)
|
||||
inference: InferenceEngineConfig = field(default_factory=SyncInferenceConfig)
|
||||
|
||||
# Dataset (required for sentry, highlight, dagger; None for base)
|
||||
dataset: DatasetRecordConfig | None = None
|
||||
|
||||
# Runtime
|
||||
fps: float = 30.0
|
||||
duration: float = 0.0 # 0 = infinite (24/7 mode)
|
||||
interpolation_multiplier: int = 1
|
||||
device: str | None = None
|
||||
task: str = ""
|
||||
display_data: bool = False
|
||||
# Display data on a remote Rerun server
|
||||
display_ip: str | None = None
|
||||
# Port of the remote Rerun server
|
||||
display_port: int | None = None
|
||||
# Whether to display compressed images in Rerun
|
||||
display_compressed_images: bool = False
|
||||
# Use vocal synthesis to read events
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
|
||||
# Torch compile
|
||||
use_torch_compile: bool = False
|
||||
torch_compile_backend: str = "inductor"
|
||||
torch_compile_mode: str = "default"
|
||||
compile_warmup_inferences: int = 2
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate config invariants and load the policy config from ``--policy.path``."""
|
||||
# --- Strategy-specific validation ---
|
||||
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
|
||||
raise ValueError("DAgger strategy requires --teleop.type to be set")
|
||||
|
||||
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||
needs_dataset = isinstance(
|
||||
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||
)
|
||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||
|
||||
# if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None:
|
||||
# raise ValueError(
|
||||
# "Base strategy does not record data. Use sentry, highlight, or dagger for recording."
|
||||
# )
|
||||
|
||||
# Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop
|
||||
if (
|
||||
isinstance(self.strategy, SentryStrategyConfig)
|
||||
and self.dataset is not None
|
||||
and not self.dataset.streaming_encoding
|
||||
):
|
||||
logger.warning("Sentry mode forces streaming_encoding=True")
|
||||
self.dataset.streaming_encoding = True
|
||||
|
||||
# Highlight writes frames while the policy is still running, so streaming is mandatory.
|
||||
if (
|
||||
isinstance(self.strategy, HighlightStrategyConfig)
|
||||
and self.dataset is not None
|
||||
and not self.dataset.streaming_encoding
|
||||
):
|
||||
logger.warning("Highlight mode forces streaming_encoding=True")
|
||||
self.dataset.streaming_encoding = True
|
||||
|
||||
# DAgger: streaming is mandatory only when the autonomous phase is also recorded.
|
||||
if isinstance(self.strategy, DAggerStrategyConfig) and self.dataset is not None:
|
||||
if self.strategy.record_autonomous and not self.dataset.streaming_encoding:
|
||||
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
|
||||
self.dataset.streaming_encoding = True
|
||||
elif not self.strategy.record_autonomous and not self.dataset.streaming_encoding:
|
||||
logger.info(
|
||||
"Streaming encoding is disabled for DAgger corrections-only mode. "
|
||||
"Consider enabling it for faster episode saving: "
|
||||
"--dataset.streaming_encoding=true --dataset.encoder_threads=2"
|
||||
)
|
||||
|
||||
# DAgger: resolve num_episodes from dataset config when not explicitly set.
|
||||
if isinstance(self.strategy, DAggerStrategyConfig) and self.strategy.num_episodes is None:
|
||||
if self.dataset is not None:
|
||||
self.strategy.num_episodes = self.dataset.num_episodes
|
||||
logger.info(
|
||||
"DAgger num_episodes not set — using --dataset.num_episodes=%d",
|
||||
self.strategy.num_episodes,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"DAgger num_episodes must be set either via --strategy.num_episodes or --dataset.num_episodes"
|
||||
)
|
||||
|
||||
# --- Policy loading ---
|
||||
if self.robot is None:
|
||||
raise ValueError("--robot.type is required for rollout")
|
||||
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("--policy.path is required for rollout")
|
||||
|
||||
# --- Task resolution ---
|
||||
# When --dataset.rename_map (or any --dataset.* flag) is passed, draccus
|
||||
# creates a DatasetRecordConfig with single_task="". If the user set
|
||||
# the task via the top-level --task flag, propagate it so that all
|
||||
# downstream consumers (inference engine, dataset frame builders) see it.
|
||||
if self.dataset is not None and not self.dataset.single_task and self.task:
|
||||
self.dataset.single_task = self.task
|
||||
elif self.dataset is not None and self.dataset.single_task and not self.task:
|
||||
self.task = self.dataset.single_task
|
||||
|
||||
# --- Device resolution ---
|
||||
# Resolve device from the policy config when not explicitly set so all
|
||||
# components (policy.to, preprocessor, inference engine) use the same
|
||||
# device string instead of inconsistent fallbacks.
|
||||
if self.device is None and self.policy is not None:
|
||||
resolved = getattr(self.policy, "device", None)
|
||||
if resolved:
|
||||
self.device = resolved
|
||||
logger.info("Resolved device from policy config: %s", self.device)
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
@@ -1,496 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rollout context: shared state created once before strategy dispatch.
|
||||
|
||||
Grouped into five topical sub-contexts — :class:`RuntimeContext`,
|
||||
:class:`HardwareContext`, :class:`PolicyContext`, :class:`ProcessorContext`,
|
||||
and :class:`DatasetContext` — assembled into :class:`RolloutContext`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PreTrainedConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
aggregate_pipeline_dataset_features,
|
||||
create_initial_features,
|
||||
)
|
||||
from lerobot.policies import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import (
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_processors,
|
||||
rename_stats,
|
||||
)
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
|
||||
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
RTCInferenceConfig,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .robot_wrapper import ThreadSafeRobot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_action_key_order(
|
||||
policy_action_names: list[str] | None, dataset_action_names: list[str]
|
||||
) -> list[str]:
|
||||
"""Choose action name ordering for mapping policy tensor outputs to robot action dicts."""
|
||||
if not policy_action_names:
|
||||
return dataset_action_names
|
||||
policy_action_names = list(policy_action_names)
|
||||
if len(policy_action_names) != len(dataset_action_names):
|
||||
logger.warning(
|
||||
"policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order",
|
||||
len(policy_action_names),
|
||||
len(dataset_action_names),
|
||||
)
|
||||
return dataset_action_names
|
||||
if set(dataset_action_names) != set(policy_action_names):
|
||||
logger.warning("policy.action_feature_names keys don't match dataset; using dataset order")
|
||||
return dataset_action_names
|
||||
return policy_action_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sub-contexts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RuntimeContext:
|
||||
"""Runtime knobs shared with every strategy."""
|
||||
|
||||
cfg: RolloutConfig
|
||||
shutdown_event: Event
|
||||
|
||||
|
||||
@dataclass
|
||||
class HardwareContext:
|
||||
"""Connected hardware.
|
||||
|
||||
The raw robot is available via ``robot_wrapper.inner`` when needed
|
||||
(e.g. for disconnect); strategies should otherwise go through the
|
||||
thread-safe wrapper.
|
||||
|
||||
``initial_position`` stores the robot's joint positions at connect
|
||||
time. Strategies use it to return the robot to a safe pose before
|
||||
shutting down.
|
||||
"""
|
||||
|
||||
robot_wrapper: ThreadSafeRobot
|
||||
teleop: Teleoperator | None
|
||||
initial_position: dict | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""Loaded policy and its inference engine."""
|
||||
|
||||
policy: PreTrainedPolicy
|
||||
preprocessor: PolicyProcessorPipeline
|
||||
postprocessor: PolicyProcessorPipeline
|
||||
inference: InferenceEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorContext:
|
||||
"""Robot-side pipelines (run outside the policy)."""
|
||||
|
||||
teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
|
||||
robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]
|
||||
robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetContext:
|
||||
"""Dataset and feature bookkeeping."""
|
||||
|
||||
dataset: LeRobotDataset | None
|
||||
dataset_features: dict = field(default_factory=dict)
|
||||
hw_features: dict = field(default_factory=dict)
|
||||
ordered_action_keys: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutContext:
|
||||
"""Bundle of sub-contexts passed to every rollout strategy.
|
||||
|
||||
Built once by :func:`build_rollout_context` before strategy dispatch.
|
||||
"""
|
||||
|
||||
runtime: RuntimeContext
|
||||
hardware: HardwareContext
|
||||
policy: PolicyContext
|
||||
processors: ProcessorContext
|
||||
data: DatasetContext
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_rollout_context(
|
||||
cfg: RolloutConfig,
|
||||
shutdown_event: Event,
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None,
|
||||
) -> RolloutContext:
|
||||
"""Wire up policy, processors, hardware, dataset, and inference engine.
|
||||
|
||||
The order is policy-first / hardware-last so a bad ``--policy.path``
|
||||
fails fast without touching the robot.
|
||||
"""
|
||||
is_rtc = isinstance(cfg.inference, RTCInferenceConfig)
|
||||
|
||||
# --- 1. Policy (heavy I/O, but no hardware yet) -------------------
|
||||
logger.info("Loading policy from '%s'...", cfg.policy.pretrained_path)
|
||||
policy_config = cfg.policy
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
for attr in ("device", "use_amp"):
|
||||
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
|
||||
cli_val = getattr(cfg.policy, attr)
|
||||
if cli_val is not None:
|
||||
setattr(full_config, attr, cli_val)
|
||||
|
||||
if hasattr(full_config, "compile_model"):
|
||||
full_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if full_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
if full_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
if hasattr(policy, "init_rtc_processor"):
|
||||
policy.init_rtc_processor()
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info("Policy loaded: type=%s, device=%s", policy_config.type, cfg.device)
|
||||
|
||||
if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"):
|
||||
try:
|
||||
if hasattr(torch, "compile"):
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
"options": {"triton.cudagraphs": False},
|
||||
}
|
||||
policy.predict_action_chunk = torch.compile(policy.predict_action_chunk, **compile_kwargs)
|
||||
logger.info("torch.compile applied to predict_action_chunk")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
|
||||
# --- 2. Robot-side processors (user-supplied or defaults) --------
|
||||
if (
|
||||
teleop_action_processor is None
|
||||
or robot_action_processor is None
|
||||
or robot_observation_processor is None
|
||||
):
|
||||
_t, _r, _o = make_default_processors()
|
||||
teleop_action_processor = teleop_action_processor or _t
|
||||
robot_action_processor = robot_action_processor or _r
|
||||
robot_observation_processor = robot_observation_processor or _o
|
||||
|
||||
# --- 3. Hardware (heaviest side-effect, deferred) -----------------
|
||||
logger.info("Connecting robot (%s)...", cfg.robot.type if cfg.robot else "?")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
logger.info("Robot connected: %s", robot.name)
|
||||
|
||||
# Store the initial joint positions so we can return to a safe pose on shutdown.
|
||||
initial_obs = robot.get_observation()
|
||||
initial_position = {k: v for k, v in initial_obs.items() if k.endswith(".pos")}
|
||||
logger.info("Captured initial robot position (%d keys)", len(initial_position))
|
||||
|
||||
robot_wrapper = ThreadSafeRobot(robot)
|
||||
|
||||
teleop = None
|
||||
if cfg.teleop is not None:
|
||||
logger.info("Connecting teleoperator (%s)...", cfg.teleop.type if cfg.teleop else "?")
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
teleop.connect()
|
||||
logger.info("Teleoperator connected")
|
||||
|
||||
# DAgger requires teleop with motor control capabilities (enable_torque,
|
||||
# disable_torque, write_goal_positions).
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
|
||||
# required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
|
||||
# missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
|
||||
# if missing:
|
||||
# teleop.disconnect()
|
||||
# raise ValueError(
|
||||
# f"DAgger strategy requires a teleoperator with motor control methods "
|
||||
# f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}"
|
||||
# )
|
||||
|
||||
# --- 4. Features + action-key reconciliation ---------------------
|
||||
# Only `.pos` joint features are used for policy inference — velocity and
|
||||
# torque channels are observation-only and must be excluded from the state
|
||||
# and action tensors that the policy sees. This matches the filtering
|
||||
# applied by the old ``hil_data_collection`` script.
|
||||
all_obs_features = robot.observation_features
|
||||
observation_features_hw = {
|
||||
k: v
|
||||
for k, v in all_obs_features.items()
|
||||
if isinstance(v, tuple) or (v is float and k.endswith(".pos"))
|
||||
}
|
||||
action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
|
||||
|
||||
# The action side is always needed: sync inference reads action names from
|
||||
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
|
||||
action_dataset_features = aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=cfg.dataset.video if cfg.dataset else True,
|
||||
)
|
||||
# Observation-side aggregation is needed because of build_dataset_frame
|
||||
observation_dataset_features = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=observation_features_hw),
|
||||
use_videos=cfg.dataset.video if cfg.dataset else True,
|
||||
)
|
||||
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
|
||||
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
raw_action_keys = list(action_features_hw.keys())
|
||||
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
||||
ordered_action_keys = _resolve_action_key_order(
|
||||
list(policy_action_names) if policy_action_names else None,
|
||||
raw_action_keys,
|
||||
)
|
||||
|
||||
# --- Diagnostic logging ---
|
||||
_act_ft = dataset_features.get("action", {})
|
||||
_obs_ft = dataset_features.get("observation.state", {})
|
||||
logger.info(
|
||||
"Feature reconciliation: action_dim=%d, obs_state_dim=%d, ordered_action_keys=%d",
|
||||
_act_ft.get("shape", (0,))[0],
|
||||
_obs_ft.get("shape", (0,))[0],
|
||||
len(ordered_action_keys),
|
||||
)
|
||||
logger.info(" action names : %s", _act_ft.get("names", []))
|
||||
logger.info(" obs state names: %s", _obs_ft.get("names", []))
|
||||
logger.info(" ordered keys : %s", ordered_action_keys)
|
||||
logger.info(
|
||||
" policy.action_feature_names: %s",
|
||||
list(policy_action_names) if policy_action_names else "None (using raw_action_keys)",
|
||||
)
|
||||
if full_config.input_features:
|
||||
logger.info(" policy input_features: %s", list(full_config.input_features.keys()))
|
||||
else:
|
||||
logger.warning(" policy input_features is EMPTY — policy may not process images!")
|
||||
if full_config.output_features:
|
||||
for k, v in full_config.output_features.items():
|
||||
logger.info(" policy output_features[%s]: shape=%s", k, v.shape)
|
||||
# Validate action dimension consistency
|
||||
if full_config.output_features:
|
||||
for ft in full_config.output_features.values():
|
||||
policy_action_dim = ft.shape[0]
|
||||
if len(ordered_action_keys) != policy_action_dim:
|
||||
logger.error(
|
||||
"ACTION DIM MISMATCH: policy expects %d dims, hardware produces %d keys. "
|
||||
"First 5 keys: %s",
|
||||
policy_action_dim,
|
||||
len(ordered_action_keys),
|
||||
ordered_action_keys[:5],
|
||||
)
|
||||
break
|
||||
|
||||
# Validate visual features if no rename_map is active
|
||||
rename_map = cfg.dataset.rename_map if cfg.dataset else {}
|
||||
if not rename_map:
|
||||
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {
|
||||
f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
|
||||
}
|
||||
policy_subset = expected_visuals.issubset(provided_visuals)
|
||||
hw_subset = provided_visuals.issubset(expected_visuals)
|
||||
if not (policy_subset or hw_subset):
|
||||
raise ValueError(
|
||||
f"Visual feature mismatch between policy and robot hardware.\n"
|
||||
f"Policy expects: {expected_visuals}\n"
|
||||
f"Robot provides: {provided_visuals}"
|
||||
)
|
||||
|
||||
# --- 5. Dataset -------------
|
||||
dataset = None
|
||||
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
|
||||
logger.info("Setting up dataset (repo_id=%s)...", cfg.dataset.repo_id)
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset.resume(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
)
|
||||
else:
|
||||
if isinstance(cfg.strategy, DAggerStrategyConfig):
|
||||
dataset_features["intervention"] = {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
root=cfg.dataset.root,
|
||||
robot_type=robot.name,
|
||||
features=dataset_features,
|
||||
use_videos=cfg.dataset.video,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
if dataset is not None:
|
||||
logger.info("Dataset ready: %s (%d existing episodes)", dataset.repo_id, dataset.num_episodes)
|
||||
|
||||
# --- 6. Policy pre/post processors (needs dataset stats if any) ---
|
||||
dataset_stats = None
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.dataset.rename_map if cfg.dataset else {},
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
|
||||
},
|
||||
)
|
||||
|
||||
# --- Debug: verify normalizer stats loaded from pretrained ---
|
||||
from lerobot.processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, NormalizerProcessorStep):
|
||||
n_stats = sum(len(v) for v in step._tensor_stats.values()) if step._tensor_stats else 0
|
||||
logger.info(
|
||||
"Preprocessor normalizer: %d stat tensors, keys=%s",
|
||||
n_stats,
|
||||
list(step._tensor_stats.keys())[:3],
|
||||
)
|
||||
if n_stats == 0:
|
||||
logger.error("PREPROCESSOR NORMALIZER HAS NO STATS — observations will NOT be normalized!")
|
||||
for step in postprocessor.steps:
|
||||
if isinstance(step, UnnormalizerProcessorStep):
|
||||
n_stats = sum(len(v) for v in step._tensor_stats.values()) if step._tensor_stats else 0
|
||||
logger.info(
|
||||
"Postprocessor unnormalizer: %d stat tensors, keys=%s",
|
||||
n_stats,
|
||||
list(step._tensor_stats.keys())[:3],
|
||||
)
|
||||
if n_stats == 0:
|
||||
logger.error("POSTPROCESSOR UNNORMALIZER HAS NO STATS — actions will NOT be denormalized!")
|
||||
|
||||
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
|
||||
logger.info(
|
||||
"Creating inference engine (type=%s)...",
|
||||
cfg.inference.type if hasattr(cfg.inference, "type") else "sync",
|
||||
)
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
inference_strategy = create_inference_engine(
|
||||
cfg.inference,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
robot_wrapper=robot_wrapper,
|
||||
hw_features=hw_features,
|
||||
dataset_features=dataset_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
task=task_str,
|
||||
fps=cfg.fps,
|
||||
device=cfg.device,
|
||||
use_torch_compile=cfg.use_torch_compile,
|
||||
compile_warmup_inferences=cfg.compile_warmup_inferences,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
|
||||
# --- 8. Assemble ---------------------------------------------------
|
||||
logger.info("Rollout context assembled successfully")
|
||||
return RolloutContext(
|
||||
runtime=RuntimeContext(cfg=cfg, shutdown_event=shutdown_event),
|
||||
hardware=HardwareContext(
|
||||
robot_wrapper=robot_wrapper, teleop=teleop, initial_position=initial_position
|
||||
),
|
||||
policy=PolicyContext(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
inference=inference_strategy,
|
||||
),
|
||||
processors=ProcessorContext(
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
),
|
||||
data=DatasetContext(
|
||||
dataset=dataset,
|
||||
dataset_features=dataset_features,
|
||||
hw_features=hw_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
),
|
||||
)
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference engine package — backend-agnostic action production.
|
||||
|
||||
Concrete strategies (sync, RTC, …) expose the same small interface so
|
||||
rollout strategies never branch on the inference backend.
|
||||
"""
|
||||
|
||||
from .base import InferenceEngine
|
||||
from .factory import (
|
||||
InferenceEngineConfig,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
__all__ = [
|
||||
"InferenceEngine",
|
||||
"InferenceEngineConfig",
|
||||
"RTCInferenceConfig",
|
||||
"RTCInferenceEngine",
|
||||
"SyncInferenceConfig",
|
||||
"SyncInferenceEngine",
|
||||
"create_inference_engine",
|
||||
]
|
||||
@@ -1,88 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference engine ABC.
|
||||
|
||||
Rollout strategies consume actions through this small interface so they
|
||||
do not need to know whether the inference engine is synchronous, runs in
|
||||
a background thread (RTC), or comes from an external source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class InferenceEngine(abc.ABC):
|
||||
"""Abstract backend for producing actions during rollout.
|
||||
|
||||
Subclasses decide whether inference happens inline, in a background
|
||||
thread, or externally. The contract is minimal so new backends can
|
||||
be added without touching rollout strategies.
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
``start`` — prepare the backend (e.g. launch a background thread).
|
||||
``stop`` — shut the backend down cleanly.
|
||||
``reset`` — clear episode-scoped state (policy hidden state, queues…).
|
||||
|
||||
Action production
|
||||
-----------------
|
||||
``get_action(obs_frame)`` — return the next action tensor, or
|
||||
``None`` if none is available (e.g. async queue empty). Sync
|
||||
backends always compute from ``obs_frame``; async backends may
|
||||
ignore it (they get observations via ``notify_observation``).
|
||||
|
||||
Optional hooks
|
||||
--------------
|
||||
``notify_observation`` / ``pause`` / ``resume`` have a no-op default
|
||||
so rollout strategies can invoke them unconditionally.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def start(self) -> None:
|
||||
"""Initialise the backend."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""Tear the backend down."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Clear episode-scoped state."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Return the next action tensor, or ``None`` if unavailable."""
|
||||
|
||||
def notify_observation(self, obs: dict) -> None: # noqa: B027
|
||||
"""Publish the latest processed observation. Default: no-op."""
|
||||
|
||||
def pause(self) -> None: # noqa: B027
|
||||
"""Pause background inference. Default: no-op."""
|
||||
|
||||
def resume(self) -> None: # noqa: B027
|
||||
"""Resume background inference. Default: no-op."""
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""True once the backend can produce actions (e.g. warmup done)."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def failed(self) -> bool:
|
||||
"""True if an unrecoverable error occurred in the backend."""
|
||||
return False
|
||||
@@ -1,129 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference engine configs and factory.
|
||||
|
||||
Selection is explicit via ``--inference.type=sync|rtc``. Adding a new
|
||||
backend requires registering its config subclass and dispatching it in
|
||||
:func:`create_inference_engine`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .base import InferenceEngine
|
||||
from .rtc import RTCInferenceEngine
|
||||
from .sync import SyncInferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceEngineConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
"""Abstract base for inference backend configuration.
|
||||
|
||||
Use ``--inference.type=<name>`` on the CLI to select a backend.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@InferenceEngineConfig.register_subclass("sync")
|
||||
@dataclass
|
||||
class SyncInferenceConfig(InferenceEngineConfig):
|
||||
"""Inline synchronous inference (one policy call per control tick)."""
|
||||
|
||||
|
||||
@InferenceEngineConfig.register_subclass("rtc")
|
||||
@dataclass
|
||||
class RTCInferenceConfig(InferenceEngineConfig):
|
||||
"""Real-Time Chunking: async policy inference in a background thread."""
|
||||
|
||||
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
|
||||
# constructing one here costs nothing and keeps draccus' CLI surface flat
|
||||
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
queue_threshold: int = 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_inference_engine(
|
||||
config: InferenceEngineConfig,
|
||||
*,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
robot_wrapper: ThreadSafeRobot,
|
||||
hw_features: dict,
|
||||
dataset_features: dict,
|
||||
ordered_action_keys: list[str],
|
||||
task: str,
|
||||
fps: float,
|
||||
device: str | None,
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> InferenceEngine:
|
||||
"""Instantiate the appropriate inference engine from a config object."""
|
||||
logger.info("Creating inference engine: %s", config.type)
|
||||
if isinstance(config, SyncInferenceConfig):
|
||||
return SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset_features=dataset_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
task=task,
|
||||
device=device,
|
||||
robot_type=robot_wrapper.robot_type,
|
||||
)
|
||||
if isinstance(config, RTCInferenceConfig):
|
||||
return RTCInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
robot_wrapper=robot_wrapper,
|
||||
rtc_config=config.rtc,
|
||||
hw_features=hw_features,
|
||||
task=task,
|
||||
fps=fps,
|
||||
device=device,
|
||||
use_torch_compile=use_torch_compile,
|
||||
compile_warmup_inferences=compile_warmup_inferences,
|
||||
rtc_queue_threshold=config.queue_threshold,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
raise ValueError(f"Unknown inference engine type: {type(config).__name__}")
|
||||
@@ -1,391 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Real-Time Chunking inference engine.
|
||||
|
||||
A background thread produces action chunks asynchronously via
|
||||
:meth:`policy.predict_action_chunk`. The main control loop polls
|
||||
``get_action`` for the next ready action; observations flow the other
|
||||
way via ``notify_observation``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .base import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# How long the RTC loop sleeps when paused, idle, or backpressured by a full queue.
|
||||
_RTC_IDLE_SLEEP_S: float = 0.01
|
||||
# Backoff between transient inference errors (per consecutive failure).
|
||||
_RTC_ERROR_RETRY_DELAY_S: float = 0.5
|
||||
# Consecutive transient errors tolerated before giving up and propagating shutdown.
|
||||
_RTC_MAX_CONSECUTIVE_ERRORS: int = 10
|
||||
# Hard timeout for joining the RTC thread on stop().
|
||||
_RTC_JOIN_TIMEOUT_S: float = 3.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTC helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: torch.Tensor,
|
||||
current_state: torch.Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
|
||||
|
||||
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
|
||||
is stored in absolute coordinates. Before feeding it back to the policy, this
|
||||
helper re-expresses those actions relative to the robot's current joint state
|
||||
and optionally normalizes them so the policy receives correctly scaled inputs.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
|
||||
|
||||
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
||||
"""Pad or truncate RTC prefix actions to a fixed length for stable compiled inference."""
|
||||
if prev_actions.ndim != 2:
|
||||
raise ValueError(f"Expected 2D [T, A] tensor, got shape={tuple(prev_actions.shape)}")
|
||||
steps, action_dim = prev_actions.shape
|
||||
if steps == target_steps:
|
||||
return prev_actions
|
||||
if steps > target_steps:
|
||||
return prev_actions[:target_steps]
|
||||
padded = torch.zeros((target_steps, action_dim), dtype=prev_actions.dtype, device=prev_actions.device)
|
||||
padded[:steps] = prev_actions
|
||||
return padded
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTCInferenceEngine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RTCInferenceEngine(InferenceEngine):
|
||||
"""Async RTC inference: a background thread produces action chunks.
|
||||
|
||||
``get_action`` pops the next action from the shared queue (or
|
||||
returns ``None`` if the queue is empty). The main loop should call
|
||||
``notify_observation`` every tick and ``pause``/``resume`` around
|
||||
human-intervention phases.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
robot_wrapper: ThreadSafeRobot,
|
||||
rtc_config: RTCConfig,
|
||||
hw_features: dict,
|
||||
task: str,
|
||||
fps: float,
|
||||
device: str | None,
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
rtc_queue_threshold: int = 30,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._preprocessor = preprocessor
|
||||
self._postprocessor = postprocessor
|
||||
self._robot = robot_wrapper
|
||||
self._rtc_config = rtc_config
|
||||
self._hw_features = hw_features
|
||||
self._task = task
|
||||
self._fps = fps
|
||||
self._device = device or "cpu"
|
||||
self._use_torch_compile = use_torch_compile
|
||||
self._compile_warmup_inferences = compile_warmup_inferences
|
||||
self._rtc_queue_threshold = rtc_queue_threshold
|
||||
|
||||
self._action_queue: ActionQueue | None = None
|
||||
self._obs_holder: dict[str, Any] = {}
|
||||
self._obs_lock = Lock()
|
||||
self._policy_active = Event()
|
||||
self._compile_warmup_done = Event()
|
||||
self._shutdown_event = Event()
|
||||
self._rtc_error = Event()
|
||||
self._global_shutdown_event = shutdown_event
|
||||
self._rtc_thread: Thread | None = None
|
||||
|
||||
if not self._use_torch_compile:
|
||||
self._compile_warmup_done.set()
|
||||
logger.info("RTCInferenceEngine initialized (torch.compile disabled, no warmup needed)")
|
||||
else:
|
||||
logger.info(
|
||||
"RTCInferenceEngine initialized (torch.compile enabled, %d warmup inferences)",
|
||||
compile_warmup_inferences,
|
||||
)
|
||||
|
||||
# Processor introspection for relative-action re-anchoring.
|
||||
self._relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
self._normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
if self._relative_step is not None:
|
||||
if self._relative_step.action_names is None:
|
||||
cfg_names = getattr(policy.config, "action_feature_names", None)
|
||||
if cfg_names:
|
||||
self._relative_step.action_names = list(cfg_names)
|
||||
else:
|
||||
self._relative_step.action_names = [
|
||||
k for k in robot_wrapper.action_features if k.endswith(".pos")
|
||||
]
|
||||
logger.info("Relative actions enabled: RTC prefix will be re-anchored")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def ready(self) -> bool:
|
||||
"""True once torch.compile warmup is complete (or immediately if compile is disabled)."""
|
||||
return self._compile_warmup_done.is_set()
|
||||
|
||||
@property
|
||||
def failed(self) -> bool:
|
||||
"""True if the RTC background thread exited due to an unrecoverable error."""
|
||||
return self._rtc_error.is_set()
|
||||
|
||||
@property
|
||||
def action_queue(self) -> ActionQueue | None:
|
||||
"""The shared action queue between the RTC thread and the main loop."""
|
||||
return self._action_queue
|
||||
|
||||
def start(self) -> None:
|
||||
"""Launch the RTC background thread."""
|
||||
self._action_queue = ActionQueue(self._rtc_config)
|
||||
self._obs_holder = {
|
||||
"obs": None,
|
||||
"robot_type": self._robot.robot_type,
|
||||
}
|
||||
self._shutdown_event.clear()
|
||||
self._rtc_thread = Thread(
|
||||
target=self._rtc_loop,
|
||||
daemon=True,
|
||||
name="RTCInference",
|
||||
)
|
||||
self._rtc_thread.start()
|
||||
logger.info("RTC inference thread started")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the RTC thread to stop and wait for it."""
|
||||
logger.info("Stopping RTC inference thread...")
|
||||
self._shutdown_event.set()
|
||||
self._policy_active.clear()
|
||||
if self._rtc_thread is not None and self._rtc_thread.is_alive():
|
||||
self._rtc_thread.join(timeout=_RTC_JOIN_TIMEOUT_S)
|
||||
if self._rtc_thread.is_alive():
|
||||
logger.warning("RTC thread did not join within %.1fs", _RTC_JOIN_TIMEOUT_S)
|
||||
else:
|
||||
logger.info("RTC inference thread stopped")
|
||||
self._rtc_thread = None
|
||||
|
||||
def pause(self) -> None:
|
||||
"""Pause the RTC background thread."""
|
||||
logger.info("Pausing RTC inference thread")
|
||||
self._policy_active.clear()
|
||||
|
||||
def resume(self) -> None:
|
||||
"""Resume the RTC background thread."""
|
||||
logger.info("Resuming RTC inference thread")
|
||||
self._policy_active.set()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the policy, processors, and action queue."""
|
||||
logger.info("Resetting RTC inference state (policy + processors + queue)")
|
||||
self._policy.reset()
|
||||
self._preprocessor.reset()
|
||||
self._postprocessor.reset()
|
||||
if self._action_queue is not None:
|
||||
self._action_queue.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Action production (called from main thread)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Pop the next action from the RTC queue (ignores ``obs_frame``)."""
|
||||
if self._action_queue is None:
|
||||
return None
|
||||
return self._action_queue.get()
|
||||
|
||||
def notify_observation(self, obs: dict) -> None:
|
||||
"""Publish the latest observation for the RTC thread to consume."""
|
||||
with self._obs_lock:
|
||||
self._obs_holder["obs"] = obs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# RTC: background inference thread
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _rtc_loop(self) -> None:
|
||||
"""Background thread that generates action chunks via RTC."""
|
||||
try:
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / self._fps
|
||||
policy_device = torch.device(self._device)
|
||||
|
||||
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
|
||||
inference_count = 0
|
||||
consecutive_errors = 0
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
if not self._policy_active.is_set():
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
queue = self._action_queue
|
||||
with self._obs_lock:
|
||||
obs = self._obs_holder.get("obs")
|
||||
if queue is None or obs is None:
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
continue
|
||||
|
||||
if queue.qsize() <= self._rtc_queue_threshold:
|
||||
try:
|
||||
current_time = time.perf_counter()
|
||||
idx_before = queue.get_action_index()
|
||||
prev_actions = queue.get_left_over()
|
||||
|
||||
latency = latency_tracker.max()
|
||||
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
||||
|
||||
obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation")
|
||||
obs_batch = prepare_observation_for_inference(
|
||||
obs_batch, policy_device, self._task, self._robot.robot_type
|
||||
)
|
||||
obs_batch["task"] = [self._task]
|
||||
|
||||
preprocessed = self._preprocessor(obs_batch)
|
||||
|
||||
if prev_actions is not None and self._relative_step is not None:
|
||||
state_tensor = preprocessed.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
prev_abs = queue.get_processed_left_over()
|
||||
if prev_abs is not None and prev_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_abs,
|
||||
current_state=state_tensor,
|
||||
relative_step=self._relative_step,
|
||||
normalizer_step=self._normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
|
||||
if prev_actions is not None:
|
||||
prev_actions = _normalize_prev_actions_length(
|
||||
prev_actions, target_steps=self._rtc_config.execution_horizon
|
||||
)
|
||||
|
||||
actions = self._policy.predict_action_chunk(
|
||||
preprocessed, inference_delay=delay, prev_chunk_left_over=prev_actions
|
||||
)
|
||||
|
||||
original = actions.squeeze(0).clone()
|
||||
processed = self._postprocessor(actions).squeeze(0)
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
|
||||
inference_count += 1
|
||||
consecutive_errors = 0
|
||||
is_warmup = self._use_torch_compile and inference_count <= warmup_required
|
||||
if is_warmup:
|
||||
latency_tracker.reset()
|
||||
else:
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
queue.merge(original, processed, new_delay, idx_before)
|
||||
|
||||
if (
|
||||
is_warmup
|
||||
and inference_count >= warmup_required
|
||||
and not self._compile_warmup_done.is_set()
|
||||
):
|
||||
self._compile_warmup_done.set()
|
||||
logger.info("Compile warmup complete (%d inferences)", inference_count)
|
||||
|
||||
logger.debug("RTC inference latency=%.2fs, queue=%d", new_latency, queue.qsize())
|
||||
|
||||
except Exception as e:
|
||||
consecutive_errors += 1
|
||||
logger.error(
|
||||
"RTC inference error (%d/%d): %s",
|
||||
consecutive_errors,
|
||||
_RTC_MAX_CONSECUTIVE_ERRORS,
|
||||
e,
|
||||
)
|
||||
logger.debug(traceback.format_exc())
|
||||
if consecutive_errors >= _RTC_MAX_CONSECUTIVE_ERRORS:
|
||||
# Persistent failure: stop retrying and propagate shutdown.
|
||||
raise
|
||||
time.sleep(_RTC_ERROR_RETRY_DELAY_S)
|
||||
else:
|
||||
time.sleep(_RTC_IDLE_SLEEP_S)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Fatal error in RTC thread: %s", e)
|
||||
logger.error(traceback.format_exc())
|
||||
self._rtc_error.set()
|
||||
# Unblock any warmup waiters so the main loop doesn't spin forever
|
||||
self._compile_warmup_done.set()
|
||||
# Signal the top-level shutdown so strategies exit their control loops
|
||||
if self._global_shutdown_event is not None:
|
||||
self._global_shutdown_event.set()
|
||||
@@ -1,127 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Synchronous inference engine: inline policy call per control tick."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncInferenceEngine(InferenceEngine):
|
||||
"""Inline synchronous inference: compute one action per call.
|
||||
|
||||
``get_action`` runs the full policy pipeline (pre/post-processor +
|
||||
``select_action``) on the given observation frame and returns a
|
||||
CPU action tensor reordered to match the dataset action keys.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: PreTrainedPolicy,
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
postprocessor: PolicyProcessorPipeline,
|
||||
dataset_features: dict,
|
||||
ordered_action_keys: list[str],
|
||||
task: str,
|
||||
device: str | None,
|
||||
robot_type: str,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._preprocessor = preprocessor
|
||||
self._postprocessor = postprocessor
|
||||
self._dataset_features = dataset_features
|
||||
self._ordered_action_keys = ordered_action_keys
|
||||
self._task = task
|
||||
self._device = torch.device(device or "cpu")
|
||||
self._robot_type = robot_type
|
||||
logger.info(
|
||||
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
|
||||
self._device,
|
||||
len(ordered_action_keys),
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""No background resources to start."""
|
||||
logger.info("SyncInferenceEngine started (inline mode — no background thread)")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""No background resources to stop."""
|
||||
logger.info("SyncInferenceEngine stopped")
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the policy and pre/post-processors."""
|
||||
logger.info("Resetting sync inference state (policy + processors)")
|
||||
self._policy.reset()
|
||||
self._preprocessor.reset()
|
||||
self._postprocessor.reset()
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
|
||||
if obs_frame is None:
|
||||
return None
|
||||
# Shallow copy is intentional: the caller (`send_next_action`) builds
|
||||
# ``obs_frame`` fresh per tick via ``build_dataset_frame``, so the
|
||||
# tensor/array values are not shared with any other reader.
|
||||
observation = copy(obs_frame)
|
||||
autocast_ctx = (
|
||||
torch.autocast(device_type=self._device.type)
|
||||
if self._device.type == "cuda" and self._policy.config.use_amp
|
||||
else nullcontext()
|
||||
)
|
||||
with torch.inference_mode(), autocast_ctx:
|
||||
observation = prepare_observation_for_inference(
|
||||
observation, self._device, self._task, self._robot_type
|
||||
)
|
||||
observation = self._preprocessor(observation)
|
||||
action_raw = self._policy.select_action(observation)
|
||||
action = self._postprocessor(action_raw)
|
||||
action_tensor = action.squeeze(0).cpu()
|
||||
|
||||
if not hasattr(self, "_log_count"):
|
||||
self._log_count = 0
|
||||
if self._log_count < 3:
|
||||
raw_flat = action_raw.squeeze(0).cpu()
|
||||
logger.info(
|
||||
"[Sync tick %d] raw action (first 5): %s | post-processed (first 5): %s",
|
||||
self._log_count,
|
||||
raw_flat[:5].tolist(),
|
||||
action_tensor[:5].tolist(),
|
||||
)
|
||||
obs_state = obs_frame.get("observation.state")
|
||||
if obs_state is not None:
|
||||
logger.info(
|
||||
"[Sync tick %d] obs_frame['observation.state'] (first 5): %s | shape: %s",
|
||||
self._log_count,
|
||||
obs_state[:5].tolist() if hasattr(obs_state, "tolist") else str(obs_state)[:80],
|
||||
obs_state.shape if hasattr(obs_state, "shape") else "?",
|
||||
)
|
||||
self._log_count += 1
|
||||
|
||||
# Reorder to match dataset action ordering so the caller can treat
|
||||
# the returned tensor uniformly across backends.
|
||||
action_dict = make_robot_action(action_tensor, self._dataset_features)
|
||||
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
|
||||
@@ -1,112 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Memory-bounded ring buffer for the Highlight Reel rollout strategy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class RolloutRingBuffer:
|
||||
"""Fixed-capacity circular buffer for observation/action frames.
|
||||
|
||||
Stores the last *N* seconds of telemetry in memory, bounded by both
|
||||
time (``max_frames``) and memory (``max_memory_bytes``). When either
|
||||
limit is reached the oldest frames are evicted.
|
||||
|
||||
.. note::
|
||||
This class is **single-threaded**. ``append``/``drain``/``clear``
|
||||
must all be called from the same thread (the rollout main loop).
|
||||
Concurrent access from a background thread will corrupt
|
||||
``_current_bytes`` accounting.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_seconds:
|
||||
Maximum duration of buffered telemetry.
|
||||
max_memory_mb:
|
||||
Hard memory cap in MiB. Frames are evicted when the estimated
|
||||
total size exceeds this.
|
||||
fps:
|
||||
Frames per second — used to convert ``max_seconds`` to a frame
|
||||
count.
|
||||
"""
|
||||
|
||||
def __init__(self, max_seconds: float = 30.0, max_memory_mb: float = 2048.0, fps: float = 30.0) -> None:
|
||||
self._max_frames = int(max_seconds * fps)
|
||||
self._max_bytes = int(max_memory_mb * 1024 * 1024)
|
||||
self._buffer: deque[dict] = deque(maxlen=self._max_frames)
|
||||
self._current_bytes: int = 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def append(self, frame: dict) -> None:
|
||||
"""Add *frame* to the buffer, evicting the oldest if at capacity."""
|
||||
frame_bytes = _estimate_frame_bytes(frame)
|
||||
|
||||
# Evict oldest frames until we are under the memory cap
|
||||
while self._current_bytes + frame_bytes > self._max_bytes and self._buffer:
|
||||
evicted = self._buffer.popleft()
|
||||
self._current_bytes -= _estimate_frame_bytes(evicted)
|
||||
|
||||
self._buffer.append(frame)
|
||||
self._current_bytes += frame_bytes
|
||||
|
||||
def drain(self) -> list[dict]:
|
||||
"""Return all buffered frames and clear the buffer."""
|
||||
frames = list(self._buffer)
|
||||
self._buffer.clear()
|
||||
self._current_bytes = 0
|
||||
return frames
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Discard all buffered frames."""
|
||||
self._buffer.clear()
|
||||
self._current_bytes = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._buffer)
|
||||
|
||||
@property
|
||||
def estimated_bytes(self) -> int:
|
||||
"""Estimated total byte size of all buffered frames."""
|
||||
return self._current_bytes
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _estimate_frame_bytes(frame: dict) -> int:
|
||||
"""Rough byte estimate for a single frame dictionary."""
|
||||
total = 0
|
||||
for v in frame.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
# ``torch.Tensor`` has no ``nbytes``; compute it explicitly so the
|
||||
# memory cap is honoured even when frames hold unconverted tensors.
|
||||
total += v.nelement() * v.element_size()
|
||||
elif isinstance(v, np.ndarray) or hasattr(v, "nbytes"):
|
||||
total += v.nbytes
|
||||
elif isinstance(v, (int, float)):
|
||||
total += 8
|
||||
elif isinstance(v, (str, bytes)):
|
||||
total += len(v)
|
||||
return max(total, 1) # avoid zero-size frames
|
||||
@@ -1,79 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Thread-safe robot wrapper for concurrent observation/action access."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
class ThreadSafeRobot:
|
||||
"""Lock-protected wrapper around a :class:`Robot` for use with background threads.
|
||||
|
||||
When RTC inference runs in a background thread while the main loop
|
||||
executes actions, both threads may access the robot concurrently.
|
||||
This wrapper serialises ``get_observation`` and ``send_action`` calls.
|
||||
|
||||
Read-only properties are proxied without the lock since they don't
|
||||
mutate hardware state.
|
||||
"""
|
||||
|
||||
def __init__(self, robot: Robot) -> None:
|
||||
self._robot = robot
|
||||
self._lock = Lock()
|
||||
|
||||
# -- Lock-protected I/O --------------------------------------------------
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
return self._robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict[str, Any] | Any) -> Any:
|
||||
with self._lock:
|
||||
return self._robot.send_action(action)
|
||||
|
||||
# -- Read-only proxies (no lock needed) -----------------------------------
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
return self._robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return self._robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._robot.name
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str:
|
||||
return self._robot.robot_type
|
||||
|
||||
@property
|
||||
def cameras(self):
|
||||
return getattr(self._robot, "cameras", {})
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._robot.is_connected
|
||||
|
||||
@property
|
||||
def inner(self) -> Robot:
|
||||
"""Access the underlying robot (e.g. for connect/disconnect)."""
|
||||
return self._robot
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rollout strategies — public API re-exports."""
|
||||
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .dagger import DAggerEvents, DAggerPhase, DAggerStrategy
|
||||
from .factory import create_strategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
__all__ = [
|
||||
"BaseStrategy",
|
||||
"DAggerEvents",
|
||||
"DAggerPhase",
|
||||
"DAggerStrategy",
|
||||
"HighlightStrategy",
|
||||
"RolloutStrategy",
|
||||
"SentryStrategy",
|
||||
"create_strategy",
|
||||
"estimate_max_episode_seconds",
|
||||
"safe_push_to_hub",
|
||||
"send_next_action",
|
||||
]
|
||||
@@ -1,90 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Base rollout strategy: autonomous policy execution with no data recording."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseStrategy(RolloutStrategy):
|
||||
"""Autonomous policy rollout with no data recording.
|
||||
|
||||
All actions flow through the ``robot_action_processor`` pipeline
|
||||
before reaching the robot.
|
||||
"""
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine."""
|
||||
self._init_engine(ctx)
|
||||
logger.info("Base strategy ready")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run the autonomous control loop until shutdown or duration expires."""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
interpolator = self._interpolator
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
# Flush a few observation reads so CAN bus / sensor state is fresh
|
||||
# before the first inference. Without this, the first observation(s)
|
||||
# can return stale or identical values for all joints, poisoning the
|
||||
# entire first action chunk.
|
||||
_OBS_WARMUP_READS = 5
|
||||
for _ in range(_OBS_WARMUP_READS):
|
||||
robot.get_observation()
|
||||
precise_sleep(1 / cfg.fps)
|
||||
logger.info("Flushed %d observation warmup reads", _OBS_WARMUP_READS)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
engine.resume()
|
||||
logger.info("Base strategy control loop started")
|
||||
|
||||
while not ctx.runtime.shutdown_event.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Disconnect hardware and stop inference."""
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Base strategy teardown complete")
|
||||
@@ -1,302 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rollout strategy ABC and shared action-dispatch helper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
from ..inference import InferenceEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..configs import RolloutStrategyConfig
|
||||
from ..context import HardwareContext, RolloutContext, RuntimeContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RolloutStrategy(abc.ABC):
|
||||
"""Abstract base for rollout execution strategies.
|
||||
|
||||
Each concrete strategy implements a self-contained control loop with
|
||||
its own recording/interaction semantics. Strategies are mutually
|
||||
exclusive — only one runs per session.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RolloutStrategyConfig) -> None:
|
||||
self.config = config
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._warmup_flushed: bool = False
|
||||
|
||||
def _init_engine(self, ctx: RolloutContext) -> None:
|
||||
"""Attach the inference engine and action interpolator, then start the backend.
|
||||
|
||||
Creates an :class:`ActionInterpolator` from the config's
|
||||
``interpolation_multiplier`` and starts the inference engine.
|
||||
Call this from ``setup()`` so strategies share identical
|
||||
initialisation without duplicating code.
|
||||
"""
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
|
||||
self._engine = ctx.policy.inference
|
||||
logger.info("Starting inference engine...")
|
||||
self._engine.start()
|
||||
# Reset policy and processor state so the first inference starts clean
|
||||
# (matches the old HIL script which called policy.reset() / preprocessor.reset()
|
||||
# at the beginning of each episode).
|
||||
self._engine.reset()
|
||||
self._warmup_flushed = False
|
||||
logger.info("Inference engine started")
|
||||
|
||||
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
|
||||
"""Handle torch.compile warmup phase.
|
||||
|
||||
Returns ``True`` if the caller should ``continue`` (still warming
|
||||
up). On the first post-warmup iteration the engine and
|
||||
interpolator are reset so stale warmup state is discarded.
|
||||
"""
|
||||
engine = self._engine
|
||||
interpolator = self._interpolator
|
||||
if not use_torch_compile:
|
||||
return False
|
||||
if not engine.ready:
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
return True
|
||||
if not self._warmup_flushed:
|
||||
logger.info("Warmup complete — flushing stale state and resuming engine")
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
self._warmup_flushed = True
|
||||
engine.resume()
|
||||
return False
|
||||
|
||||
def _teardown_hardware(self, hw: HardwareContext) -> None:
|
||||
"""Stop the inference engine, return robot to initial position, and disconnect hardware."""
|
||||
if self._engine is not None:
|
||||
logger.info("Stopping inference engine...")
|
||||
self._engine.stop()
|
||||
robot = hw.robot_wrapper.inner
|
||||
if robot.is_connected:
|
||||
if hw.initial_position:
|
||||
logger.info("Returning robot to initial position before shutdown...")
|
||||
self._return_to_initial_position(hw)
|
||||
logger.info("Disconnecting robot...")
|
||||
robot.disconnect()
|
||||
teleop = hw.teleop
|
||||
if teleop is not None and teleop.is_connected:
|
||||
logger.info("Disconnecting teleoperator...")
|
||||
teleop.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _return_to_initial_position(hw: HardwareContext, duration_s: float = 3.0, fps: int = 50) -> None:
|
||||
"""Smoothly interpolate the robot back to its initial position."""
|
||||
robot = hw.robot_wrapper
|
||||
target = hw.initial_position
|
||||
try:
|
||||
current_obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in current_obs.items() if k in target}
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
for step in range(1, steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current_pos:
|
||||
interp[k] = current_pos[k] * (1 - t) + target[k] * t
|
||||
robot.send_action(interp)
|
||||
precise_sleep(1 / fps)
|
||||
except Exception as e:
|
||||
logger.warning("Could not return to initial position: %s", e)
|
||||
|
||||
@staticmethod
|
||||
def _log_telemetry(
|
||||
obs_processed: dict | None,
|
||||
action_dict: dict | None,
|
||||
runtime_ctx: RuntimeContext,
|
||||
) -> None:
|
||||
"""Log observation/action telemetry to Rerun if display_data is enabled."""
|
||||
cfg = runtime_ctx.cfg
|
||||
if not cfg.display_data:
|
||||
return
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=action_dict,
|
||||
compress_images=cfg.display_compressed_images,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _warn_if_slow(dt: float, control_interval: float, fps: float) -> None:
|
||||
"""Log a warning when the control loop runs slower than target FPS."""
|
||||
if dt > control_interval:
|
||||
actual_fps = 1.0 / dt if dt > 0 else 0
|
||||
logger.warning(
|
||||
"Control loop is running slower (%.1f Hz) than target FPS (%.0f Hz). "
|
||||
"Dataset frames might be dropped and robot control might be unstable. "
|
||||
"Common causes: 1) Camera FPS not keeping up "
|
||||
"2) Policy inference taking too long 3) CPU starvation",
|
||||
actual_fps,
|
||||
fps,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Main rollout loop. Returns when shutdown is requested or duration expires."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Cleanup: save dataset, stop threads, disconnect hardware."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def safe_push_to_hub(dataset, tags=None, private=False) -> bool:
|
||||
"""Push dataset to hub, skipping if no episodes have been saved.
|
||||
|
||||
Returns ``True`` if the push was attempted, ``False`` if skipped.
|
||||
"""
|
||||
if dataset.num_episodes == 0:
|
||||
logger.warning("No episodes saved — skipping push to hub")
|
||||
return False
|
||||
dataset.push_to_hub(tags=tags, private=private)
|
||||
return True
|
||||
|
||||
|
||||
def estimate_max_episode_seconds(
|
||||
dataset_features: dict,
|
||||
fps: float,
|
||||
target_size_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
) -> float:
|
||||
"""Conservatively estimate how many seconds of video will exceed *target_size_mb*.
|
||||
|
||||
Each camera produces its own video file, so the episode duration is
|
||||
driven by the **slowest** camera to fill ``target_size_mb`` — i.e.
|
||||
the one with the fewest pixels per frame (lowest bitrate).
|
||||
|
||||
Uses a deliberately **low** bits-per-pixel estimate so the computed
|
||||
duration is *longer* than reality. By the time the timer fires the
|
||||
actual video file is guaranteed to have crossed the target size,
|
||||
which aligns episode boundaries with the dataset's video-file
|
||||
chunking — each ``push_to_hub`` uploads complete files rather than
|
||||
re-uploading a still-growing one.
|
||||
|
||||
The estimate ignores codec-specific settings (CRF, preset) on purpose:
|
||||
we only need a rough lower bound on bitrate, not a precise prediction.
|
||||
|
||||
Falls back to 600 s (10 min) when no video features are present.
|
||||
"""
|
||||
# 0.1 bits-per-pixel is a *low* estimate for CRF-30 streaming video of
|
||||
# robot footage (real-world is typically 0.1 – 0.3 bpp). Under-
|
||||
# estimating the bitrate over-estimates the time → the episode will be
|
||||
# *larger* than target_size_mb when we save, which is what we want.
|
||||
conservative_bpp = 0.1
|
||||
|
||||
# Collect per-camera pixel counts — each camera has its own video file.
|
||||
camera_pixels = []
|
||||
for feat in dataset_features.values():
|
||||
if feat.get("dtype") == "video":
|
||||
shape = feat.get("shape", ())
|
||||
|
||||
# Assuming shape could be (C, H, W) or (T, C, H, W)
|
||||
# We want to extract the spatial dimensions.
|
||||
if len(shape) >= 3:
|
||||
h, w = shape[-2], shape[-1]
|
||||
pixels = h * w
|
||||
if pixels > 0:
|
||||
camera_pixels.append(pixels)
|
||||
|
||||
if not camera_pixels:
|
||||
return 600.0
|
||||
|
||||
# Use the smallest camera: it produces the lowest bitrate and therefore
|
||||
# takes the longest to reach the target — the conservative choice.
|
||||
min_pixels = min(camera_pixels)
|
||||
bits_per_frame = min_pixels * conservative_bpp
|
||||
bytes_per_second = (bits_per_frame * fps) / 8
|
||||
|
||||
# Guard against division by zero just in case
|
||||
if bytes_per_second <= 0:
|
||||
return 600.0
|
||||
|
||||
return (target_size_mb * 1024 * 1024) / bytes_per_second
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared action-dispatch helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def send_next_action(
|
||||
obs_processed: dict,
|
||||
obs_raw: dict,
|
||||
ctx: RolloutContext,
|
||||
interpolator: ActionInterpolator,
|
||||
) -> dict | None:
|
||||
"""Dispatch the next action to the robot.
|
||||
|
||||
Pulls the next action tensor from the inference engine, feeds the
|
||||
interpolator, and sends the interpolated action through the
|
||||
``robot_action_processor`` to the robot. Works identically for
|
||||
sync and async backends — the rollout strategy never needs to branch.
|
||||
|
||||
Returns the action dict that was sent, or ``None`` if no action was
|
||||
ready (e.g. empty async queue, interpolator not yet primed).
|
||||
"""
|
||||
engine = ctx.policy.inference
|
||||
features = ctx.data.dataset_features
|
||||
ordered_keys = ctx.data.ordered_action_keys
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_tensor = engine.get_action(obs_frame)
|
||||
if action_tensor is not None:
|
||||
interpolator.add(action_tensor.cpu())
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is None:
|
||||
return None
|
||||
|
||||
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
|
||||
processed = ctx.processors.robot_action_processor((action_dict, obs_raw))
|
||||
|
||||
if not hasattr(send_next_action, "_log_count"):
|
||||
send_next_action._log_count = 0
|
||||
if send_next_action._log_count < 3:
|
||||
sample = {k: round(v, 4) for k, v in list(processed.items())[:5]}
|
||||
logger.info(
|
||||
"[send_next_action tick %d] action sent to robot (first 5): %s",
|
||||
send_next_action._log_count,
|
||||
sample,
|
||||
)
|
||||
send_next_action._log_count += 1
|
||||
|
||||
ctx.hardware.robot_wrapper.send_action(processed)
|
||||
return action_dict
|
||||
@@ -1,740 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""DAgger rollout strategy: Human-in-the-Loop data collection.
|
||||
|
||||
Implements the RaC paradigm (Recovery and Correction) for interactive
|
||||
imitation learning. Alternates between autonomous policy execution and
|
||||
human intervention via teleoperator.
|
||||
|
||||
Input is controlled via either a keyboard or foot pedal, selected by
|
||||
the ``input_device`` config field. Each device exposes three actions:
|
||||
|
||||
1. **pause_resume** — Toggle policy execution (AUTONOMOUS <-> PAUSED).
|
||||
2. **correction** — Toggle correction recording (PAUSED <-> CORRECTING).
|
||||
3. **upload** — Push dataset to hub on demand (corrections-only mode).
|
||||
ESC (keyboard only) — Stop session.
|
||||
|
||||
Recording Modes:
|
||||
``record_autonomous=True``: Sentry-like continuous recording with
|
||||
time-based episode rotation. Both autonomous and correction
|
||||
frames are recorded; corrections tagged ``intervention=True``.
|
||||
``record_autonomous=False``: Only correction windows are recorded.
|
||||
Each correction (start to stop) becomes one episode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event, Lock
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available
|
||||
from lerobot.utils.pedal import start_pedal_listener
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
PYNPUT_AVAILABLE = False
|
||||
else:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DAgger state machine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DAggerPhase(enum.Enum):
|
||||
"""Observable phases of a DAgger episode."""
|
||||
|
||||
AUTONOMOUS = "autonomous" # Policy driving
|
||||
PAUSED = "paused" # Engine paused, teleop aligned, awaiting input
|
||||
CORRECTING = "correcting" # Human driving via teleop, recording interventions
|
||||
|
||||
|
||||
# Valid (current_phase, event) -> next_phase
|
||||
_DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = {
|
||||
(DAggerPhase.AUTONOMOUS, "pause_resume"): DAggerPhase.PAUSED,
|
||||
(DAggerPhase.PAUSED, "pause_resume"): DAggerPhase.AUTONOMOUS,
|
||||
(DAggerPhase.PAUSED, "correction"): DAggerPhase.CORRECTING,
|
||||
(DAggerPhase.CORRECTING, "correction"): DAggerPhase.PAUSED,
|
||||
}
|
||||
|
||||
|
||||
class DAggerEvents:
|
||||
"""Thread-safe container for DAgger input device events.
|
||||
|
||||
The keyboard/pedal threads write transition requests; the main loop
|
||||
consumes them.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = Lock()
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition: str | None = None
|
||||
|
||||
# Session-level flags
|
||||
self.stop_recording = Event()
|
||||
self.upload_requested = Event()
|
||||
|
||||
# -- Thread-safe phase access ------------------------------------------
|
||||
|
||||
@property
|
||||
def phase(self) -> DAggerPhase:
|
||||
"""Current phase of the DAgger state machine."""
|
||||
with self._lock:
|
||||
return self._phase
|
||||
|
||||
@phase.setter
|
||||
def phase(self, value: DAggerPhase) -> None:
|
||||
with self._lock:
|
||||
self._phase = value
|
||||
|
||||
def request_transition(self, event: str) -> None:
|
||||
"""Request a phase transition (called from keyboard/pedal threads).
|
||||
|
||||
Only enqueues the request if it corresponds to a valid transition
|
||||
from the current phase, preventing impossible state changes.
|
||||
"""
|
||||
with self._lock:
|
||||
if (self._phase, event) in _DAGGER_TRANSITIONS:
|
||||
self._pending_transition = event
|
||||
|
||||
def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None:
|
||||
"""Consume a pending transition (called from main loop)."""
|
||||
with self._lock:
|
||||
if self._pending_transition is None:
|
||||
return None
|
||||
key = (self._phase, self._pending_transition)
|
||||
self._pending_transition = None
|
||||
new_phase = _DAGGER_TRANSITIONS.get(key)
|
||||
if new_phase is None:
|
||||
return None
|
||||
old_phase = self._phase
|
||||
self._phase = new_phase
|
||||
return old_phase, new_phase
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all transient state for a fresh session."""
|
||||
with self._lock:
|
||||
self._phase = DAggerPhase.AUTONOMOUS
|
||||
self._pending_transition = None
|
||||
self.upload_requested.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teleoperator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
|
||||
) -> None:
|
||||
"""Smoothly move teleop to target position via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support motor control methods
|
||||
(``enable_torque``, ``write_goal_positions``, ``get_action``).
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.write_goal_positions(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Input device handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
"""Initialise keyboard listener with DAgger 3-key controls.
|
||||
|
||||
Returns the pynput Listener (or ``None`` in headless mode or when
|
||||
pynput is unavailable).
|
||||
"""
|
||||
if not PYNPUT_AVAILABLE or is_headless():
|
||||
logger.warning("Headless environment or pynput unavailable — keyboard controls disabled")
|
||||
return None
|
||||
|
||||
# Map config key names to pynput Key objects for special keys
|
||||
special_keys = {
|
||||
"space": keyboard.Key.space,
|
||||
"tab": keyboard.Key.tab,
|
||||
"enter": keyboard.Key.enter,
|
||||
}
|
||||
|
||||
def _resolve_key(key) -> str | None:
|
||||
"""Resolve a pynput key event to a config-comparable string."""
|
||||
if key == keyboard.Key.esc:
|
||||
return "esc"
|
||||
for name, pynput_key in special_keys.items():
|
||||
if key == pynput_key:
|
||||
return name
|
||||
if hasattr(key, "char") and key.char:
|
||||
return key.char
|
||||
return None
|
||||
|
||||
# Build mapping: resolved key string -> DAgger event name
|
||||
key_to_event = {
|
||||
cfg.pause_resume: "pause_resume",
|
||||
cfg.correction: "correction",
|
||||
}
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
resolved = _resolve_key(key)
|
||||
if resolved is None:
|
||||
return
|
||||
if resolved == "esc":
|
||||
logger.info("Stop recording...")
|
||||
events.stop_recording.set()
|
||||
return
|
||||
if resolved in key_to_event:
|
||||
events.request_transition(key_to_event[resolved])
|
||||
if resolved == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
except Exception as e:
|
||||
logger.debug("Key error: %s", e)
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
logger.info(
|
||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
||||
cfg.pause_resume,
|
||||
cfg.correction,
|
||||
cfg.upload,
|
||||
)
|
||||
return listener
|
||||
|
||||
|
||||
def _init_dagger_pedal(events: DAggerEvents, cfg: DAggerPedalConfig):
|
||||
"""Initialise foot pedal listener with DAgger 3-pedal controls.
|
||||
|
||||
Returns the pedal listener thread (or ``None`` if evdev is unavailable).
|
||||
"""
|
||||
code_to_event = {
|
||||
cfg.pause_resume: "pause_resume",
|
||||
cfg.correction: "correction",
|
||||
}
|
||||
|
||||
def on_press(code: str) -> None:
|
||||
if code in code_to_event:
|
||||
events.request_transition(code_to_event[code])
|
||||
if code == cfg.upload:
|
||||
events.upload_requested.set()
|
||||
|
||||
logger.info("Initializing DAgger foot pedal listener (device=%s)", cfg.device_path)
|
||||
return start_pedal_listener(on_press, device_path=cfg.device_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DAgger Strategy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DAggerStrategy(RolloutStrategy):
|
||||
"""Human-in-the-Loop data collection with intervention tagging.
|
||||
|
||||
State machine::
|
||||
|
||||
AUTONOMOUS --(key1)--> PAUSED --(key2)--> CORRECTING --(key2)--> PAUSED
|
||||
--(key1)--> AUTONOMOUS
|
||||
|
||||
Recording modes:
|
||||
``record_autonomous=True``: Sentry-like continuous recording with
|
||||
time-based episode rotation. Intervention frames tagged True.
|
||||
``record_autonomous=False``: Only correction windows recorded.
|
||||
Each correction = one episode. Upload on demand via key3.
|
||||
"""
|
||||
|
||||
config: DAggerStrategyConfig
|
||||
|
||||
def __init__(self, config: DAggerStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._listener = None
|
||||
self._pedal_thread = None
|
||||
self._events = DAggerEvents()
|
||||
self._push_executor: ThreadPoolExecutor | None = None
|
||||
self._pending_push: Future | None = None
|
||||
self._needs_push = Event()
|
||||
self._episode_lock = Lock()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine and input device listener."""
|
||||
self._init_engine(ctx)
|
||||
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="dagger-push")
|
||||
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
self._episode_duration_s = estimate_max_episode_seconds(
|
||||
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
|
||||
)
|
||||
|
||||
if self.config.input_device == "keyboard":
|
||||
self._listener = _init_dagger_keyboard(self._events, self.config.keyboard)
|
||||
else:
|
||||
self._pedal_thread = _init_dagger_pedal(self._events, self.config.pedal)
|
||||
|
||||
record_mode = "all frames (sentry-like)" if self.config.record_autonomous else "corrections only"
|
||||
logger.info(
|
||||
"DAgger strategy ready (input=%s, episodes=%d, record=%s, episode_duration=%.0fs)",
|
||||
self.config.input_device,
|
||||
self.config.num_episodes,
|
||||
record_mode,
|
||||
self._episode_duration_s,
|
||||
)
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run DAgger episodes with human-in-the-loop intervention."""
|
||||
if self.config.record_autonomous:
|
||||
self._run_continuous(ctx)
|
||||
else:
|
||||
self._run_corrections_only(ctx)
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Stop listeners, finalise the dataset, and disconnect hardware."""
|
||||
play_sounds = ctx.runtime.cfg.play_sounds
|
||||
logger.info("Stopping DAgger recording")
|
||||
log_say("Stopping DAgger recording", play_sounds)
|
||||
|
||||
if self._listener is not None and not is_headless():
|
||||
logger.info("Stopping keyboard listener")
|
||||
self._listener.stop()
|
||||
|
||||
# Flush any queued/running push cleanly
|
||||
if self._push_executor is not None:
|
||||
logger.info("Shutting down push executor (waiting for pending pushes)...")
|
||||
self._push_executor.shutdown(wait=True)
|
||||
self._push_executor = None
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
||||
logger.info("Pushing final dataset to hub...")
|
||||
if safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=ctx.runtime.cfg.dataset.tags,
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("DAgger strategy teardown complete")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Continuous recording mode (record_autonomous=True)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_continuous(self, ctx: RolloutContext) -> None:
|
||||
"""Sentry-like continuous recording with intervention tagging.
|
||||
|
||||
Episodes are auto-rotated every ``episode_time_s`` seconds and
|
||||
uploaded in the background every ``upload_every_n_episodes`` episodes.
|
||||
Both autonomous and correction frames are recorded; corrections are
|
||||
tagged with ``intervention=True``.
|
||||
"""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
record_stride = max(1, cfg.interpolation_multiplier)
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
engine.resume()
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
record_tick = 0
|
||||
start_time = time.perf_counter()
|
||||
episode_start = time.perf_counter()
|
||||
episodes_since_push = 0
|
||||
episode_duration_s = self._episode_duration_s
|
||||
logger.info("DAgger continuous recording started (episode_duration=%.0fs)", episode_duration_s)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while not events.stop_recording.is_set() and not ctx.runtime.shutdown_event.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
# Process transitions
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
|
||||
phase = events.phase
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# --- CORRECTING: human teleop control ---
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
last_action = robot_action_to_send
|
||||
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
|
||||
# --- PAUSED: hold position ---
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
# --- AUTONOMOUS: policy control ---
|
||||
else:
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
if action_dict is not None:
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([False], dtype=bool),
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
|
||||
# Episode rotation derived from video file-size target.
|
||||
# Do NOT save mid-correction — wait for the correction
|
||||
# to finish so the episode boundary is clean.
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push.set()
|
||||
logger.info(
|
||||
"Episode saved (total: %d, elapsed: %.1fs)",
|
||||
dataset.num_episodes,
|
||||
elapsed,
|
||||
)
|
||||
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
|
||||
|
||||
if episodes_since_push >= self.config.upload_every_n_episodes:
|
||||
self._background_push(dataset, cfg)
|
||||
episodes_since_push = 0
|
||||
|
||||
episode_start = time.perf_counter()
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
logger.info("DAgger continuous control loop ended — pausing engine")
|
||||
engine.pause()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Corrections-only mode (record_autonomous=False)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_corrections_only(self, ctx: RolloutContext) -> None:
|
||||
"""Record only human correction windows. Each correction = one episode.
|
||||
|
||||
The policy runs autonomously without recording. When the user
|
||||
pauses and starts a correction, frames are recorded with
|
||||
``intervention=True``. Stopping the correction saves the episode.
|
||||
The dataset can be uploaded on demand via the upload key/pedal.
|
||||
"""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
teleop = ctx.hardware.teleop
|
||||
dataset = ctx.data.dataset
|
||||
events = self._events
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
record_stride = max(1, cfg.interpolation_multiplier)
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
engine.resume()
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
start_time = time.perf_counter()
|
||||
record_tick = 0
|
||||
recorded = 0
|
||||
logger.info(
|
||||
"DAgger corrections-only recording started (target: %d episodes)", self.config.num_episodes
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while (
|
||||
recorded < self.config.num_episodes
|
||||
and not events.stop_recording.is_set()
|
||||
and not ctx.runtime.shutdown_event.is_set()
|
||||
):
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
# Process transitions
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
|
||||
# Correction ended -> save episode (blocking if not streaming)
|
||||
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
recorded += 1
|
||||
self._needs_push.set()
|
||||
logger.info(
|
||||
"Correction %d/%d saved",
|
||||
recorded,
|
||||
self.config.num_episodes,
|
||||
)
|
||||
log_say(f"Correction {recorded} saved", play_sounds)
|
||||
|
||||
# On-demand upload
|
||||
if events.upload_requested.is_set():
|
||||
events.upload_requested.clear()
|
||||
logger.info("Upload requested by user")
|
||||
self._background_push(dataset, cfg)
|
||||
|
||||
phase = events.phase
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
|
||||
# --- CORRECTING: human teleop control + recording ---
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
last_action = robot_action_to_send
|
||||
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
|
||||
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
dataset.add_frame(
|
||||
{
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([True], dtype=bool),
|
||||
}
|
||||
)
|
||||
record_tick += 1
|
||||
|
||||
# --- PAUSED: hold position ---
|
||||
elif phase == DAggerPhase.PAUSED:
|
||||
if last_action:
|
||||
robot.send_action(last_action)
|
||||
|
||||
# --- AUTONOMOUS: policy control (no recording) ---
|
||||
else:
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
if action_dict is not None:
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
logger.info("DAgger corrections-only loop ended — pausing engine")
|
||||
engine.pause()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
logger.info("Final in-progress episode saved")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State-machine transition side-effects
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _apply_transition(
|
||||
old_phase: DAggerPhase,
|
||||
new_phase: DAggerPhase,
|
||||
engine,
|
||||
interpolator,
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
) -> None:
|
||||
"""Execute side-effects for a validated phase transition."""
|
||||
logger.info("Phase transition: %s -> %s", old_phase.value, new_phase.value)
|
||||
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
|
||||
logger.info("Pausing engine — robot holds position")
|
||||
engine.pause()
|
||||
obs = robot.get_observation()
|
||||
_robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
elif new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode — human teleop control")
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
logger.info("Resuming autonomous mode — resetting engine and interpolator")
|
||||
interpolator.reset()
|
||||
engine.reset()
|
||||
engine.resume()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background push (shared by both modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Queue a Hub push on the single-worker executor.
|
||||
|
||||
The executor's max_workers=1 guarantees at most one push runs at
|
||||
a time; submitted tasks are queued rather than dropped. Pushes
|
||||
are blocked while the operator is mid-correction to avoid
|
||||
uploading a partially-recorded episode.
|
||||
"""
|
||||
if self._push_executor is None:
|
||||
return
|
||||
|
||||
if self._events.phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Skipping push — correction in progress")
|
||||
return
|
||||
|
||||
if self._pending_push is not None and not self._pending_push.done():
|
||||
logger.info("Previous push still in progress; queueing next")
|
||||
|
||||
def _push():
|
||||
try:
|
||||
with self._episode_lock:
|
||||
if safe_push_to_hub(
|
||||
dataset,
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
):
|
||||
self._needs_push.clear()
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
self._pending_push = self._push_executor.submit(_push)
|
||||
logger.info("Background push task submitted")
|
||||
@@ -1,45 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Strategy factory: config type-name → strategy class dispatch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import BaseStrategy
|
||||
from .core import RolloutStrategy
|
||||
from .dagger import DAggerStrategy
|
||||
from .highlight import HighlightStrategy
|
||||
from .sentry import SentryStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..configs import RolloutStrategyConfig
|
||||
|
||||
|
||||
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
"""Instantiate the appropriate strategy from a config object.
|
||||
|
||||
Dispatches on ``config.type`` (the name registered via
|
||||
``draccus.ChoiceRegistry``).
|
||||
"""
|
||||
if config.type == "base":
|
||||
return BaseStrategy(config)
|
||||
if config.type == "sentry":
|
||||
return SentryStrategy(config)
|
||||
if config.type == "highlight":
|
||||
return HighlightStrategy(config)
|
||||
if config.type == "dagger":
|
||||
return DAggerStrategy(config)
|
||||
raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger")
|
||||
@@ -1,278 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Highlight Reel strategy: on-demand recording via ring buffer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event as ThreadingEvent
|
||||
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.import_utils import _pynput_available, require_package
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import HighlightStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..ring_buffer import RolloutRingBuffer
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
if PYNPUT_AVAILABLE:
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
PYNPUT_AVAILABLE = False
|
||||
else:
|
||||
from pynput import keyboard
|
||||
except Exception as e:
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HighlightStrategy(RolloutStrategy):
|
||||
"""Autonomous rollout with on-demand recording via ring buffer.
|
||||
|
||||
The robot runs autonomously while a memory-bounded ring buffer
|
||||
captures continuous telemetry. When the user presses the save key:
|
||||
|
||||
1. The ring buffer is flushed to the dataset (last *Z* seconds).
|
||||
2. Live recording continues until the save key is pressed again.
|
||||
3. The episode is saved and the ring buffer resumes capturing.
|
||||
|
||||
Requires ``streaming_encoding=True`` (enforced in config validation)
|
||||
so that ``dataset.add_frame`` is a non-blocking queue put — draining
|
||||
900 frames stays sub-ms per frame.
|
||||
"""
|
||||
|
||||
config: HighlightStrategyConfig
|
||||
|
||||
def __init__(self, config: HighlightStrategyConfig):
|
||||
super().__init__(config)
|
||||
require_package("pynput", extra="pynput-dep")
|
||||
self._ring: RolloutRingBuffer | None = None
|
||||
self._listener = None
|
||||
self._save_requested = ThreadingEvent()
|
||||
self._recording_live = ThreadingEvent()
|
||||
self._push_requested = ThreadingEvent()
|
||||
self._push_executor: ThreadPoolExecutor | None = None
|
||||
self._pending_push: Future | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine, ring buffer, and keyboard listener."""
|
||||
self._init_engine(ctx)
|
||||
|
||||
self._ring = RolloutRingBuffer(
|
||||
max_seconds=self.config.ring_buffer_seconds,
|
||||
max_memory_mb=self.config.ring_buffer_max_memory_mb,
|
||||
fps=ctx.runtime.cfg.fps,
|
||||
)
|
||||
|
||||
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="highlight-push")
|
||||
logger.info(
|
||||
"Ring buffer initialized (max_seconds=%.0f, max_memory=%.0fMB)",
|
||||
self.config.ring_buffer_seconds,
|
||||
self.config.ring_buffer_max_memory_mb,
|
||||
)
|
||||
self._setup_keyboard(ctx.runtime.shutdown_event)
|
||||
logger.info(
|
||||
"Highlight strategy ready (buffer=%.0fs, save='%s', push='%s')",
|
||||
self.config.ring_buffer_seconds,
|
||||
self.config.save_key,
|
||||
self.config.push_key,
|
||||
)
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run the autonomous loop, buffering frames and recording on demand."""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
dataset = ctx.data.dataset
|
||||
ring = self._ring
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
engine.resume()
|
||||
play_sounds = cfg.play_sounds
|
||||
|
||||
start_time = time.perf_counter()
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
logger.info("Highlight strategy recording started (press '%s' to save)", self.config.save_key)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while not ctx.runtime.shutdown_event.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
|
||||
# NOTE: ``is_set()`` then ``clear()`` is not atomic
|
||||
# against the keyboard thread setting the flag again
|
||||
# in between — but that is benign: we lose at most one
|
||||
# toggle, processed on the next iteration. The
|
||||
# ``_recording_live`` branch below is reached in the
|
||||
# SAME iteration after ``clear()`` runs, so a frame
|
||||
# finalised by ``save_episode()`` is never re-added to
|
||||
# the next episode.
|
||||
if self._save_requested.is_set():
|
||||
self._save_requested.clear()
|
||||
if not self._recording_live.is_set():
|
||||
logger.info(
|
||||
"Flushing ring buffer (%d frames) + starting live recording",
|
||||
len(ring),
|
||||
)
|
||||
for buffered_frame in ring.drain():
|
||||
dataset.add_frame(buffered_frame)
|
||||
self._recording_live.set()
|
||||
else:
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
logger.info("Episode saved (total: %d)", dataset.num_episodes)
|
||||
log_say(
|
||||
f"Episode {dataset.num_episodes} saved",
|
||||
play_sounds,
|
||||
)
|
||||
self._recording_live.clear()
|
||||
|
||||
if self._push_requested.is_set():
|
||||
self._push_requested.clear()
|
||||
logger.info("Push requested by user")
|
||||
self._background_push(dataset, cfg)
|
||||
|
||||
if self._recording_live.is_set():
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
ring.append(frame)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
logger.info("Highlight control loop ended")
|
||||
if self._recording_live.is_set():
|
||||
logger.info("Saving in-progress live episode")
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Stop listeners, finalise the dataset, and disconnect hardware."""
|
||||
play_sounds = ctx.runtime.cfg.play_sounds
|
||||
logger.info("Stopping highlight recording")
|
||||
log_say("Stopping highlight recording", play_sounds)
|
||||
|
||||
if self._listener is not None:
|
||||
logger.info("Stopping keyboard listener")
|
||||
self._listener.stop()
|
||||
|
||||
if self._push_executor is not None:
|
||||
logger.info("Shutting down push executor (waiting for pending pushes)...")
|
||||
self._push_executor.shutdown(wait=True)
|
||||
self._push_executor = None
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
||||
logger.info("Pushing final dataset to hub...")
|
||||
if safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=ctx.runtime.cfg.dataset.tags,
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Highlight strategy teardown complete")
|
||||
|
||||
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
|
||||
"""Set up keyboard listener for save and push keys."""
|
||||
if is_headless():
|
||||
logger.warning("Headless environment — highlight keys unavailable")
|
||||
return
|
||||
|
||||
try:
|
||||
save_key = self.config.save_key
|
||||
push_key = self.config.push_key
|
||||
|
||||
def on_press(key):
|
||||
with contextlib.suppress(Exception):
|
||||
if hasattr(key, "char") and key.char == save_key:
|
||||
self._save_requested.set()
|
||||
elif hasattr(key, "char") and key.char == push_key:
|
||||
self._push_requested.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
self._save_requested.clear()
|
||||
shutdown_event.set()
|
||||
|
||||
self._listener = keyboard.Listener(on_press=on_press)
|
||||
self._listener.start()
|
||||
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
|
||||
except ImportError:
|
||||
logger.warning("pynput not available — keyboard listener disabled")
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Queue a Hub push on the single-worker executor."""
|
||||
if self._push_executor is None:
|
||||
return
|
||||
|
||||
if self._pending_push is not None and not self._pending_push.done():
|
||||
logger.info("Previous push still in progress; queueing next")
|
||||
|
||||
def _push():
|
||||
try:
|
||||
if safe_push_to_hub(
|
||||
dataset,
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
):
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
self._pending_push = self._push_executor.submit(_push)
|
||||
logger.info("Background push task submitted")
|
||||
@@ -1,226 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Sentry rollout strategy: continuous autonomous recording with auto-upload."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from threading import Event, Lock
|
||||
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.datasets.utils import DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import SentryStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SentryStrategy(RolloutStrategy):
|
||||
"""Continuous autonomous rollout with always-on recording.
|
||||
|
||||
Episode duration is derived from camera resolution, FPS, and
|
||||
``DEFAULT_VIDEO_FILE_SIZE_IN_MB`` so that each saved episode
|
||||
produces a video file that has crossed the chunk-size boundary.
|
||||
This keeps ``push_to_hub`` efficient — it uploads complete video
|
||||
files rather than re-uploading a still-growing one.
|
||||
|
||||
The dataset is pushed to the Hub via a bounded single-worker executor
|
||||
so no push is ever silently dropped and exactly one push runs at a
|
||||
time.
|
||||
|
||||
Policy state (hidden state, RTC queue) intentionally persists across
|
||||
episode boundaries — Sentry slices one continuous rollout, the robot
|
||||
does not reset between slices.
|
||||
|
||||
Requires ``streaming_encoding=True`` (enforced in config validation)
|
||||
to prevent disk I/O from blocking the control loop.
|
||||
"""
|
||||
|
||||
config: SentryStrategyConfig
|
||||
|
||||
def __init__(self, config: SentryStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._push_executor: ThreadPoolExecutor | None = None
|
||||
self._pending_push: Future | None = None
|
||||
self._needs_push = Event()
|
||||
self._episode_lock = Lock()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
"""Initialise the inference engine and background push executor."""
|
||||
self._init_engine(ctx)
|
||||
self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="sentry-push")
|
||||
target_mb = self.config.target_video_file_size_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
self._episode_duration_s = estimate_max_episode_seconds(
|
||||
ctx.data.dataset_features, ctx.runtime.cfg.fps, target_size_mb=target_mb
|
||||
)
|
||||
logger.info(
|
||||
"Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)",
|
||||
self._episode_duration_s,
|
||||
self.config.upload_every_n_episodes,
|
||||
)
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run the continuous recording loop with automatic episode rotation."""
|
||||
engine = self._engine
|
||||
cfg = ctx.runtime.cfg
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
dataset = ctx.data.dataset
|
||||
interpolator = self._interpolator
|
||||
features = ctx.data.dataset_features
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
engine.resume()
|
||||
play_sounds = cfg.play_sounds
|
||||
episode_duration_s = self._episode_duration_s
|
||||
|
||||
start_time = time.perf_counter()
|
||||
episode_start = time.perf_counter()
|
||||
episodes_since_push = 0
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
logger.info("Sentry recording started (episode_duration=%.0fs)", episode_duration_s)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
try:
|
||||
while not ctx.runtime.shutdown_event.is_set():
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
|
||||
if action_dict is not None:
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
# ``add_frame`` writes to the in-progress episode buffer; the
|
||||
# background pusher only ever touches *finalised* episode
|
||||
# artifacts on disk. The two operate on disjoint state, so
|
||||
# ``add_frame`` does not need ``_episode_lock``.
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Episode rotation derived from video file-size target.
|
||||
# The duration is a conservative estimate so the actual
|
||||
# video has crossed DEFAULT_VIDEO_FILE_SIZE_IN_MB by now,
|
||||
# keeping push_to_hub efficient (uploads complete files).
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= episode_duration_s:
|
||||
# ``save_episode`` finalises the in-progress episode and
|
||||
# flushes it to disk; ``_episode_lock`` serialises this with
|
||||
# ``push_to_hub`` (run in the background executor) so the
|
||||
# pusher never reads a half-written episode.
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push.set()
|
||||
logger.info(
|
||||
"Episode saved (total: %d, elapsed: %.1fs)",
|
||||
dataset.num_episodes,
|
||||
elapsed,
|
||||
)
|
||||
log_say(f"Episode {dataset.num_episodes} saved", play_sounds)
|
||||
|
||||
if episodes_since_push >= self.config.upload_every_n_episodes:
|
||||
self._background_push(dataset, cfg)
|
||||
episodes_since_push = 0
|
||||
|
||||
episode_start = time.perf_counter()
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
logger.info("Sentry control loop ended — saving final episode")
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
self._needs_push.set()
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Flush pending pushes, finalise the dataset, and disconnect hardware."""
|
||||
play_sounds = ctx.runtime.cfg.play_sounds
|
||||
logger.info("Stopping sentry recording")
|
||||
log_say("Stopping sentry recording", play_sounds)
|
||||
|
||||
# Flush any queued/running push cleanly.
|
||||
if self._push_executor is not None:
|
||||
logger.info("Shutting down push executor (waiting for pending pushes)...")
|
||||
self._push_executor.shutdown(wait=True)
|
||||
self._push_executor = None
|
||||
|
||||
if ctx.data.dataset is not None:
|
||||
logger.info("Finalizing dataset...")
|
||||
ctx.data.dataset.finalize()
|
||||
if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub:
|
||||
logger.info("Pushing final dataset to hub...")
|
||||
if safe_push_to_hub(
|
||||
ctx.data.dataset,
|
||||
tags=ctx.runtime.cfg.dataset.tags,
|
||||
private=ctx.runtime.cfg.dataset.private,
|
||||
):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Sentry strategy teardown complete")
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
"""Queue a Hub push on the single-worker executor.
|
||||
|
||||
The executor's max_workers=1 guarantees at most one push runs at
|
||||
a time; submitted tasks are queued rather than dropped.
|
||||
"""
|
||||
if self._push_executor is None:
|
||||
return
|
||||
|
||||
if self._pending_push is not None and not self._pending_push.done():
|
||||
logger.info("Previous push still in progress; queueing next")
|
||||
|
||||
def _push():
|
||||
try:
|
||||
with self._episode_lock:
|
||||
if safe_push_to_hub(
|
||||
dataset,
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
):
|
||||
self._needs_push.clear()
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
self._pending_push = self._push_executor.submit(_push)
|
||||
logger.info("Background push task submitted")
|
||||
@@ -13,62 +13,70 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Records a dataset via teleoperation. This is a pure data-collection
|
||||
tool — no policy inference. For deploying trained policies, use
|
||||
``lerobot-rollout`` instead.
|
||||
Records a dataset. Actions for the robot can be either generated by teleoperation or by a policy.
|
||||
|
||||
Requires: pip install 'lerobot[core_scripts]' (includes dataset + hardware + viz extras)
|
||||
|
||||
Example:
|
||||
|
||||
```shell
|
||||
lerobot-record \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \\
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\
|
||||
--robot.id=black \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \\
|
||||
--teleop.id=blue \\
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \\
|
||||
--dataset.num_episodes=2 \\
|
||||
--dataset.single_task="Grab the cube" \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
--dataset.encoder_threads=2 \\
|
||||
lerobot-record \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--robot.id=black \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
--display_data=true
|
||||
# <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# --dataset.vcodec=h264 \
|
||||
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
# --teleop.id=blue \
|
||||
# <- Policy optional if you want to record with a policy \
|
||||
# --policy.path=${HF_USER}/my_policy \
|
||||
```
|
||||
|
||||
Example recording with bimanual so100:
|
||||
```shell
|
||||
lerobot-record \\
|
||||
--robot.type=bi_so_follower \\
|
||||
--robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \\
|
||||
--robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \\
|
||||
--robot.id=bimanual_follower \\
|
||||
lerobot-record \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \
|
||||
--robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \
|
||||
--robot.id=bimanual_follower \
|
||||
--robot.left_arm_config.cameras='{
|
||||
wrist: {"type": "opencv", "index_or_path": 1, "width": 640, "height": 480, "fps": 30},
|
||||
top: {"type": "opencv", "index_or_path": 3, "width": 640, "height": 480, "fps": 30},
|
||||
}' --robot.right_arm_config.cameras='{
|
||||
wrist: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
|
||||
front: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
|
||||
}' \\
|
||||
--teleop.type=bi_so_leader \\
|
||||
--teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \\
|
||||
--teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \\
|
||||
--teleop.id=bimanual_leader \\
|
||||
--display_data=true \\
|
||||
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \\
|
||||
--dataset.num_episodes=25 \\
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
}' \
|
||||
--teleop.type=bi_so_leader \
|
||||
--teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \
|
||||
--teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \
|
||||
--teleop.id=bimanual_leader \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \
|
||||
--dataset.num_episodes=25 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.cameras import CameraConfig # noqa: F401
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
@@ -78,10 +86,11 @@ from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.common.control_utils import (
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
predict_action,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.configs import PreTrainedConfig, parser
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
@@ -89,11 +98,21 @@ from lerobot.datasets import (
|
||||
create_initial_features,
|
||||
safe_stop_image_writer,
|
||||
)
|
||||
from lerobot.policies import (
|
||||
ActionInterpolator,
|
||||
PreTrainedPolicy,
|
||||
make_policy,
|
||||
make_pre_post_processors,
|
||||
make_robot_action,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
make_default_processors,
|
||||
rename_stats,
|
||||
)
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
@@ -127,6 +146,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
)
|
||||
from lerobot.teleoperators.keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
@@ -137,12 +157,71 @@ from lerobot.utils.utils import (
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecordConfig:
|
||||
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
# Number of seconds for data recording for each episode.
|
||||
episode_time_s: int | float = 60
|
||||
# Number of seconds for resetting the environment after each episode.
|
||||
reset_time_s: int | float = 60
|
||||
# Number of episodes to record.
|
||||
num_episodes: int = 50
|
||||
# Encode frames in the dataset into video
|
||||
video: bool = True
|
||||
# Upload dataset to Hugging Face hub.
|
||||
push_to_hub: bool = True
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool = False
|
||||
# Add tags to your dataset on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
||||
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
||||
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
||||
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
||||
num_image_writer_processes: int = 0
|
||||
# Number of threads writing the frames as png images on disk, per camera.
|
||||
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
||||
# Not enough threads might cause low camera fps.
|
||||
num_image_writer_threads_per_camera: int = 4
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.single_task is None:
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecordConfig:
|
||||
robot: RobotConfig
|
||||
dataset: DatasetRecordConfig
|
||||
# Teleoperator to control the robot (required)
|
||||
# Whether to control the robot with a teleoperator
|
||||
teleop: TeleoperatorConfig | None = None
|
||||
# Whether to control the robot with a policy
|
||||
policy: PreTrainedConfig | None = None
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Display data on a remote Rerun server
|
||||
@@ -155,14 +234,27 @@ class RecordConfig:
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Action interpolation multiplier for smoother policy control (1=off, 2=2x, 3=3x)
|
||||
# Only applies when using a policy (not teleop)
|
||||
interpolation_multiplier: int = 1
|
||||
|
||||
def __post_init__(self):
|
||||
if self.teleop is None:
|
||||
raise ValueError(
|
||||
"A teleoperator is required for recording. "
|
||||
"Use --teleop.type=... to specify one. "
|
||||
"For policy-based deployment, use lerobot-rollout instead."
|
||||
)
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
if self.teleop is None and self.policy is None:
|
||||
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
""" --------------- record_loop() data flow --------------------------
|
||||
@@ -172,14 +264,18 @@ class RecordConfig:
|
||||
V
|
||||
[ robot_observation_processor ] ---> processed_obs
|
||||
V
|
||||
[ Teleoperator ]
|
||||
|
|
||||
| [teleop.get_action] -> raw_action
|
||||
| |
|
||||
| V
|
||||
| [teleop_action_processor]
|
||||
| |
|
||||
'---> processed_teleop_action
|
||||
.-----( ACTION LOGIC )------------------.
|
||||
V V
|
||||
[ From Teleoperator ] [ From Policy ]
|
||||
| |
|
||||
| [teleop.get_action] -> raw_action | [predict_action]
|
||||
| | | |
|
||||
| V | V
|
||||
| [teleop_action_processor] | |
|
||||
| | | |
|
||||
'---> processed_teleop_action '---> processed_policy_action
|
||||
| |
|
||||
'-------------------------.-------------'
|
||||
V
|
||||
[ robot_action_processor ] --> robot_action_to_send
|
||||
V
|
||||
@@ -207,9 +303,13 @@ def record_loop(
|
||||
], # runs after robot
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None,
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None,
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
interpolator: ActionInterpolator | None = None,
|
||||
display_compressed_images: bool = False,
|
||||
):
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
@@ -240,7 +340,21 @@ def record_loop(
|
||||
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
||||
)
|
||||
|
||||
control_interval = 1 / fps
|
||||
# Reset policy and processor if they are provided
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Reset interpolator if provided
|
||||
if interpolator is not None:
|
||||
interpolator.reset()
|
||||
|
||||
# Calculate control interval based on interpolation
|
||||
use_interpolation = interpolator is not None and interpolator.enabled and policy is not None
|
||||
control_interval = interpolator.get_control_interval(fps) if interpolator else 1 / fps
|
||||
# Pre-compute action key order outside the hot loop — it won't change mid-episode.
|
||||
action_keys = sorted(robot.action_features) if use_interpolation else []
|
||||
|
||||
no_action_count = 0
|
||||
timestamp = 0
|
||||
@@ -258,11 +372,63 @@ def record_loop(
|
||||
# Applies a pipeline to the raw robot observation, default is IdentityProcessor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
if dataset is not None:
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# Get action from teleop
|
||||
if isinstance(teleop, Teleoperator):
|
||||
# Track whether this iteration should be recorded to the dataset.
|
||||
# Interpolated-only iterations send actions to the robot but don't record frames,
|
||||
# keeping the dataset at the original fps while the robot moves at the higher rate.
|
||||
is_record_frame = True
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
# With interpolation: only call policy when interpolator needs new action
|
||||
if use_interpolation:
|
||||
ran_inference = False
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy = make_robot_action(action_values, dataset.features)
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
|
||||
action_tensor = torch.tensor([robot_action_to_send[k] for k in action_keys])
|
||||
interpolator.add(action_tensor)
|
||||
ran_inference = True
|
||||
|
||||
interp_action = interpolator.get()
|
||||
if interp_action is not None:
|
||||
robot_action_to_send = {k: interp_action[i].item() for i, k in enumerate(action_keys)}
|
||||
action_values = robot_action_to_send
|
||||
else:
|
||||
continue
|
||||
|
||||
is_record_frame = ran_inference
|
||||
else:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
act_processed_policy: RobotAction = make_robot_action(action_values, dataset.features)
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
robot_action_to_send = robot_action_processor((act_processed_policy, obs))
|
||||
action_values = robot_action_to_send
|
||||
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
act = teleop.get_action()
|
||||
if robot.name == "unitree_g1":
|
||||
teleop.send_feedback(obs)
|
||||
@@ -272,7 +438,7 @@ def record_loop(
|
||||
action_values = act_processed_teleop
|
||||
robot_action_to_send = robot_action_processor((act_processed_teleop, obs))
|
||||
|
||||
elif isinstance(teleop, list):
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
keyboard_action = teleop_keyboard.get_action()
|
||||
@@ -285,7 +451,7 @@ def record_loop(
|
||||
no_action_count += 1
|
||||
if no_action_count == 1 or no_action_count % 10 == 0:
|
||||
logging.warning(
|
||||
"No teleoperator provided, skipping action generation. "
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device. "
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
@@ -297,8 +463,8 @@ def record_loop(
|
||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_sent_action = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
# Write to dataset (only on real policy frames, not interpolated-only iterations)
|
||||
if dataset is not None and is_record_frame:
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
@@ -322,12 +488,7 @@ def record_loop(
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def record(
|
||||
cfg: RecordConfig,
|
||||
teleop_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_action_processor: RobotProcessorPipeline | None = None,
|
||||
robot_observation_processor: RobotProcessorPipeline | None = None,
|
||||
) -> LeRobotDataset:
|
||||
def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
if cfg.display_data:
|
||||
@@ -341,16 +502,7 @@ def record(
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
|
||||
|
||||
# Fall back to identity pipelines when the caller doesn't supply processors.
|
||||
if (
|
||||
teleop_action_processor is None
|
||||
or robot_action_processor is None
|
||||
or robot_observation_processor is None
|
||||
):
|
||||
_t, _r, _o = make_default_processors()
|
||||
teleop_action_processor = teleop_action_processor or _t
|
||||
robot_action_processor = robot_action_processor or _r
|
||||
robot_observation_processor = robot_observation_processor or _o
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
@@ -388,12 +540,8 @@ def record(
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
|
||||
else:
|
||||
# Reject eval_ prefix — for policy evaluation use lerobot-rollout
|
||||
if cfg.dataset.repo_id.startswith("eval_"):
|
||||
raise ValueError(
|
||||
"Dataset names starting with 'eval_' are reserved for policy evaluation. "
|
||||
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
|
||||
)
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
@@ -410,6 +558,30 @@ def record(
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = (
|
||||
None
|
||||
if cfg.policy is None
|
||||
else make_policy(cfg.policy, ds_meta=dataset.meta, rename_map=cfg.dataset.rename_map)
|
||||
)
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
interpolator = None
|
||||
if cfg.policy is not None:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
# Create interpolator for smoother policy control
|
||||
if cfg.interpolation_multiplier > 1:
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
logging.info(f"Action interpolation enabled: {cfg.interpolation_multiplier}x control rate")
|
||||
|
||||
robot.connect()
|
||||
if teleop is not None:
|
||||
teleop.connect()
|
||||
@@ -433,10 +605,14 @@ def record(
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
display_data=cfg.display_data,
|
||||
interpolator=interpolator,
|
||||
display_compressed_images=display_compressed_images,
|
||||
)
|
||||
|
||||
@@ -484,10 +660,7 @@ def record(
|
||||
listener.stop()
|
||||
|
||||
if cfg.dataset.push_to_hub:
|
||||
if dataset and dataset.num_episodes > 0:
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
else:
|
||||
logging.warning("No episodes saved — skipping push to hub")
|
||||
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
||||
|
||||
log_say("Exiting", cfg.play_sounds)
|
||||
return dataset
|
||||
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Policy deployment engine with pluggable rollout strategies.
|
||||
|
||||
``lerobot-rollout`` is the single CLI for running trained policies on
|
||||
real robots.
|
||||
|
||||
Strategies
|
||||
----------
|
||||
--strategy.type=base Autonomous rollout, no recording
|
||||
--strategy.type=sentry Continuous recording with auto-upload
|
||||
--strategy.type=highlight Ring buffer + keystroke save
|
||||
--strategy.type=dagger Human-in-the-loop (DAgger / RaC)
|
||||
|
||||
Inference backends
|
||||
------------------
|
||||
--inference.type=sync One policy call per control tick (default)
|
||||
--inference.type=rtc Real-Time Chunking for slow VLA models
|
||||
|
||||
Usage examples
|
||||
--------------
|
||||
::
|
||||
|
||||
# Base mode — quick evaluation with sync inference
|
||||
lerobot-rollout \\
|
||||
--strategy.type=base \\
|
||||
--policy.path=lerobot/act_koch_real \\
|
||||
--robot.type=koch_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--task="pick up cube" --duration=30
|
||||
|
||||
# Base mode — RTC inference for slow VLAs (Pi0, Pi0.5, SmolVLA)
|
||||
lerobot-rollout \\
|
||||
--strategy.type=base \\
|
||||
--policy.path=lerobot/pi0_base \\
|
||||
--inference.type=rtc \\
|
||||
--inference.rtc.execution_horizon=10 \\
|
||||
--inference.rtc.max_guidance_weight=10.0 \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\
|
||||
--task="pick up cube" --duration=60
|
||||
|
||||
# Sentry mode — continuous recording with periodic upload
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
--strategy.upload_every_n_episodes=5 \\
|
||||
--policy.path=lerobot/pi0_base \\
|
||||
--inference.type=rtc \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--dataset.repo_id=user/sentry-data \\
|
||||
--dataset.single_task="patrol" --duration=3600
|
||||
|
||||
# Highlight mode — ring buffer, press 's' to save, 'h' to push
|
||||
lerobot-rollout \\
|
||||
--strategy.type=highlight \\
|
||||
--strategy.ring_buffer_seconds=30 \\
|
||||
--policy.path=lerobot/act_koch_real \\
|
||||
--robot.type=koch_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--dataset.repo_id=user/highlight-data \\
|
||||
--dataset.single_task="pick up cube"
|
||||
|
||||
# DAgger mode — human-in-the-loop corrections only
|
||||
lerobot-rollout \\
|
||||
--strategy.type=dagger \\
|
||||
--strategy.num_episodes=20 \\
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \\
|
||||
--robot.type=bi_openarm_follower \\
|
||||
--teleop.type=openarm_mini \\
|
||||
--dataset.repo_id=user/hil-data \\
|
||||
--dataset.single_task="Fold the T-shirt"
|
||||
|
||||
# DAgger mode — continuous recording with RTC inference
|
||||
lerobot-rollout \\
|
||||
--strategy.type=dagger \\
|
||||
--strategy.record_autonomous=true \\
|
||||
--strategy.num_episodes=50 \\
|
||||
--inference.type=rtc \\
|
||||
--inference.rtc.execution_horizon=10 \\
|
||||
--policy.path=user/my_pi0_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--teleop.type=so101_leader \\
|
||||
--teleop.port=/dev/ttyACM1 \\
|
||||
--dataset.repo_id=user/dagger-rtc-data \\
|
||||
--dataset.single_task="Grasp the block"
|
||||
|
||||
# With Rerun visualization and torch.compile
|
||||
lerobot-rollout \\
|
||||
--strategy.type=base \\
|
||||
--policy.path=lerobot/act_koch_real \\
|
||||
--robot.type=koch_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--task="pick up cube" --duration=60 \\
|
||||
--display_data=true \\
|
||||
--use_torch_compile=true
|
||||
|
||||
# Resume a previous sentry recording session
|
||||
lerobot-rollout \\
|
||||
--strategy.type=sentry \\
|
||||
--policy.path=user/my_policy \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/ttyACM0 \\
|
||||
--dataset.repo_id=user/sentry-data \\
|
||||
--dataset.single_task="patrol" \\
|
||||
--resume=true
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
omx_follower,
|
||||
openarm_follower,
|
||||
reachy2,
|
||||
so_follower,
|
||||
unitree_g1 as unitree_g1_robot,
|
||||
)
|
||||
from lerobot.rollout import RolloutConfig, build_rollout_context, create_strategy
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
koch_leader,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.utils import init_logging
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def rollout(cfg: RolloutConfig):
|
||||
"""Main entry point for policy deployment."""
|
||||
init_logging()
|
||||
|
||||
if cfg.display_data:
|
||||
logger.info("Initializing Rerun visualization (ip=%s, port=%s)", cfg.display_ip, cfg.display_port)
|
||||
init_rerun(session_name="rollout", ip=cfg.display_ip, port=cfg.display_port)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
logger.info("Building rollout context...")
|
||||
ctx = build_rollout_context(cfg, shutdown_event)
|
||||
|
||||
strategy = create_strategy(cfg.strategy)
|
||||
logger.info("Rollout strategy: %s", cfg.strategy.type)
|
||||
logger.info(
|
||||
"Robot: %s | FPS: %.0f | Duration: %s",
|
||||
cfg.robot.type if cfg.robot else "?",
|
||||
cfg.fps,
|
||||
f"{cfg.duration}s" if cfg.duration > 0 else "infinite",
|
||||
)
|
||||
|
||||
try:
|
||||
strategy.setup(ctx)
|
||||
logger.info("Rollout setup complete, starting rollout...")
|
||||
strategy.run(ctx)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
finally:
|
||||
strategy.teardown(ctx)
|
||||
|
||||
logger.info("Rollout finished")
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for ``lerobot-rollout``."""
|
||||
register_third_party_plugins()
|
||||
rollout()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -49,6 +49,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.model_profiling import TrainingProfiler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
cycle,
|
||||
@@ -71,6 +72,7 @@ def update_policy(
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
rabc_weights_provider=None,
|
||||
profiler: "TrainingProfiler | None" = None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
@@ -103,8 +105,10 @@ def update_policy(
|
||||
if rabc_weights_provider is not None:
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
def _section(name: str) -> Any:
|
||||
return profiler.section(name) if profiler is not None else nullcontext()
|
||||
|
||||
with _section("forward"), accelerator.autocast():
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
# Get per-sample losses
|
||||
@@ -123,8 +127,8 @@ def update_policy(
|
||||
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator's backward method
|
||||
accelerator.backward(loss)
|
||||
with _section("backward"):
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Clip gradients if specified
|
||||
if grad_clip_norm > 0:
|
||||
@@ -134,8 +138,7 @@ def update_policy(
|
||||
policy.parameters(), float("inf"), error_if_nonfinite=False
|
||||
)
|
||||
|
||||
# Optimizer step
|
||||
with lock if lock is not None else nullcontext():
|
||||
with _section("optimizer"), lock if lock is not None else nullcontext():
|
||||
optimizer.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
@@ -316,6 +319,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
profiler = (
|
||||
TrainingProfiler.from_cfg(cfg, device) if cfg.profile_mode != "off" and is_main_process else None
|
||||
)
|
||||
if profiler:
|
||||
profiler.record_deterministic_forward(
|
||||
policy=policy, dataset=dataset, batch_size=cfg.batch_size, preprocessor=preprocessor
|
||||
)
|
||||
profiler.start()
|
||||
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
rabc_weights = None
|
||||
@@ -449,6 +461,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
profiler=profiler,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -456,6 +469,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
if profiler:
|
||||
profiler.step(step, train_tracker)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
@@ -551,6 +566,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
if profiler:
|
||||
profiler.finalize()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action interpolation for smoother robot control.
|
||||
|
||||
Provides configurable Nx control rate by interpolating between consecutive actions.
|
||||
Useful with RTC and action-chunking policies to reduce jerkiness.
|
||||
"""
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolates between consecutive actions for smoother control.
|
||||
|
||||
When enabled with multiplier N, produces N actions per policy action
|
||||
by linearly interpolating between the previous and current action.
|
||||
|
||||
Example with multiplier=3:
|
||||
prev_action -> [1/3 interpolated, 2/3 interpolated, current_action]
|
||||
|
||||
This effectively multiplies the control rate for smoother motion.
|
||||
|
||||
Usage:
|
||||
interpolator = ActionInterpolator(multiplier=2) # 2x control rate
|
||||
|
||||
# In control loop:
|
||||
if interpolator.needs_new_action():
|
||||
new_action = queue.get()
|
||||
if new_action:
|
||||
interpolator.add(new_action.cpu())
|
||||
|
||||
action = interpolator.get()
|
||||
if action:
|
||||
robot.send_action(action)
|
||||
"""
|
||||
|
||||
def __init__(self, multiplier: int = 1):
|
||||
"""Initialize the interpolator.
|
||||
|
||||
Args:
|
||||
multiplier: Control rate multiplier (1 = no interpolation, 2 = 2x, 3 = 3x, etc.)
|
||||
"""
|
||||
if multiplier < 1:
|
||||
raise ValueError(f"multiplier must be >= 1, got {multiplier}")
|
||||
self.multiplier = multiplier
|
||||
self._prev: Tensor | None = None
|
||||
self._buffer: list[Tensor] = []
|
||||
self._idx = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Whether interpolation is active (multiplier > 1)."""
|
||||
return self.multiplier > 1
|
||||
|
||||
def reset(self):
|
||||
"""Reset interpolation state (call between episodes)."""
|
||||
self._prev = None
|
||||
self._buffer = []
|
||||
self._idx = 0
|
||||
|
||||
def needs_new_action(self) -> bool:
|
||||
"""Check if a new action is needed from the queue."""
|
||||
return self._idx >= len(self._buffer)
|
||||
|
||||
def add(self, action: Tensor) -> None:
|
||||
"""Add a new action and compute interpolated sequence.
|
||||
|
||||
Args:
|
||||
action: New action tensor from policy/queue (already on CPU).
|
||||
"""
|
||||
if self.multiplier > 1 and self._prev is not None:
|
||||
self._buffer = []
|
||||
for i in range(1, self.multiplier + 1):
|
||||
t = i / self.multiplier
|
||||
interp = self._prev + t * (action - self._prev)
|
||||
self._buffer.append(interp)
|
||||
else:
|
||||
# First step: no previous action yet, so run at base FPS without interpolation.
|
||||
self._buffer = [action.clone()]
|
||||
self._prev = action.clone()
|
||||
self._idx = 0
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next interpolated action.
|
||||
|
||||
Returns:
|
||||
Next action tensor, or None if buffer is exhausted.
|
||||
"""
|
||||
if self._idx >= len(self._buffer):
|
||||
return None
|
||||
action = self._buffer[self._idx]
|
||||
self._idx += 1
|
||||
return action
|
||||
|
||||
def get_control_interval(self, fps: float) -> float:
|
||||
"""Get the control interval based on interpolation multiplier.
|
||||
|
||||
Args:
|
||||
fps: Base frames per second.
|
||||
|
||||
Returns:
|
||||
Control interval in seconds (divided by multiplier).
|
||||
"""
|
||||
return 1.0 / (fps * self.multiplier)
|
||||
783
src/lerobot/utils/model_profiling.py
Normal file
783
src/lerobot/utils/model_profiling.py
Normal file
@@ -0,0 +1,783 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
"""Model profiling — single-file entry point.
|
||||
|
||||
Contains three things that used to live in three separate files:
|
||||
|
||||
* `TrainingProfiler` — hooks the training loop. Captures per-step
|
||||
forward/backward/optimizer timings, the torch profiler output, and a
|
||||
deterministic-forward fingerprint for regression detection.
|
||||
* `POLICY_SPECS` — CI matrix of `policy_name → (steps, train_args)`.
|
||||
Inline so there is no separate JSON to keep in sync.
|
||||
* `main()` — CI orchestrator. For each selected policy, spawns a
|
||||
`lerobot-train` subprocess with profiling enabled, collects the
|
||||
artifacts, and (optionally) publishes a row to a HF Hub dataset.
|
||||
|
||||
Usage (CI):
|
||||
|
||||
python -m lerobot.utils.model_profiling \
|
||||
--output_dir=./profiling-results \
|
||||
--policies act diffusion \
|
||||
--profile_mode=trace \
|
||||
--publish
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import statistics
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from numbers import Real
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import CommitOperationAdd, HfApi
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from torch.utils.data import default_collate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Policy matrix. Same shape as the former JSON file; inlined so the source
|
||||
# tree has one less file to keep in sync with the training args.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LIBERO_RENAME_BASE_RGB = (
|
||||
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
|
||||
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}'
|
||||
)
|
||||
_LIBERO_RENAME_CAMERAS = (
|
||||
'--rename_map={"observation.images.front": "observation.images.camera1", '
|
||||
'"observation.images.wrist": "observation.images.camera2"}'
|
||||
)
|
||||
_PI_SGD = [
|
||||
"--use_policy_training_preset=false",
|
||||
"--optimizer.type=sgd",
|
||||
"--optimizer.lr=1e-5",
|
||||
"--optimizer.weight_decay=0",
|
||||
"--optimizer.grad_clip_norm=1.0",
|
||||
"--scheduler.type=cosine_decay_with_warmup",
|
||||
"--scheduler.peak_lr=1e-5",
|
||||
"--scheduler.decay_lr=1e-6",
|
||||
"--scheduler.num_warmup_steps=0",
|
||||
"--scheduler.num_decay_steps=12",
|
||||
]
|
||||
|
||||
|
||||
POLICY_SPECS: dict[str, dict[str, Any]] = {
|
||||
"act": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.type=act",
|
||||
"--policy.device=cuda",
|
||||
"--batch_size=4",
|
||||
"--cudnn_deterministic=true",
|
||||
],
|
||||
},
|
||||
"diffusion": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.type=diffusion",
|
||||
"--policy.device=cuda",
|
||||
"--batch_size=4",
|
||||
"--cudnn_deterministic=true",
|
||||
],
|
||||
},
|
||||
"groot": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/libero_plus",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.type=groot",
|
||||
"--policy.base_model_path=nvidia/GR00T-N1.5-3B",
|
||||
"--policy.tune_diffusion_model=true",
|
||||
"--policy.tune_projector=true",
|
||||
"--policy.tune_llm=false",
|
||||
"--policy.tune_visual=false",
|
||||
"--policy.use_bf16=true",
|
||||
"--policy.device=cuda",
|
||||
"--batch_size=1",
|
||||
'--rename_map={"observation.images.image": "observation.images.camera1", '
|
||||
'"observation.images.image2": "observation.images.camera2"}',
|
||||
],
|
||||
},
|
||||
"multi_task_dit": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.type=multi_task_dit",
|
||||
"--policy.device=cuda",
|
||||
"--policy.horizon=32",
|
||||
"--policy.n_action_steps=30",
|
||||
"--batch_size=4",
|
||||
"--cudnn_deterministic=true",
|
||||
],
|
||||
},
|
||||
"pi0": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/libero_plus",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.path=lerobot/pi0_base",
|
||||
"--policy.device=cuda",
|
||||
"--policy.dtype=bfloat16",
|
||||
"--policy.n_action_steps=30",
|
||||
"--policy.use_amp=true",
|
||||
"--policy.gradient_checkpointing=true",
|
||||
"--batch_size=1",
|
||||
*_PI_SGD,
|
||||
_LIBERO_RENAME_BASE_RGB,
|
||||
],
|
||||
},
|
||||
"pi0_fast": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/libero_plus",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.path=lerobot/pi0fast-base",
|
||||
"--policy.device=cuda",
|
||||
"--policy.dtype=bfloat16",
|
||||
"--policy.n_action_steps=30",
|
||||
"--policy.use_amp=true",
|
||||
"--policy.gradient_checkpointing=true",
|
||||
"--batch_size=1",
|
||||
*_PI_SGD,
|
||||
_LIBERO_RENAME_BASE_RGB,
|
||||
],
|
||||
},
|
||||
"pi05": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/libero_plus",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.path=lerobot/pi05_base",
|
||||
"--policy.device=cuda",
|
||||
"--policy.dtype=bfloat16",
|
||||
"--policy.n_action_steps=30",
|
||||
"--policy.use_amp=true",
|
||||
"--policy.gradient_checkpointing=true",
|
||||
"--batch_size=1",
|
||||
*_PI_SGD,
|
||||
'--policy.normalization_mapping={"ACTION": "MEAN_STD", '
|
||||
'"STATE": "MEAN_STD", "VISUAL": "IDENTITY"}',
|
||||
_LIBERO_RENAME_BASE_RGB,
|
||||
],
|
||||
},
|
||||
"smolvla": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/libero_plus",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.path=lerobot/smolvla_base",
|
||||
"--policy.load_vlm_weights=true",
|
||||
"--policy.freeze_vision_encoder=false",
|
||||
"--policy.train_expert_only=false",
|
||||
"--policy.empty_cameras=1",
|
||||
"--policy.device=cuda",
|
||||
"--batch_size=1",
|
||||
_LIBERO_RENAME_CAMERAS,
|
||||
],
|
||||
},
|
||||
"wall_x": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/aloha_sim_insertion_human",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.type=wall_x",
|
||||
"--policy.pretrained_name_or_path=x-square-robot/wall-oss-flow",
|
||||
"--policy.prediction_mode=diffusion",
|
||||
"--policy.attn_implementation=eager",
|
||||
"--policy.device=cuda",
|
||||
"--batch_size=1",
|
||||
*_PI_SGD,
|
||||
],
|
||||
},
|
||||
"xvla": {
|
||||
"steps": 12,
|
||||
"train_args": [
|
||||
"--dataset.repo_id=lerobot/libero_plus",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.path=lerobot/xvla-widowx",
|
||||
"--policy.action_mode=auto",
|
||||
"--policy.empty_cameras=1",
|
||||
"--policy.device=cuda",
|
||||
"--batch_size=1",
|
||||
'--rename_map={"observation.images.front": "observation.images.image", '
|
||||
'"observation.images.wrist": "observation.images.image2"}',
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrainingProfiler — hooks the training loop.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _stable_float(value: float | int | None) -> float | None:
|
||||
return None if value is None else round(float(value), 8)
|
||||
|
||||
|
||||
def _as_float(value: Any) -> float:
|
||||
if isinstance(value, Real):
|
||||
return float(value)
|
||||
if hasattr(value, "val"):
|
||||
return float(value.val)
|
||||
raise TypeError(f"Expected a real-valued metric, got {type(value).__name__}")
|
||||
|
||||
|
||||
def _summary(values: list[float]) -> dict[str, float | int | None]:
|
||||
if not values:
|
||||
return {"count": 0, "mean": None, "median": None, "min": None, "max": None}
|
||||
return {
|
||||
"count": len(values),
|
||||
"mean": statistics.fmean(values),
|
||||
"median": statistics.median(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
}
|
||||
|
||||
|
||||
def _tensor_signature(tensor: torch.Tensor) -> dict[str, Any]:
|
||||
"""Small, stable summary of a tensor so forward-pass outputs can be
|
||||
compared across runs without bloating the regression JSON."""
|
||||
cpu = tensor.detach().cpu()
|
||||
hash_tensor = cpu.float() if cpu.dtype == torch.bfloat16 else cpu
|
||||
sig: dict[str, Any] = {
|
||||
"shape": list(cpu.shape),
|
||||
"dtype": str(cpu.dtype),
|
||||
"numel": cpu.numel(),
|
||||
"sha256": hashlib.sha256(hash_tensor.contiguous().numpy().tobytes()).hexdigest(),
|
||||
}
|
||||
if cpu.numel():
|
||||
promoted = cpu.to(torch.float64) if cpu.is_floating_point() else cpu.to(torch.int64)
|
||||
sig["sum"] = _stable_float(promoted.sum().item())
|
||||
sig["mean"] = _stable_float(promoted.float().mean().item())
|
||||
return sig
|
||||
|
||||
|
||||
def _summarize_value(value: Any) -> Any:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return _tensor_signature(value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _summarize_value(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_summarize_value(v) for v in value]
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _hash_payload(payload: Any) -> str:
|
||||
return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
|
||||
def _get_profiler_device_time_us(event: Any) -> float | None:
|
||||
return _stable_float(
|
||||
getattr(event, "self_device_time_total", getattr(event, "self_cuda_time_total", None))
|
||||
)
|
||||
|
||||
|
||||
def _write_profiler_table(profiler: Any, path: Path, *, sort_by: str, row_limit: int = 40) -> None:
|
||||
try:
|
||||
path.write_text(profiler.key_averages().table(sort_by=sort_by, row_limit=row_limit))
|
||||
except Exception:
|
||||
logger.debug("Could not write profiler table for sort_by=%s", sort_by, exc_info=True)
|
||||
|
||||
|
||||
def write_deterministic_forward_artifacts(
|
||||
*,
|
||||
policy: Any,
|
||||
dataset: Any,
|
||||
batch_size: int,
|
||||
preprocessor: Any,
|
||||
output_dir: Path,
|
||||
device_type: str,
|
||||
) -> None:
|
||||
"""Run a seed-controlled single forward pass and dump a stable fingerprint
|
||||
(loss/output tensor hashes + op counts) for regression detection. Keeps
|
||||
the caller-selected module mode so ACT-with-VAE-style policies that only
|
||||
materialize their full forward outputs in `train()` still match. Models
|
||||
with stochastic train-mode layers still rely on the seeded RNG for stable
|
||||
fingerprints."""
|
||||
if len(dataset) == 0:
|
||||
raise ValueError("Cannot build a reference batch from an empty dataset.")
|
||||
indices = [i % len(dataset) for i in range(batch_size)]
|
||||
reference_batch = default_collate([dataset[i] for i in indices])
|
||||
# Mirror the uint8 → float32/255 conversion the train loop applies after
|
||||
# the dataloader (PR #3406). The dataset ships camera frames as uint8 for
|
||||
# faster transport, but policies like SmolVLA/xVLA run bilinear
|
||||
# interpolation on images which doesn't support Byte tensors.
|
||||
camera_keys = tuple(getattr(getattr(dataset, "meta", None), "camera_keys", ()) or ())
|
||||
if not camera_keys:
|
||||
camera_keys = tuple(
|
||||
key
|
||||
for key, value in reference_batch.items()
|
||||
if key.startswith("observation.images.") and isinstance(value, torch.Tensor)
|
||||
)
|
||||
for cam_key in camera_keys:
|
||||
if cam_key in reference_batch and reference_batch[cam_key].dtype == torch.uint8:
|
||||
reference_batch[cam_key] = reference_batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||
reference_batch = preprocessor(reference_batch)
|
||||
|
||||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||
if device_type == "cuda":
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
|
||||
with torch.random.fork_rng(devices=[] if device_type != "cuda" else None):
|
||||
torch.manual_seed(0)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed_all(0)
|
||||
with torch.no_grad(), torch.profiler.profile(activities=activities) as prof:
|
||||
loss, output_dict = policy.forward(reference_batch)
|
||||
|
||||
operators = sorted(
|
||||
(
|
||||
{
|
||||
"key": e.key,
|
||||
"count": e.count,
|
||||
"cpu_time_total_us": _stable_float(getattr(e, "cpu_time_total", None)),
|
||||
**(
|
||||
{"self_cuda_time_total_us": _get_profiler_device_time_us(e)}
|
||||
if device_type == "cuda"
|
||||
else {}
|
||||
),
|
||||
}
|
||||
for e in prof.key_averages()
|
||||
),
|
||||
key=lambda e: e["key"],
|
||||
)
|
||||
outputs = {"loss": _summarize_value(loss), "output_dict": _summarize_value(output_dict)}
|
||||
payload = {
|
||||
"seed": 0,
|
||||
"reference_batch_size": batch_size,
|
||||
"operator_fingerprint": _hash_payload([(o["key"], o["count"]) for o in operators]),
|
||||
"output_fingerprint": _hash_payload(outputs),
|
||||
"operators": operators,
|
||||
"outputs": outputs,
|
||||
}
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / "deterministic_forward.json").write_text(json.dumps(payload, indent=2, sort_keys=True))
|
||||
sort_by = "self_cuda_time_total" if device_type == "cuda" else "cpu_time_total"
|
||||
_write_profiler_table(prof, output_dir / "deterministic_forward_ops.txt", sort_by=sort_by)
|
||||
|
||||
|
||||
class TrainingProfiler:
|
||||
"""Self-contained profiling hooks for the training loop.
|
||||
|
||||
The training script interacts via ``start()``, ``section()``, ``step()``,
|
||||
``finalize()``, and (optionally) ``record_deterministic_forward()`` — a
|
||||
~7-line surface.
|
||||
"""
|
||||
|
||||
_SCHEDULE_WAIT = 1
|
||||
_SCHEDULE_WARMUP = 2
|
||||
_SCHEDULE_ACTIVE = 6
|
||||
|
||||
def __init__(self, mode: str, output_dir: Path, device: torch.device) -> None:
|
||||
self._mode = mode
|
||||
self._output_dir = output_dir
|
||||
self._output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._device = device
|
||||
# Inline timing state — no separate collector class.
|
||||
self._total_update_s: list[float] = []
|
||||
self._dataloading_s: list[float] = []
|
||||
self._section_s: dict[str, list[float]] = {}
|
||||
self._memory: list[dict[str, int]] = []
|
||||
self._torch = self._build_torch_profiler()
|
||||
logger.info("Profiling enabled. Artifacts will be written to %s", output_dir)
|
||||
|
||||
def _build_torch_profiler(self) -> Any:
|
||||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||||
if self._device.type == "cuda":
|
||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
trace_dir = self._output_dir / "torch_traces"
|
||||
trace_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _on_trace_ready(p: Any) -> None:
|
||||
if self._mode == "trace":
|
||||
p.export_chrome_trace(str(trace_dir / f"trace_step_{p.step_num}.json"))
|
||||
|
||||
return torch.profiler.profile(
|
||||
activities=activities,
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=self._SCHEDULE_WAIT,
|
||||
warmup=self._SCHEDULE_WARMUP,
|
||||
active=self._SCHEDULE_ACTIVE,
|
||||
repeat=1,
|
||||
),
|
||||
on_trace_ready=_on_trace_ready,
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_flops=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cfg(cls, cfg: Any, device: torch.device) -> TrainingProfiler:
|
||||
output = cfg.profile_output_dir or (Path(cfg.output_dir) / "profiling")
|
||||
return cls(mode=cfg.profile_mode, output_dir=Path(output), device=device)
|
||||
|
||||
def record_deterministic_forward(
|
||||
self, *, policy: Any, dataset: Any, batch_size: int, preprocessor: Any
|
||||
) -> None:
|
||||
logger.info("Recording deterministic forward-pass artifacts")
|
||||
write_deterministic_forward_artifacts(
|
||||
policy=policy,
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
preprocessor=preprocessor,
|
||||
output_dir=self._output_dir,
|
||||
device_type=self._device.type,
|
||||
)
|
||||
if self._device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def start(self) -> None:
|
||||
if self._device.type == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats(self._device)
|
||||
self._torch.__enter__()
|
||||
|
||||
@contextmanager
|
||||
def section(self, name: str) -> Iterator[None]:
|
||||
"""Time a region of the training step. Syncs on CUDA so the
|
||||
duration reflects GPU work, not just kernel-launch latency."""
|
||||
if self._device.type == "cuda":
|
||||
torch.cuda.synchronize(self._device)
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if self._device.type == "cuda":
|
||||
torch.cuda.synchronize(self._device)
|
||||
self._section_s.setdefault(name, []).append(time.perf_counter() - t0)
|
||||
|
||||
def step(self, step_num: int, train_tracker: Any) -> None:
|
||||
self._total_update_s.append(_as_float(train_tracker.update_s))
|
||||
self._dataloading_s.append(_as_float(train_tracker.dataloading_s))
|
||||
if self._device.type == "cuda":
|
||||
self._memory.append(
|
||||
{
|
||||
"step": step_num,
|
||||
"allocated_bytes": torch.cuda.memory_allocated(self._device),
|
||||
"reserved_bytes": torch.cuda.memory_reserved(self._device),
|
||||
}
|
||||
)
|
||||
self._torch.step()
|
||||
|
||||
def finalize(self) -> None:
|
||||
self._torch.__exit__(None, None, None)
|
||||
payload: dict[str, Any] = {
|
||||
"profile_mode": self._mode,
|
||||
"total_update_s": _summary(self._total_update_s),
|
||||
"dataloading_s": _summary(self._dataloading_s),
|
||||
"memory_timeline": self._memory,
|
||||
}
|
||||
for name, values in self._section_s.items():
|
||||
payload[f"{name}_s"] = _summary(values)
|
||||
if self._device.type == "cuda":
|
||||
payload["peak_memory_allocated_bytes"] = torch.cuda.max_memory_allocated(self._device)
|
||||
payload["peak_memory_reserved_bytes"] = torch.cuda.max_memory_reserved(self._device)
|
||||
(self._output_dir / "step_timing_summary.json").write_text(
|
||||
json.dumps(payload, indent=2, sort_keys=True)
|
||||
)
|
||||
|
||||
tables_dir = self._output_dir / "torch_tables"
|
||||
tables_dir.mkdir(parents=True, exist_ok=True)
|
||||
_write_profiler_table(self._torch, tables_dir / "cpu_time_total.txt", sort_by="cpu_time_total")
|
||||
_write_profiler_table(self._torch, tables_dir / "cpu_memory.txt", sort_by="self_cpu_memory_usage")
|
||||
_write_profiler_table(self._torch, tables_dir / "flops.txt", sort_by="flops")
|
||||
if self._device.type == "cuda":
|
||||
_write_profiler_table(
|
||||
self._torch, tables_dir / "cuda_time_total.txt", sort_by="self_cuda_time_total"
|
||||
)
|
||||
_write_profiler_table(
|
||||
self._torch, tables_dir / "cuda_memory.txt", sort_by="self_cuda_memory_usage"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CI orchestrator. Spawns `lerobot-train` per policy, collects the
|
||||
# artifacts, (optionally) uploads to the HF Hub results dataset.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UploadTarget:
|
||||
local_path: Path
|
||||
path_in_repo: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UploadResult:
|
||||
uploaded_paths: dict[str, str]
|
||||
pr_url: str | None = None
|
||||
|
||||
|
||||
def _utc_timestamp_slug(now: datetime | None = None) -> str:
|
||||
return (now or datetime.now(UTC)).strftime("%Y%m%dT%H%M%SZ")
|
||||
|
||||
|
||||
def _hub_file_url(repo_id: str, path_in_repo: str, *, revision: str = "main") -> str:
|
||||
return f"https://huggingface.co/datasets/{repo_id}/resolve/{revision}/{path_in_repo}"
|
||||
|
||||
|
||||
def parse_discussion_num(pr_url: str | None) -> int | None:
|
||||
if not pr_url:
|
||||
return None
|
||||
m = re.search(r"/discussions/(\d+)$", pr_url)
|
||||
return int(m.group(1)) if m else None
|
||||
|
||||
|
||||
def upload_targets(
|
||||
repo_id: str,
|
||||
targets: list[UploadTarget],
|
||||
*,
|
||||
token: str | None = None,
|
||||
commit_message: str | None = None,
|
||||
create_pr: bool = False,
|
||||
) -> UploadResult:
|
||||
api = HfApi(token=token)
|
||||
commit = api.create_commit(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
operations=[
|
||||
CommitOperationAdd(path_in_repo=t.path_in_repo, path_or_fileobj=str(t.local_path))
|
||||
for t in targets
|
||||
],
|
||||
commit_message=commit_message or f"Upload {len(targets)} profiling artifacts",
|
||||
revision="main",
|
||||
create_pr=create_pr,
|
||||
)
|
||||
pr_num = parse_discussion_num(commit.pr_url)
|
||||
revision = f"refs/pr/{pr_num}" if (create_pr and pr_num) else "main"
|
||||
return UploadResult(
|
||||
uploaded_paths={
|
||||
t.path_in_repo: _hub_file_url(repo_id, t.path_in_repo, revision=revision) for t in targets
|
||||
},
|
||||
pr_url=commit.pr_url,
|
||||
)
|
||||
|
||||
|
||||
def build_train_command(policy: str, run_dir: Path, profile_mode: str) -> list[str]:
|
||||
spec = POLICY_SPECS[policy]
|
||||
return [
|
||||
"uv",
|
||||
"run",
|
||||
"lerobot-train",
|
||||
*spec["train_args"],
|
||||
f"--output_dir={run_dir / 'train'}",
|
||||
f"--steps={spec['steps']}",
|
||||
"--eval_freq=0",
|
||||
"--save_checkpoint=false",
|
||||
f"--save_freq={spec['steps']}",
|
||||
"--wandb.enable=false",
|
||||
"--policy.push_to_hub=false",
|
||||
"--num_workers=0",
|
||||
"--log_freq=1",
|
||||
f"--profile_mode={profile_mode}",
|
||||
f"--profile_output_dir={run_dir / 'profiling'}",
|
||||
]
|
||||
|
||||
|
||||
def build_artifact_index(
|
||||
*, repo_id: str, run_dir: Path, policy_name: str, run_id: str
|
||||
) -> tuple[dict[str, Any], dict[str, Any], list[UploadTarget], str]:
|
||||
"""Scan the run directory and categorize files into
|
||||
(stdout/stderr, torch_tables/*, torch_traces/*, everything else under profiling/).
|
||||
Returns (paths, urls, upload targets, row path in repo)."""
|
||||
row_path_in_repo = f"rows/{policy_name}/{run_id}.json"
|
||||
root = f"artifacts/{policy_name}/{run_id}"
|
||||
paths: dict[str, Any] = {
|
||||
"row": row_path_in_repo,
|
||||
"profiling_files": {},
|
||||
"torch_tables": {},
|
||||
"trace_files": {},
|
||||
}
|
||||
urls: dict[str, Any] = {
|
||||
"row": _hub_file_url(repo_id, row_path_in_repo),
|
||||
"profiling_files": {},
|
||||
"torch_tables": {},
|
||||
"trace_files": {},
|
||||
}
|
||||
targets: list[UploadTarget] = []
|
||||
|
||||
for name in ("stdout.txt", "stderr.txt"):
|
||||
p = run_dir / name
|
||||
if p.exists():
|
||||
key = name.removesuffix(".txt")
|
||||
repo = f"{root}/{name}"
|
||||
paths[key] = repo
|
||||
urls[key] = _hub_file_url(repo_id, repo)
|
||||
targets.append(UploadTarget(p, repo))
|
||||
|
||||
profiling_dir = run_dir / "profiling"
|
||||
if profiling_dir.exists():
|
||||
for p in sorted(profiling_dir.rglob("*")):
|
||||
if not p.is_file():
|
||||
continue
|
||||
rel = str(p.relative_to(run_dir))
|
||||
repo = f"{root}/{rel}"
|
||||
paths["profiling_files"][rel] = repo
|
||||
urls["profiling_files"][rel] = _hub_file_url(repo_id, repo)
|
||||
targets.append(UploadTarget(p, repo))
|
||||
if p.name == "step_timing_summary.json":
|
||||
paths["step_timing_summary"] = repo
|
||||
urls["step_timing_summary"] = _hub_file_url(repo_id, repo)
|
||||
elif "torch_tables" in p.parts:
|
||||
paths["torch_tables"][p.name] = repo
|
||||
urls["torch_tables"][p.name] = _hub_file_url(repo_id, repo)
|
||||
elif "torch_traces" in p.parts:
|
||||
paths["trace_files"][p.name] = repo
|
||||
urls["trace_files"][p.name] = _hub_file_url(repo_id, repo)
|
||||
|
||||
return paths, urls, targets, row_path_in_repo
|
||||
|
||||
|
||||
def upload_profile_run(
|
||||
*,
|
||||
repo_id: str,
|
||||
row_path: Path,
|
||||
row_path_in_repo: str,
|
||||
artifact_targets: list[UploadTarget],
|
||||
create_pr: bool = False,
|
||||
) -> UploadResult:
|
||||
return upload_targets(
|
||||
repo_id=repo_id,
|
||||
targets=[*artifact_targets, UploadTarget(row_path, row_path_in_repo)],
|
||||
commit_message=f"Add model profiling row {row_path_in_repo}",
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
|
||||
def _load_json(path: Path) -> dict[str, Any]:
|
||||
return json.loads(path.read_text()) if path.exists() else {}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--policies", nargs="*", default=None)
|
||||
parser.add_argument("--output_dir", type=Path, required=True)
|
||||
parser.add_argument("--hub_org", default="lerobot")
|
||||
parser.add_argument("--results_repo", default="model-profiling-history")
|
||||
parser.add_argument("--publish", action="store_true")
|
||||
parser.add_argument("--profile_mode", choices=["summary", "trace"], default="trace")
|
||||
parser.add_argument("--git_commit", default="")
|
||||
parser.add_argument("--git_ref", default="")
|
||||
parser.add_argument("--pr_number", default="")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
selected = args.policies or list(POLICY_SPECS)
|
||||
unknown = sorted(set(selected) - set(POLICY_SPECS))
|
||||
if unknown:
|
||||
raise ValueError(f"Unknown profiling policies: {', '.join(unknown)}")
|
||||
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
repo_id = args.results_repo if "/" in args.results_repo else f"{args.hub_org}/{args.results_repo}"
|
||||
git_exe = shutil.which("git")
|
||||
if not git_exe:
|
||||
raise RuntimeError("git not found in PATH")
|
||||
git_commit = args.git_commit or subprocess.check_output([git_exe, "rev-parse", "HEAD"], text=True).strip()
|
||||
pr_number = int(args.pr_number) if str(args.pr_number).strip() else None
|
||||
exit_code = 0
|
||||
|
||||
for policy in selected:
|
||||
run_id = f"{_utc_timestamp_slug()}__{policy}"
|
||||
run_dir = args.output_dir / policy / run_id
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
cmd = build_train_command(policy, run_dir, args.profile_mode)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
wall_s = time.perf_counter() - t0
|
||||
|
||||
(run_dir / "stdout.txt").write_text(result.stdout)
|
||||
(run_dir / "stderr.txt").write_text(result.stderr)
|
||||
if result.returncode != 0:
|
||||
exit_code = 1
|
||||
|
||||
paths, urls, upload_list, row_in_repo = build_artifact_index(
|
||||
repo_id=repo_id, run_dir=run_dir, policy_name=policy, run_id=run_id
|
||||
)
|
||||
row: dict[str, Any] = {
|
||||
"schema_version": 1,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"run_id": run_id,
|
||||
"policy": policy,
|
||||
"git_commit": git_commit,
|
||||
"git_ref": args.git_ref or None,
|
||||
"pr_number": pr_number,
|
||||
"status": "success" if result.returncode == 0 else "failed",
|
||||
"return_code": result.returncode,
|
||||
"profile_mode": args.profile_mode,
|
||||
"wall_time_s": wall_s,
|
||||
"spec": {
|
||||
"steps": POLICY_SPECS[policy]["steps"],
|
||||
"train_args": POLICY_SPECS[policy]["train_args"],
|
||||
},
|
||||
"step_timing_summary": _load_json(run_dir / "profiling" / "step_timing_summary.json"),
|
||||
"deterministic_forward": _load_json(run_dir / "profiling" / "deterministic_forward.json"),
|
||||
"artifact_paths": paths,
|
||||
"artifact_urls": urls,
|
||||
"stderr_tail": result.stderr.splitlines()[-20:],
|
||||
}
|
||||
|
||||
row_path = run_dir / "profiling_row.json"
|
||||
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
||||
|
||||
if args.publish:
|
||||
try:
|
||||
uploaded = upload_profile_run(
|
||||
repo_id=repo_id,
|
||||
row_path=row_path,
|
||||
row_path_in_repo=row_in_repo,
|
||||
artifact_targets=upload_list,
|
||||
create_pr=pr_number is not None,
|
||||
)
|
||||
except HfHubHTTPError as exc:
|
||||
row.update({"publish_status": "failed", "publish_error": str(exc)})
|
||||
else:
|
||||
row.update(
|
||||
{
|
||||
"publish_status": "success",
|
||||
"uploaded_paths": uploaded.uploaded_paths,
|
||||
"publish_pr_url": uploaded.pr_url,
|
||||
"publish_pr_number": parse_discussion_num(uploaded.pr_url),
|
||||
}
|
||||
)
|
||||
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
|
||||
|
||||
print(json.dumps(row, indent=2, sort_keys=True))
|
||||
|
||||
return exit_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,83 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generic foot pedal listener using evdev.
|
||||
|
||||
Callers supply a callback receiving the pressed key code (e.g. ``"KEY_A"``)
|
||||
and an optional device path. The listener runs in a daemon thread and
|
||||
silently no-ops when :mod:`evdev` is not installed or the device is
|
||||
unavailable. Strategy-specific key mapping logic lives in the caller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
|
||||
|
||||
def start_pedal_listener(
|
||||
on_press: Callable[[str], None],
|
||||
device_path: str = DEFAULT_PEDAL_DEVICE,
|
||||
) -> threading.Thread | None:
|
||||
"""Spawn a daemon thread that forwards pedal key-press codes to ``on_press``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
on_press:
|
||||
Callback invoked with the pressed key code string (e.g. ``"KEY_A"``)
|
||||
on each pedal press event. The callback runs in the listener thread
|
||||
and must be thread-safe.
|
||||
device_path:
|
||||
Linux input device path (e.g. ``/dev/input/by-id/...``).
|
||||
|
||||
Returns
|
||||
-------
|
||||
The started daemon :class:`threading.Thread`, or ``None`` when
|
||||
:mod:`evdev` is not installed (optional dependency; silent no-op).
|
||||
"""
|
||||
try:
|
||||
from evdev import InputDevice, categorize, ecodes
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
def pedal_reader() -> None:
|
||||
try:
|
||||
dev = InputDevice(device_path)
|
||||
logger.info("Pedal connected: %s", dev.name)
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
continue
|
||||
key = categorize(ev)
|
||||
code = key.keycode
|
||||
if isinstance(code, (list, tuple)):
|
||||
code = code[0]
|
||||
if key.keystate != 1: # only key-down events
|
||||
continue
|
||||
try:
|
||||
on_press(code)
|
||||
except Exception as cb_err: # pragma: no cover - defensive
|
||||
logger.warning("Pedal callback error: %s", cb_err)
|
||||
except (FileNotFoundError, PermissionError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Pedal error: %s", e)
|
||||
|
||||
thread = threading.Thread(target=pedal_reader, daemon=True, name="PedalListener")
|
||||
thread.start()
|
||||
return thread
|
||||
282
tests/envs/test_robotwin.py
Normal file
282
tests/envs/test_robotwin.py
Normal file
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for the RoboTwin 2.0 Gymnasium wrapper.
|
||||
|
||||
These tests mock out the SAPIEN-based RoboTwin runtime (task modules +
|
||||
YAML config loader) so they run without the full RoboTwin installation
|
||||
(SAPIEN, CuRobo, mplib, asset downloads, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.envs.robotwin import (
|
||||
ACTION_DIM,
|
||||
ROBOTWIN_CAMERA_NAMES,
|
||||
ROBOTWIN_TASKS,
|
||||
RoboTwinEnv,
|
||||
create_robotwin_envs,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures / helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_task_env(
|
||||
height: int = 240,
|
||||
width: int = 320,
|
||||
cameras: tuple[str, ...] = ROBOTWIN_CAMERA_NAMES,
|
||||
) -> MagicMock:
|
||||
"""Return a mock that mimics the RoboTwin task class API.
|
||||
|
||||
RoboTwin's real get_obs returns
|
||||
{"observation": {cam: {"rgb": img}}, "joint_action": {"vector": np.ndarray}, ...}
|
||||
so the mock follows the same nested shape.
|
||||
"""
|
||||
obs_dict = {
|
||||
"observation": {cam: {"rgb": np.zeros((height, width, 3), dtype=np.uint8)} for cam in cameras},
|
||||
"joint_action": {"vector": np.zeros(ACTION_DIM, dtype=np.float32)},
|
||||
"endpose": {},
|
||||
}
|
||||
|
||||
mock = MagicMock()
|
||||
mock.get_obs.return_value = obs_dict
|
||||
mock.setup_demo.return_value = None
|
||||
mock.take_action.return_value = None
|
||||
mock.eval_success = False
|
||||
mock.check_success.return_value = False
|
||||
mock.close_env.return_value = None
|
||||
return mock
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_runtime(mock_task_instance: MagicMock):
|
||||
"""Patch both the task-class loader and the YAML config loader so the
|
||||
env can construct + reset without a real RoboTwin install."""
|
||||
task_cls = MagicMock(return_value=mock_task_instance)
|
||||
fake_setup = {
|
||||
"head_camera_h": 240,
|
||||
"head_camera_w": 320,
|
||||
"left_embodiment_config": {},
|
||||
"right_embodiment_config": {},
|
||||
"left_robot_file": "",
|
||||
"right_robot_file": "",
|
||||
"dual_arm_embodied": True,
|
||||
"render_freq": 0,
|
||||
"task_name": "beat_block_hammer",
|
||||
"task_config": "demo_clean",
|
||||
}
|
||||
with (
|
||||
patch("lerobot.envs.robotwin._load_robotwin_task", return_value=task_cls),
|
||||
patch("lerobot.envs.robotwin._load_robotwin_setup_kwargs", return_value=fake_setup),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoboTwinEnv unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRoboTwinEnv:
|
||||
def test_observation_space_shape(self):
|
||||
"""observation_space should have the configured h×w×3 for every camera."""
|
||||
h, w = 240, 320
|
||||
env = RoboTwinEnv(
|
||||
task_name="beat_block_hammer",
|
||||
observation_height=h,
|
||||
observation_width=w,
|
||||
camera_names=["head_camera", "left_camera"],
|
||||
)
|
||||
pixels_space = env.observation_space["pixels"]
|
||||
assert pixels_space["head_camera"].shape == (h, w, 3)
|
||||
assert pixels_space["left_camera"].shape == (h, w, 3)
|
||||
assert "right_camera" not in pixels_space
|
||||
|
||||
def test_action_space(self):
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer")
|
||||
assert env.action_space.shape == (ACTION_DIM,)
|
||||
assert env.action_space.dtype == np.float32
|
||||
|
||||
def test_reset_returns_correct_obs_keys(self):
|
||||
mock_task = _make_mock_task_env()
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer")
|
||||
with _patch_runtime(mock_task):
|
||||
obs, info = env.reset()
|
||||
|
||||
assert "pixels" in obs
|
||||
for cam in ROBOTWIN_CAMERA_NAMES:
|
||||
assert cam in obs["pixels"], f"Missing camera '{cam}' in obs"
|
||||
assert "agent_pos" in obs
|
||||
assert obs["agent_pos"].shape == (ACTION_DIM,)
|
||||
assert info["is_success"] is False
|
||||
|
||||
def test_reset_calls_setup_demo(self):
|
||||
mock_task = _make_mock_task_env()
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer")
|
||||
with _patch_runtime(mock_task):
|
||||
env.reset(seed=42)
|
||||
# setup_demo receives the full YAML-derived kwargs plus seed + is_test;
|
||||
# we only assert the caller-provided bits.
|
||||
assert mock_task.setup_demo.call_count == 1
|
||||
call_kwargs = mock_task.setup_demo.call_args.kwargs
|
||||
assert call_kwargs["seed"] == 42
|
||||
assert call_kwargs["is_test"] is True
|
||||
|
||||
def test_step_returns_correct_types(self):
|
||||
mock_task = _make_mock_task_env()
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer")
|
||||
action = np.zeros(ACTION_DIM, dtype=np.float32)
|
||||
with _patch_runtime(mock_task):
|
||||
env.reset()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
assert isinstance(obs, dict)
|
||||
assert isinstance(reward, float)
|
||||
assert isinstance(terminated, bool)
|
||||
assert isinstance(truncated, bool)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
def test_step_wrong_action_shape_raises(self):
|
||||
mock_task = _make_mock_task_env()
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer")
|
||||
bad_action = np.zeros(7, dtype=np.float32) # wrong dim
|
||||
with _patch_runtime(mock_task):
|
||||
env.reset()
|
||||
with pytest.raises(ValueError, match="Expected 1-D action"):
|
||||
env.step(bad_action)
|
||||
|
||||
def test_success_terminates_episode(self):
|
||||
mock_task = _make_mock_task_env()
|
||||
mock_task.check_success.return_value = True
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer")
|
||||
action = np.zeros(ACTION_DIM, dtype=np.float32)
|
||||
with _patch_runtime(mock_task):
|
||||
env.reset()
|
||||
_, _, terminated, _, info = env.step(action)
|
||||
assert terminated is True
|
||||
assert info["is_success"] is True
|
||||
|
||||
def test_truncation_after_episode_length(self):
|
||||
mock_task = _make_mock_task_env()
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer", episode_length=2)
|
||||
action = np.zeros(ACTION_DIM, dtype=np.float32)
|
||||
with _patch_runtime(mock_task):
|
||||
env.reset()
|
||||
env.step(action) # step 1
|
||||
_, _, _, truncated, _ = env.step(action) # step 2 → truncated
|
||||
assert truncated is True
|
||||
|
||||
def test_close_calls_close_env(self):
|
||||
mock_task = _make_mock_task_env()
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer")
|
||||
with _patch_runtime(mock_task):
|
||||
env.reset()
|
||||
env.close()
|
||||
mock_task.close_env.assert_called_once()
|
||||
|
||||
def test_black_frame_for_missing_camera(self):
|
||||
"""If a camera key is absent from get_obs(), a black frame is returned."""
|
||||
# Mock exposes only head_camera; we ask for both head_camera + left_camera.
|
||||
mock_task = _make_mock_task_env(height=10, width=10, cameras=("head_camera",))
|
||||
env = RoboTwinEnv(
|
||||
task_name="beat_block_hammer",
|
||||
camera_names=["head_camera", "left_camera"],
|
||||
observation_height=10,
|
||||
observation_width=10,
|
||||
)
|
||||
with _patch_runtime(mock_task):
|
||||
obs, _ = env.reset()
|
||||
assert obs["pixels"]["left_camera"].shape == (10, 10, 3)
|
||||
assert obs["pixels"]["left_camera"].sum() == 0
|
||||
|
||||
def test_task_and_task_description_attributes(self):
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer")
|
||||
assert env.task == "beat_block_hammer"
|
||||
assert isinstance(env.task_description, str)
|
||||
|
||||
def test_deferred_init_env_is_none_before_reset(self):
|
||||
env = RoboTwinEnv(task_name="beat_block_hammer")
|
||||
assert env._env is None # noqa: SLF001 (testing internal state)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_robotwin_envs tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateRoboTwinEnvs:
|
||||
def test_returns_correct_structure(self):
|
||||
mock_task = _make_mock_task_env()
|
||||
with _patch_runtime(mock_task):
|
||||
envs = create_robotwin_envs(
|
||||
task="beat_block_hammer",
|
||||
n_envs=1,
|
||||
env_cls=gym.vector.SyncVectorEnv,
|
||||
)
|
||||
assert "beat_block_hammer" in envs
|
||||
assert 0 in envs["beat_block_hammer"]
|
||||
assert isinstance(envs["beat_block_hammer"][0], gym.vector.SyncVectorEnv)
|
||||
|
||||
def test_multi_task(self):
|
||||
mock_task = _make_mock_task_env()
|
||||
with _patch_runtime(mock_task):
|
||||
envs = create_robotwin_envs(
|
||||
task="beat_block_hammer,click_bell",
|
||||
n_envs=1,
|
||||
env_cls=gym.vector.SyncVectorEnv,
|
||||
)
|
||||
assert set(envs.keys()) == {"beat_block_hammer", "click_bell"}
|
||||
|
||||
def test_unknown_task_raises(self):
|
||||
with pytest.raises(ValueError, match="Unknown RoboTwin tasks"):
|
||||
create_robotwin_envs(
|
||||
task="not_a_real_task",
|
||||
n_envs=1,
|
||||
env_cls=gym.vector.SyncVectorEnv,
|
||||
)
|
||||
|
||||
def test_invalid_n_envs_raises(self):
|
||||
with pytest.raises(ValueError, match="n_envs must be a positive int"):
|
||||
create_robotwin_envs(
|
||||
task="beat_block_hammer",
|
||||
n_envs=0,
|
||||
env_cls=gym.vector.SyncVectorEnv,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ROBOTWIN_TASKS list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_task_list_not_empty():
|
||||
assert len(ROBOTWIN_TASKS) >= 50
|
||||
|
||||
|
||||
def test_all_tasks_are_strings():
|
||||
assert all(isinstance(t, str) and t for t in ROBOTWIN_TASKS)
|
||||
|
||||
|
||||
def test_no_duplicate_tasks():
|
||||
assert len(ROBOTWIN_TASKS) == len(set(ROBOTWIN_TASKS))
|
||||
@@ -17,9 +17,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class TestRTCDenoiseWithRelativeLeftovers:
|
||||
|
||||
|
||||
class TestFullPipelineRelativeRTC:
|
||||
"""End-to-end test of the RTC + relative actions pipeline matching lerobot-rollout flow."""
|
||||
"""End-to-end test of the RTC + relative actions pipeline matching eval_with_real_robot.py flow."""
|
||||
|
||||
def test_preprocessor_caches_state_for_postprocessor(self):
|
||||
"""Preprocessor's relative step should cache state so postprocessor can convert back."""
|
||||
@@ -240,7 +240,7 @@ class TestFullPipelineRelativeRTC:
|
||||
torch.testing.assert_close(recovered, actions, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_eval_loop_simulation(self):
|
||||
"""Simulate the lerobot-rollout loop with relative actions.
|
||||
"""Simulate the eval_with_real_robot.py loop with relative actions.
|
||||
|
||||
Iteration 1: No leftovers → model generates relative actions → store for RTC
|
||||
Iteration 2: Use leftovers as RTC guidance → model generates new relative actions
|
||||
@@ -401,12 +401,12 @@ class TestStateRebasingApproximation:
|
||||
|
||||
|
||||
def _detect_relative_actions(preprocessor) -> bool:
|
||||
"""Mirror of the helper in lerobot-rollout for testing without importing it."""
|
||||
"""Mirror of the helper in eval_with_real_robot.py for testing without importing it."""
|
||||
return any(isinstance(step, RelativeActionsProcessorStep) and step.enabled for step in preprocessor.steps)
|
||||
|
||||
|
||||
class TestDetectRelativeActions:
|
||||
"""Test the _detect_relative_actions helper logic used by lerobot-rollout."""
|
||||
"""Test the _detect_relative_actions helper logic used by eval_with_real_robot.py."""
|
||||
|
||||
def test_detects_enabled_relative_step(self):
|
||||
class FakePipeline:
|
||||
|
||||
@@ -24,6 +24,10 @@ def lerobot_train(args):
|
||||
return run_command(cmd="lerobot-train", module="lerobot_train", args=args)
|
||||
|
||||
|
||||
def lerobot_record(args):
|
||||
return run_command(cmd="lerobot-record", module="lerobot_record", args=args)
|
||||
|
||||
|
||||
def resolve_model_id_for_peft_training(policy_type):
|
||||
"""PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
|
||||
if policy_type == "smolvla":
|
||||
@@ -151,3 +155,81 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
|
||||
f"--output_dir={output_dir}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DummyRobot:
|
||||
name = "dummy"
|
||||
cameras = []
|
||||
action_features = {"foo": 1.0, "bar": 2.0}
|
||||
observation_features = {"obs1": 1.0, "obs2": 2.0}
|
||||
is_connected = True
|
||||
|
||||
def connect(self, *args):
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
|
||||
def dummy_make_robot_from_config(*args, **kwargs):
|
||||
return DummyRobot()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_type", ["smolvla"])
|
||||
@skip_if_package_missing("peft")
|
||||
def test_peft_record_loads_policy(policy_type, tmp_path):
|
||||
"""Train a policy with PEFT and attempt to load it with `lerobot-record`."""
|
||||
from peft import PeftModel
|
||||
|
||||
output_dir = tmp_path / f"output_{policy_type}"
|
||||
model_id = resolve_model_id_for_peft_training(policy_type)
|
||||
|
||||
lerobot_train(
|
||||
[
|
||||
f"--policy.path={model_id}",
|
||||
"--policy.push_to_hub=false",
|
||||
"--policy.input_features=null",
|
||||
"--policy.output_features=null",
|
||||
"--peft.method=LORA",
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0, 1]",
|
||||
"--steps=1",
|
||||
f"--output_dir={output_dir}",
|
||||
]
|
||||
)
|
||||
|
||||
policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model"
|
||||
dataset_dir = tmp_path / "eval_pusht"
|
||||
single_task = "move the table"
|
||||
loaded_policy = None
|
||||
|
||||
def dummy_record_loop(*args, **kwargs):
|
||||
nonlocal loaded_policy
|
||||
|
||||
if "dataset" not in kwargs:
|
||||
return
|
||||
|
||||
dataset = kwargs["dataset"]
|
||||
dataset.add_frame({"task": single_task})
|
||||
loaded_policy = kwargs["policy"]
|
||||
|
||||
with (
|
||||
patch("lerobot.scripts.lerobot_record.make_robot_from_config", dummy_make_robot_from_config),
|
||||
# disable record loop since we're only interested in successful loading of the policy.
|
||||
patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop),
|
||||
# disable speech output
|
||||
patch("lerobot.utils.utils.say"),
|
||||
):
|
||||
lerobot_record(
|
||||
[
|
||||
f"--policy.path={policy_dir}",
|
||||
"--robot.type=so101_follower",
|
||||
"--robot.port=/dev/null",
|
||||
"--dataset.repo_id=lerobot/eval_pusht",
|
||||
f'--dataset.single_task="{single_task}"',
|
||||
f"--dataset.root={dataset_dir}",
|
||||
"--dataset.push_to_hub=false",
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(loaded_policy, PeftModel)
|
||||
|
||||
@@ -21,9 +21,8 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])")
|
||||
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate
|
||||
from lerobot.scripts.lerobot_record import RecordConfig, record
|
||||
from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record
|
||||
from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay
|
||||
from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
348
tests/test_model_profiling.py
Normal file
348
tests/test_model_profiling.py
Normal file
@@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.utils import model_profiling as mp
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Policy spec matrix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_policy_specs_cover_expected_policies():
|
||||
assert set(mp.POLICY_SPECS) == {
|
||||
"act",
|
||||
"diffusion",
|
||||
"groot",
|
||||
"multi_task_dit",
|
||||
"pi0",
|
||||
"pi0_fast",
|
||||
"pi05",
|
||||
"smolvla",
|
||||
"wall_x",
|
||||
"xvla",
|
||||
}
|
||||
# Sanity: excluded policies should stay out of the matrix.
|
||||
for excluded in ("sac", "sarm", "tdmpc", "vqbet", "reward_classifier"):
|
||||
assert excluded not in mp.POLICY_SPECS
|
||||
|
||||
|
||||
def test_pretrained_libero_specs_match_expected_camera_keys_and_normalization():
|
||||
base_rgb_rename = (
|
||||
'--rename_map={"observation.images.front": "observation.images.base_0_rgb", '
|
||||
'"observation.images.wrist": "observation.images.left_wrist_0_rgb"}'
|
||||
)
|
||||
for name in ("pi0", "pi0_fast", "pi05"):
|
||||
assert base_rgb_rename in mp.POLICY_SPECS[name]["train_args"]
|
||||
assert any(
|
||||
arg.startswith('--policy.normalization_mapping={"ACTION": "MEAN_STD"')
|
||||
for arg in mp.POLICY_SPECS["pi05"]["train_args"]
|
||||
)
|
||||
assert (
|
||||
'--rename_map={"observation.images.front": "observation.images.camera1", '
|
||||
'"observation.images.wrist": "observation.images.camera2"}'
|
||||
in mp.POLICY_SPECS["smolvla"]["train_args"]
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CI orchestrator helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_train_command_includes_profiling_outputs(tmp_path):
|
||||
cmd = mp.build_train_command("act", tmp_path / "run", "trace")
|
||||
assert cmd[:3] == ["uv", "run", "lerobot-train"]
|
||||
assert any(a.startswith("--output_dir=") for a in cmd)
|
||||
assert any(a.startswith("--profile_output_dir=") for a in cmd)
|
||||
assert "--profile_mode=trace" in cmd
|
||||
assert "--eval_freq=0" in cmd
|
||||
|
||||
|
||||
def test_build_artifact_index_collects_tables_and_traces(tmp_path):
|
||||
run_dir = tmp_path / "act" / "20260415T000000Z__act"
|
||||
profiling = run_dir / "profiling"
|
||||
(profiling / "torch_tables").mkdir(parents=True)
|
||||
(profiling / "torch_traces").mkdir(parents=True)
|
||||
(profiling / "step_timing_summary.json").write_text("{}")
|
||||
(profiling / "deterministic_forward.json").write_text(
|
||||
json.dumps({"operator_fingerprint": "ops", "output_fingerprint": "out"})
|
||||
)
|
||||
(profiling / "torch_tables" / "cpu_time_total.txt").write_text("cpu table")
|
||||
(profiling / "torch_traces" / "trace_step_9.json").write_text("{}")
|
||||
(run_dir / "stdout.txt").write_text("stdout")
|
||||
(run_dir / "stderr.txt").write_text("stderr")
|
||||
|
||||
paths, urls, targets, row_in_repo = mp.build_artifact_index(
|
||||
repo_id="lerobot/model-profiling-history",
|
||||
run_dir=run_dir,
|
||||
policy_name="act",
|
||||
run_id="20260415T000000Z__act",
|
||||
)
|
||||
|
||||
assert row_in_repo == "rows/act/20260415T000000Z__act.json"
|
||||
assert paths["stdout"].endswith("/stdout.txt")
|
||||
assert paths["step_timing_summary"].endswith("/profiling/step_timing_summary.json")
|
||||
assert "cpu_time_total.txt" in paths["torch_tables"]
|
||||
assert "trace_step_9.json" in paths["trace_files"]
|
||||
assert urls["row"].startswith("https://huggingface.co/datasets/lerobot/model-profiling-history/")
|
||||
# stdout + stderr + 4 profiling files
|
||||
assert len(targets) == 6
|
||||
|
||||
|
||||
def test_upload_targets_batches_preview_publish_into_single_hf_pr(monkeypatch, tmp_path):
|
||||
local_path = tmp_path / "profiling_row.json"
|
||||
local_path.write_text("{}")
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeCommit:
|
||||
pr_url = "https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42"
|
||||
|
||||
class _FakeApi:
|
||||
def __init__(self, token=None):
|
||||
captured["token"] = token
|
||||
|
||||
def create_commit(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return _FakeCommit()
|
||||
|
||||
monkeypatch.setattr(mp, "HfApi", _FakeApi)
|
||||
|
||||
result = mp.upload_targets(
|
||||
repo_id="lerobot/model-profiling-history",
|
||||
targets=[mp.UploadTarget(local_path, "rows/act/run.json")],
|
||||
create_pr=True,
|
||||
token="hf_test_token",
|
||||
)
|
||||
|
||||
assert captured["repo_id"] == "lerobot/model-profiling-history"
|
||||
assert captured["repo_type"] == "dataset"
|
||||
assert captured["create_pr"] is True
|
||||
assert result.pr_url == _FakeCommit.pr_url
|
||||
assert result.uploaded_paths["rows/act/run.json"].endswith("/resolve/refs/pr/42/rows/act/run.json")
|
||||
|
||||
|
||||
def test_parse_discussion_num_handles_hf_discussion_urls():
|
||||
assert (
|
||||
mp.parse_discussion_num(
|
||||
"https://huggingface.co/datasets/lerobot/model-profiling-history/discussions/42"
|
||||
)
|
||||
== 42
|
||||
)
|
||||
assert mp.parse_discussion_num("https://huggingface.co/datasets/lerobot/model-profiling-history") is None
|
||||
assert mp.parse_discussion_num(None) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# main() smoke tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_args(tmp_path):
|
||||
"""Shared argparse namespace for main() smoke tests — overridden per-test."""
|
||||
return argparse.Namespace(
|
||||
policies=["act"],
|
||||
output_dir=tmp_path / "results",
|
||||
hub_org="lerobot",
|
||||
results_repo="model-profiling-history",
|
||||
publish=False,
|
||||
profile_mode="summary",
|
||||
git_commit="",
|
||||
git_ref="codex/model-profiling",
|
||||
pr_number="3389",
|
||||
)
|
||||
|
||||
|
||||
def _stub_train_subprocess(mp_module, *, returncode: int = 0, write_artifacts: bool = True):
|
||||
"""Build a fake subprocess.run that writes the profiling artifacts main() expects."""
|
||||
|
||||
def _fake_run(cmd, capture_output, text):
|
||||
assert capture_output is True
|
||||
assert text is True
|
||||
profile_dir = Path(next(a.split("=", 1)[1] for a in cmd if a.startswith("--profile_output_dir=")))
|
||||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||||
if write_artifacts:
|
||||
(profile_dir / "torch_tables").mkdir(parents=True, exist_ok=True)
|
||||
(profile_dir / "step_timing_summary.json").write_text(
|
||||
json.dumps({"total_update_s": {"count": 1, "mean": 0.3}, "peak_memory_allocated_bytes": 1024})
|
||||
)
|
||||
(profile_dir / "deterministic_forward.json").write_text(
|
||||
json.dumps(
|
||||
{"operator_fingerprint": "ops-fingerprint", "output_fingerprint": "output-fingerprint"}
|
||||
)
|
||||
)
|
||||
(profile_dir / "torch_tables" / "cpu_time_total.txt").write_text("cpu time table")
|
||||
return subprocess.CompletedProcess(cmd, returncode, "stdout ok", "")
|
||||
|
||||
return _fake_run
|
||||
|
||||
|
||||
def test_main_smoke_writes_row(monkeypatch, fake_args):
|
||||
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
|
||||
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
|
||||
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp))
|
||||
|
||||
assert mp.main() == 0
|
||||
|
||||
row_paths = list(fake_args.output_dir.rglob("profiling_row.json"))
|
||||
assert len(row_paths) == 1
|
||||
row = json.loads(row_paths[0].read_text())
|
||||
assert row["policy"] == "act"
|
||||
assert row["status"] == "success"
|
||||
assert row["git_commit"] == "deadbeef"
|
||||
assert row["git_ref"] == "codex/model-profiling"
|
||||
assert row["pr_number"] == 3389
|
||||
assert row["step_timing_summary"]["total_update_s"]["mean"] == 0.3
|
||||
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
|
||||
|
||||
|
||||
def test_main_records_publish_failure_without_failing(monkeypatch, fake_args):
|
||||
fake_args.publish = True
|
||||
fake_args.git_commit = "deadbeef"
|
||||
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
|
||||
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, write_artifacts=False))
|
||||
|
||||
def _fail_upload(**kwargs):
|
||||
resp = type("Resp", (), {"status_code": 403, "headers": {}, "request": None})()
|
||||
raise HfHubHTTPError("403 Forbidden: Authorization error.", response=resp)
|
||||
|
||||
monkeypatch.setattr(mp, "upload_profile_run", _fail_upload)
|
||||
|
||||
assert mp.main() == 0
|
||||
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
|
||||
assert row["status"] == "success"
|
||||
assert row["publish_status"] == "failed"
|
||||
assert "Authorization error" in row["publish_error"]
|
||||
|
||||
|
||||
def test_main_returns_nonzero_when_training_subprocess_fails(monkeypatch, fake_args):
|
||||
monkeypatch.setattr(mp, "parse_args", lambda: fake_args)
|
||||
monkeypatch.setattr(mp.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
|
||||
monkeypatch.setattr(mp.subprocess, "run", _stub_train_subprocess(mp, returncode=3))
|
||||
|
||||
assert mp.main() == 1
|
||||
|
||||
row = json.loads(next(fake_args.output_dir.rglob("profiling_row.json")).read_text())
|
||||
assert row["status"] == "failed"
|
||||
assert row["return_code"] == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrainingProfiler behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
||||
class _TrainingOnlyPolicy(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.forward_calls = 0
|
||||
|
||||
def forward(self, batch):
|
||||
self.forward_calls += 1
|
||||
assert self.training
|
||||
return batch["value"].sum(), {"value": batch["value"]}
|
||||
|
||||
dataset = [{"value": torch.tensor([1.0, 2.0])}]
|
||||
policy = _TrainingOnlyPolicy()
|
||||
policy.train()
|
||||
|
||||
mp.write_deterministic_forward_artifacts(
|
||||
policy=policy,
|
||||
dataset=dataset,
|
||||
batch_size=2,
|
||||
preprocessor=lambda b: b,
|
||||
output_dir=tmp_path,
|
||||
device_type="cpu",
|
||||
)
|
||||
|
||||
payload = json.loads((tmp_path / "deterministic_forward.json").read_text())
|
||||
assert policy.training is True
|
||||
assert policy.forward_calls == 1
|
||||
assert payload["reference_batch_size"] == 2
|
||||
assert "operator_fingerprint" in payload
|
||||
assert payload["outputs"]["loss"]["numel"] == 1
|
||||
|
||||
|
||||
def test_deterministic_forward_artifacts_infers_image_keys_without_dataset_meta(tmp_path):
|
||||
class _ImagePolicy(torch.nn.Module):
|
||||
def forward(self, batch):
|
||||
image = batch["observation.images.front"]
|
||||
assert image.dtype == torch.float32
|
||||
assert torch.all((image >= 0.0) & (image <= 1.0))
|
||||
return image.sum(), {"image": image}
|
||||
|
||||
dataset = [{"observation.images.front": torch.tensor([[[0, 255]]], dtype=torch.uint8)}]
|
||||
|
||||
mp.write_deterministic_forward_artifacts(
|
||||
policy=_ImagePolicy(),
|
||||
dataset=dataset,
|
||||
batch_size=1,
|
||||
preprocessor=lambda b: b,
|
||||
output_dir=tmp_path,
|
||||
device_type="cpu",
|
||||
)
|
||||
|
||||
payload = json.loads((tmp_path / "deterministic_forward.json").read_text())
|
||||
assert payload["outputs"]["loss"]["numel"] == 1
|
||||
assert payload["outputs"]["output_dict"]["image"]["dtype"] == "torch.float32"
|
||||
|
||||
|
||||
def test_training_profiler_section_records_forward_backward_optimizer(tmp_path):
|
||||
profiler = mp.TrainingProfiler(mode="summary", output_dir=tmp_path, device=torch.device("cpu"))
|
||||
profiler.start()
|
||||
for _ in range(3):
|
||||
with profiler.section("forward"):
|
||||
pass
|
||||
with profiler.section("backward"):
|
||||
pass
|
||||
with profiler.section("optimizer"):
|
||||
pass
|
||||
profiler.step(1, argparse.Namespace(update_s=0.5, dataloading_s=0.01))
|
||||
profiler.finalize()
|
||||
|
||||
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||
assert payload["forward_s"]["count"] == 3
|
||||
assert payload["backward_s"]["count"] == 3
|
||||
assert payload["optimizer_s"]["count"] == 3
|
||||
assert payload["total_update_s"]["mean"] == 0.5
|
||||
|
||||
|
||||
def test_training_profiler_accepts_metric_like_values(tmp_path):
|
||||
class _MetricLike:
|
||||
def __init__(self, v):
|
||||
self.val = v
|
||||
|
||||
profiler = mp.TrainingProfiler(mode="summary", output_dir=tmp_path, device=torch.device("cpu"))
|
||||
profiler.start()
|
||||
profiler.step(1, argparse.Namespace(update_s=_MetricLike(0.6), dataloading_s=_MetricLike(0.05)))
|
||||
profiler.finalize()
|
||||
|
||||
payload = json.loads((tmp_path / "step_timing_summary.json").read_text())
|
||||
assert payload["total_update_s"]["mean"] == 0.6
|
||||
assert payload["dataloading_s"]["mean"] == 0.05
|
||||
|
||||
|
||||
def test_profiler_device_time_uses_generic_attr_first():
|
||||
class _Event:
|
||||
self_device_time_total = 12.3456
|
||||
|
||||
assert mp._get_profiler_device_time_us(_Event()) == 12.3456
|
||||
232
tests/test_robomme_env.py
Normal file
232
tests/test_robomme_env.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for the RoboMME env wrapper and config.
|
||||
|
||||
RoboMME requires Linux + ManiSkill (Vulkan/SAPIEN), so tests that touch the
|
||||
env wrapper mock the ``robomme`` package. Tests that only exercise the
|
||||
dataclass config run without any mocking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _install_robomme_stub():
|
||||
"""Register a minimal stub for the ``robomme`` package on sys.modules."""
|
||||
stub = ModuleType("robomme")
|
||||
wrapper_stub = ModuleType("robomme.env_record_wrapper")
|
||||
|
||||
class FakeBuilder:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def make_env_for_episode(self, episode_idx: int, max_steps: int):
|
||||
env = MagicMock()
|
||||
obs = {
|
||||
"front_rgb_list": [np.zeros((256, 256, 3), dtype=np.uint8)],
|
||||
"wrist_rgb_list": [np.zeros((256, 256, 3), dtype=np.uint8)],
|
||||
"joint_state_list": [np.zeros(7, dtype=np.float32)],
|
||||
"gripper_state_list": [np.zeros(2, dtype=np.float32)],
|
||||
}
|
||||
env.reset.return_value = (obs, {"status": "ongoing", "task_goal": "pick the cube"})
|
||||
env.step.return_value = (obs, 0.0, False, False, {"status": "ongoing", "task_goal": ""})
|
||||
return env
|
||||
|
||||
wrapper_stub.BenchmarkEnvBuilder = FakeBuilder
|
||||
stub.env_record_wrapper = wrapper_stub
|
||||
sys.modules["robomme"] = stub
|
||||
sys.modules["robomme.env_record_wrapper"] = wrapper_stub
|
||||
|
||||
|
||||
def _uninstall_robomme_stub():
|
||||
sys.modules.pop("robomme", None)
|
||||
sys.modules.pop("robomme.env_record_wrapper", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config tests (no sim required)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_robomme_env_config_defaults():
|
||||
from lerobot.envs.configs import RoboMMEEnv
|
||||
|
||||
cfg = RoboMMEEnv()
|
||||
assert cfg.task == "PickXtimes"
|
||||
assert cfg.fps == 10
|
||||
assert cfg.episode_length == 300
|
||||
assert cfg.action_space == "joint_angle"
|
||||
assert cfg.dataset_split == "test"
|
||||
assert cfg.task_ids is None
|
||||
|
||||
|
||||
def test_robomme_env_config_type():
|
||||
from lerobot.envs.configs import RoboMMEEnv
|
||||
|
||||
cfg = RoboMMEEnv()
|
||||
assert cfg.type == "robomme"
|
||||
|
||||
|
||||
def test_robomme_features_map():
|
||||
from lerobot.envs.configs import RoboMMEEnv
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
cfg = RoboMMEEnv()
|
||||
assert cfg.features_map[ACTION] == ACTION
|
||||
assert cfg.features_map["pixels/image"] == f"{OBS_IMAGES}.image"
|
||||
assert cfg.features_map["pixels/wrist_image"] == f"{OBS_IMAGES}.wrist_image"
|
||||
assert cfg.features_map["agent_pos"] == OBS_STATE
|
||||
|
||||
|
||||
def test_robomme_features_action_dim_joint_angle():
|
||||
from lerobot.envs.configs import RoboMMEEnv
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
cfg = RoboMMEEnv(action_space="joint_angle")
|
||||
assert cfg.features[ACTION].shape == (8,)
|
||||
|
||||
|
||||
def test_robomme_features_action_dim_ee_pose():
|
||||
"""`ee_pose` uses a 7-D action; __post_init__ sets the correct shape."""
|
||||
from lerobot.envs.configs import RoboMMEEnv
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
cfg = RoboMMEEnv(action_space="ee_pose")
|
||||
assert cfg.features[ACTION].shape == (7,)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Obs conversion (pure Python, no sim)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_convert_obs_list_format():
|
||||
"""_convert_obs takes the last element from list-format obs fields and
|
||||
emits a nested ``pixels`` dict (image, wrist_image) plus ``agent_pos``.
|
||||
|
||||
The nested layout is required so ``preprocess_observation()`` in
|
||||
``envs/utils.py`` maps each camera to ``observation.images.<cam>``.
|
||||
"""
|
||||
_install_robomme_stub()
|
||||
try:
|
||||
from lerobot.envs.robomme import RoboMMEGymEnv
|
||||
|
||||
env = RoboMMEGymEnv.__new__(RoboMMEGymEnv)
|
||||
|
||||
front = np.full((256, 256, 3), 42, dtype=np.uint8)
|
||||
wrist = np.full((256, 256, 3), 7, dtype=np.uint8)
|
||||
joints = np.arange(7, dtype=np.float32)
|
||||
gripper = np.array([0.5, 0.5], dtype=np.float32)
|
||||
|
||||
obs_raw = {
|
||||
"front_rgb_list": [np.zeros_like(front), front],
|
||||
"wrist_rgb_list": [np.zeros_like(wrist), wrist],
|
||||
"joint_state_list": [np.zeros(7, dtype=np.float32), joints],
|
||||
"gripper_state_list": [np.zeros(2, dtype=np.float32), gripper],
|
||||
}
|
||||
|
||||
result = env._convert_obs(obs_raw)
|
||||
np.testing.assert_array_equal(result["pixels"]["image"], front)
|
||||
np.testing.assert_array_equal(result["pixels"]["wrist_image"], wrist)
|
||||
assert result["agent_pos"].shape == (8,)
|
||||
np.testing.assert_array_almost_equal(result["agent_pos"][:7], joints)
|
||||
assert result["agent_pos"][7] == gripper[0]
|
||||
finally:
|
||||
_uninstall_robomme_stub()
|
||||
|
||||
|
||||
def test_convert_obs_array_format():
|
||||
"""_convert_obs also handles non-list (direct array) obs."""
|
||||
_install_robomme_stub()
|
||||
try:
|
||||
from lerobot.envs.robomme import RoboMMEGymEnv
|
||||
|
||||
env = RoboMMEGymEnv.__new__(RoboMMEGymEnv)
|
||||
|
||||
front = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||
obs_raw = {
|
||||
"front_rgb_list": front,
|
||||
"wrist_rgb_list": front,
|
||||
"joint_state_list": np.zeros(7, dtype=np.float32),
|
||||
"gripper_state_list": np.zeros(2, dtype=np.float32),
|
||||
}
|
||||
result = env._convert_obs(obs_raw)
|
||||
assert result["pixels"]["image"].shape == (256, 256, 3)
|
||||
assert result["pixels"]["wrist_image"].shape == (256, 256, 3)
|
||||
assert result["agent_pos"].shape == (8,)
|
||||
finally:
|
||||
_uninstall_robomme_stub()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# create_robomme_envs (mocked sim)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_robomme_envs_returns_correct_structure():
|
||||
"""Single task -> {task_name: {task_id: VectorEnv}} with one entry per task_id."""
|
||||
_install_robomme_stub()
|
||||
try:
|
||||
from lerobot.envs.robomme import create_robomme_envs
|
||||
|
||||
env_cls = MagicMock(return_value=MagicMock())
|
||||
result = create_robomme_envs(
|
||||
task="PickXtimes",
|
||||
n_envs=1,
|
||||
task_ids=[0, 1],
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
assert "PickXtimes" in result
|
||||
assert 0 in result["PickXtimes"]
|
||||
assert 1 in result["PickXtimes"]
|
||||
assert env_cls.call_count == 2
|
||||
finally:
|
||||
_uninstall_robomme_stub()
|
||||
|
||||
|
||||
def test_create_robomme_envs_multi_task():
|
||||
"""Comma-separated task list produces one suite per task."""
|
||||
_install_robomme_stub()
|
||||
try:
|
||||
from lerobot.envs.robomme import create_robomme_envs
|
||||
|
||||
env_cls = MagicMock(return_value=MagicMock())
|
||||
result = create_robomme_envs(
|
||||
task="PickXtimes,BinFill,StopCube",
|
||||
n_envs=1,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
assert set(result.keys()) == {"PickXtimes", "BinFill", "StopCube"}
|
||||
finally:
|
||||
_uninstall_robomme_stub()
|
||||
|
||||
|
||||
def test_create_robomme_envs_raises_on_invalid_env_cls():
|
||||
_install_robomme_stub()
|
||||
try:
|
||||
import pytest
|
||||
|
||||
from lerobot.envs.robomme import create_robomme_envs
|
||||
|
||||
with pytest.raises(ValueError, match="env_cls must be a callable"):
|
||||
create_robomme_envs(task="PickXtimes", n_envs=1, env_cls=None)
|
||||
finally:
|
||||
_uninstall_robomme_stub()
|
||||
@@ -1,338 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Minimal tests for the rollout module's public API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import smoke tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_rollout_top_level_imports():
|
||||
import lerobot.rollout
|
||||
|
||||
for name in lerobot.rollout.__all__:
|
||||
assert hasattr(lerobot.rollout, name), f"Missing export: {name}"
|
||||
|
||||
|
||||
def test_inference_submodule_imports():
|
||||
import lerobot.rollout.inference
|
||||
|
||||
for name in lerobot.rollout.inference.__all__:
|
||||
assert hasattr(lerobot.rollout.inference, name), f"Missing export: {name}"
|
||||
|
||||
|
||||
def test_strategies_submodule_imports():
|
||||
import lerobot.rollout.strategies
|
||||
|
||||
for name in lerobot.rollout.strategies.__all__:
|
||||
assert hasattr(lerobot.rollout.strategies, name), f"Missing export: {name}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_strategy_config_types():
|
||||
from lerobot.rollout import (
|
||||
BaseStrategyConfig,
|
||||
DAggerStrategyConfig,
|
||||
HighlightStrategyConfig,
|
||||
SentryStrategyConfig,
|
||||
)
|
||||
|
||||
assert BaseStrategyConfig().type == "base"
|
||||
assert SentryStrategyConfig().type == "sentry"
|
||||
assert HighlightStrategyConfig().type == "highlight"
|
||||
assert DAggerStrategyConfig().type == "dagger"
|
||||
|
||||
|
||||
def test_dagger_config_invalid_input_device():
|
||||
from lerobot.rollout import DAggerStrategyConfig
|
||||
|
||||
with pytest.raises(ValueError, match="input_device must be 'keyboard' or 'pedal'"):
|
||||
DAggerStrategyConfig(input_device="joystick")
|
||||
|
||||
|
||||
def test_dagger_config_defaults():
|
||||
from lerobot.rollout import DAggerStrategyConfig
|
||||
|
||||
cfg = DAggerStrategyConfig()
|
||||
assert cfg.num_episodes is None
|
||||
assert cfg.record_autonomous is False
|
||||
assert cfg.input_device == "keyboard"
|
||||
|
||||
|
||||
def test_inference_config_types():
|
||||
from lerobot.rollout.inference import RTCInferenceConfig, SyncInferenceConfig
|
||||
|
||||
assert SyncInferenceConfig().type == "sync"
|
||||
|
||||
rtc = RTCInferenceConfig()
|
||||
assert rtc.type == "rtc"
|
||||
assert rtc.queue_threshold == 30
|
||||
assert rtc.rtc is not None
|
||||
|
||||
|
||||
def test_sentry_config_defaults():
|
||||
from lerobot.rollout import SentryStrategyConfig
|
||||
|
||||
cfg = SentryStrategyConfig()
|
||||
assert cfg.upload_every_n_episodes == 5
|
||||
assert cfg.target_video_file_size_mb is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RolloutRingBuffer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_ring_buffer_append_and_eviction():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=0.5, max_memory_mb=100.0, fps=10.0)
|
||||
# max_frames = 5
|
||||
for i in range(8):
|
||||
buf.append({"val": i})
|
||||
assert len(buf) == 5
|
||||
|
||||
|
||||
def test_ring_buffer_drain():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
for i in range(3):
|
||||
buf.append({"val": i})
|
||||
frames = buf.drain()
|
||||
assert len(frames) == 3
|
||||
assert len(buf) == 0
|
||||
assert buf.estimated_bytes == 0
|
||||
|
||||
|
||||
def test_ring_buffer_clear():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
buf.append({"val": 1})
|
||||
buf.clear()
|
||||
assert len(buf) == 0
|
||||
assert buf.estimated_bytes == 0
|
||||
|
||||
|
||||
def test_ring_buffer_tensor_bytes():
|
||||
from lerobot.rollout import RolloutRingBuffer
|
||||
|
||||
buf = RolloutRingBuffer(max_seconds=1.0, max_memory_mb=100.0, fps=10.0)
|
||||
t = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||
buf.append({"tensor": t})
|
||||
assert buf.estimated_bytes >= 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ThreadSafeRobot
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_thread_safe_robot_delegates():
|
||||
from lerobot.rollout import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
robot.connect()
|
||||
wrapper = ThreadSafeRobot(robot)
|
||||
|
||||
obs = wrapper.get_observation()
|
||||
assert "motor_1.pos" in obs
|
||||
assert "motor_2.pos" in obs
|
||||
assert "motor_3.pos" in obs
|
||||
|
||||
action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0}
|
||||
result = wrapper.send_action(action)
|
||||
assert result == action
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_thread_safe_robot_properties():
|
||||
from lerobot.rollout import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
|
||||
robot = MockRobot(MockRobotConfig(n_motors=3))
|
||||
robot.connect()
|
||||
wrapper = ThreadSafeRobot(robot)
|
||||
|
||||
assert wrapper.name == "mock_robot"
|
||||
assert "motor_1.pos" in wrapper.observation_features
|
||||
assert "motor_1.pos" in wrapper.action_features
|
||||
assert wrapper.is_connected is True
|
||||
assert wrapper.inner is robot
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategy factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_strategy_dispatches():
|
||||
from lerobot.rollout import BaseStrategyConfig, DAggerStrategyConfig, SentryStrategyConfig
|
||||
from lerobot.rollout.strategies import BaseStrategy, DAggerStrategy, SentryStrategy, create_strategy
|
||||
|
||||
assert isinstance(create_strategy(BaseStrategyConfig()), BaseStrategy)
|
||||
assert isinstance(create_strategy(SentryStrategyConfig()), SentryStrategy)
|
||||
assert isinstance(create_strategy(DAggerStrategyConfig()), DAggerStrategy)
|
||||
|
||||
|
||||
def test_create_strategy_unknown_raises():
|
||||
from lerobot.rollout.strategies import create_strategy
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.type = "bogus"
|
||||
with pytest.raises(ValueError, match="Unknown strategy type"):
|
||||
create_strategy(cfg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inference factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_inference_engine_sync():
|
||||
from lerobot.rollout.inference import SyncInferenceConfig, SyncInferenceEngine, create_inference_engine
|
||||
|
||||
engine = create_inference_engine(
|
||||
SyncInferenceConfig(),
|
||||
policy=MagicMock(),
|
||||
preprocessor=MagicMock(),
|
||||
postprocessor=MagicMock(),
|
||||
robot_wrapper=MagicMock(robot_type="mock"),
|
||||
hw_features={},
|
||||
dataset_features={},
|
||||
ordered_action_keys=["k"],
|
||||
task="test",
|
||||
fps=30.0,
|
||||
device="cpu",
|
||||
)
|
||||
assert isinstance(engine, SyncInferenceEngine)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_estimate_max_episode_seconds_no_video():
|
||||
from lerobot.rollout.strategies import estimate_max_episode_seconds
|
||||
|
||||
assert estimate_max_episode_seconds({}, fps=30.0) == 600.0
|
||||
|
||||
|
||||
def test_estimate_max_episode_seconds_with_video():
|
||||
from lerobot.rollout.strategies import estimate_max_episode_seconds
|
||||
|
||||
features = {"cam": {"dtype": "video", "shape": (3, 480, 640)}}
|
||||
result = estimate_max_episode_seconds(features, fps=30.0)
|
||||
assert result > 0
|
||||
# With a real camera, duration should differ from the fallback
|
||||
assert result != 600.0
|
||||
|
||||
|
||||
def test_safe_push_to_hub():
|
||||
from lerobot.rollout.strategies import safe_push_to_hub
|
||||
|
||||
ds = MagicMock()
|
||||
ds.num_episodes = 0
|
||||
assert safe_push_to_hub(ds) is False
|
||||
ds.push_to_hub.assert_not_called()
|
||||
|
||||
ds.num_episodes = 5
|
||||
assert safe_push_to_hub(ds, tags=["test"]) is True
|
||||
ds.push_to_hub.assert_called_once_with(tags=["test"], private=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DAgger state machine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_dagger_full_transition_cycle():
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
|
||||
# AUTONOMOUS -> PAUSED
|
||||
events.request_transition("pause_resume")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.AUTONOMOUS, DAggerPhase.PAUSED)
|
||||
|
||||
# PAUSED -> CORRECTING
|
||||
events.request_transition("correction")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.CORRECTING)
|
||||
|
||||
# CORRECTING -> PAUSED
|
||||
events.request_transition("correction")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.CORRECTING, DAggerPhase.PAUSED)
|
||||
|
||||
# PAUSED -> AUTONOMOUS
|
||||
events.request_transition("pause_resume")
|
||||
old, new = events.consume_transition()
|
||||
assert (old, new) == (DAggerPhase.PAUSED, DAggerPhase.AUTONOMOUS)
|
||||
|
||||
|
||||
def test_dagger_invalid_transition_ignored():
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
events.request_transition("correction") # Not valid from AUTONOMOUS
|
||||
assert events.consume_transition() is None
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
|
||||
|
||||
def test_dagger_events_reset():
|
||||
from lerobot.rollout.strategies import DAggerEvents, DAggerPhase
|
||||
|
||||
events = DAggerEvents()
|
||||
events.request_transition("pause_resume")
|
||||
events.consume_transition() # -> PAUSED
|
||||
events.upload_requested.set()
|
||||
events.reset()
|
||||
assert events.phase == DAggerPhase.AUTONOMOUS
|
||||
assert not events.upload_requested.is_set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_rollout_context_fields():
|
||||
from lerobot.rollout import RolloutContext
|
||||
|
||||
field_names = {f.name for f in dataclasses.fields(RolloutContext)}
|
||||
assert field_names == {"runtime", "hardware", "policy", "processors", "data"}
|
||||
@@ -24,7 +24,7 @@ import pytest
|
||||
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
from lerobot.utils.process import ProcessSignalHandler # noqa: E402
|
||||
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
|
||||
|
||||
|
||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||
|
||||
Reference in New Issue
Block a user