Compare commits

..

69 Commits

Author SHA1 Message Date
Michel Aractingi
ea87324725 enhance doc and add images 2025-11-19 10:03:15 +01:00
Michel Aractingi
611159f8bb add docs for rtc 2025-11-19 10:03:13 +01:00
Eugene Mironov
b7b0ac2456 fixup! Fix tests 2025-11-19 12:04:00 +07:00
Eugene Mironov
59a52e557c Fix tests 2025-11-19 03:10:27 +07:00
Eugene Mironov
8008dbb02c Update images 2025-11-19 03:02:26 +07:00
Eugene Mironov
045f9c02f7 Extract simulator logic from eval_with real robot and add proper headers to files 2025-11-18 21:30:02 +07:00
Eugene Mironov
8eb28fd653 fixup! fixup! Fixup eval with real robot 2025-11-18 21:30:02 +07:00
Eugene Mironov
32f4336467 fixup! Fixup eval with real robot 2025-11-18 21:30:02 +07:00
Eugene Mironov
eb29f12ce2 Fixup eval with real robot 2025-11-18 21:30:02 +07:00
Eugene Mironov
9a38c5f4d2 fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization 2025-11-18 21:30:02 +07:00
Eugene Mironov
5ff66e498f Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
dfa1e76082 Fix SmolVLA init_rtc_processor to use getattr instead of direct model access
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
5fd1d8bce9 Fix PI0.5 init_rtc_processor to use getattr instead of direct model access
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
4bacf70782 Add RTC initialization tests without config for PI0.5 and SmolVLA
Add test_pi05_rtc_initialization_without_rtc_config and
test_smolvla_rtc_initialization_without_rtc_config to verify that
policies can initialize without RTC config and that _rtc_enabled()
returns False in this case.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
8858e0cbf1 fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() 2025-11-18 21:30:02 +07:00
Eugene Mironov
43e631122c fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled() 2025-11-18 21:30:02 +07:00
Eugene Mironov
77fb71a903 Fix test to use _rtc_enabled() instead of is_rtc_enabled()
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
b4f67373e9 fixup! Add one more test 2025-11-18 21:30:02 +07:00
Eugene Mironov
7e820bc1e3 Add one more test 2025-11-18 21:30:02 +07:00
Eugene Mironov
691503d099 fixup! fixup! Add tests for flow matching models with RTC 2025-11-18 21:30:02 +07:00
Eugene Mironov
a14e8a65cd fixup! Add tests for flow matching models with RTC 2025-11-18 21:30:02 +07:00
Eugene Mironov
a59ebab66b Add tests for flow matching models with RTC 2025-11-18 21:30:02 +07:00
Eugene Mironov
60e1a0de0f Add tests for modeling_rtc 2025-11-18 21:30:02 +07:00
Eugene Mironov
da92b0169e Fix tests 2025-11-18 21:30:02 +07:00
Eugene Mironov
36dc58d05e Silent validation 2025-11-18 21:30:02 +07:00
Eugene Mironov
dd39d7a037 Update README 2025-11-18 21:30:02 +07:00
Eugene Mironov
70d5ca387e Add validatio at the end 2025-11-18 21:30:02 +07:00
Eugene Mironov
043432254e Add more tests 2025-11-18 21:30:02 +07:00
Eugene Mironov
7185a5350e Small fixes 2025-11-18 21:30:02 +07:00
Eugene Mironov
e5c2a0a892 Add workable flow 2025-11-18 21:30:02 +07:00
Eugene Mironov
a230e7424d fixup! fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 2025-11-18 21:30:02 +07:00
Eugene Mironov
3c484a77f6 fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05 2025-11-18 21:30:02 +07:00
Eugene Mironov
dd0bf8a86e fixup! fixup! fixup! Turn off compilation for pi0/pi05 2025-11-18 21:30:02 +07:00
Eugene Mironov
0da9976c5f fixup! fixup! Turn off compilation for pi0/pi05 2025-11-18 21:30:02 +07:00
Eugene Mironov
755ba419f6 fixup! Turn off compilation for pi0/pi05 2025-11-18 21:30:02 +07:00
Eugene Mironov
2dd7c2a7ea Turn off compilation for pi0/pi05 2025-11-18 21:30:02 +07:00
Eugene Mironov
07550ff0ef fixup! Pi0 eval dataset 2025-11-18 21:30:02 +07:00
Eugene Mironov
577ab57bab Pi0 eval dataset 2025-11-18 21:30:02 +07:00
Eugene Mironov
6684c68612 Pi0 2025-11-18 21:30:02 +07:00
Eugene Mironov
687484a864 Add RTC to PI0 2025-11-18 21:30:02 +07:00
Eugene Mironov
4739ef9da3 Fix compilation 2025-11-18 21:30:02 +07:00
Eugene Mironov
d9e72662c1 Debug 2025-11-18 21:30:02 +07:00
Eugene Mironov
9354d7ef10 Experiemnt with late detach 2025-11-18 21:30:02 +07:00
Eugene Mironov
16127642d4 fixup! Add matplotliv to dev 2025-11-18 21:30:02 +07:00
Eugene Mironov
495176f252 Add matplotliv to dev 2025-11-18 21:30:02 +07:00
Eugene Mironov
6aa940346d delete policies 2025-11-18 21:30:02 +07:00
Eugene Mironov
6fdee95923 Add torch compilation for eval_dataset 2025-11-18 21:30:02 +07:00
Eugene Mironov
c5b246f57c Drop not required methods 2025-11-18 21:30:02 +07:00
Eugene Mironov
3d3cfcf751 Fix tests 2025-11-18 21:30:02 +07:00
Eugene Mironov
a29e8a6737 Add tests for tracker 2025-11-18 21:30:02 +07:00
Eugene Mironov
e758703f9a Right kwargs for the policy 2025-11-18 21:30:02 +07:00
Eugene Mironov
64c6b89c40 Fix traacking 2025-11-18 21:30:02 +07:00
Eugene Mironov
ab5cae6547 fixup! fixup! fixup! Improve visualization: separate correction plot and fix axis scaling 2025-11-18 21:30:02 +07:00
Eugene Mironov
b5ff2b38df fixup! fixup! Improve visualization: separate correction plot and fix axis scaling 2025-11-18 21:30:02 +07:00
Eugene Mironov
7dae02cec1 fixup! Improve visualization: separate correction plot and fix axis scaling 2025-11-18 21:30:02 +07:00
Eugene Mironov
b54042a98f Improve visualization: separate correction plot and fix axis scaling
Changes:
- Create separate figure for correction data instead of overlaying on v_t
- Add _rescale_axes helper method to properly scale all axes
- Add 10% margin to y-axis for better visualization
- Fix v_t chart vertical compression issue

Benefits:
- Clearer v_t plot without correction overlay
- Better axis scaling with proper margins
- Separate correction figure for focused analysis
- Improved readability of all denoising visualizations

Output files:
- denoising_xt_comparison.png (x_t trajectories)
- denoising_vt_comparison.png (v_t velocity - now cleaner)
- denoising_correction_comparison.png (NEW - separate corrections)
- denoising_x1t_comparison.png (x1_t state with error)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
0385ccdd05 fixup! Refactor plotting loging 2025-11-18 21:30:02 +07:00
Eugene Mironov
bd85ea905f Refactor plotting loging 2025-11-18 21:30:02 +07:00
Eugene Mironov
a3d32cf123 Move plotting logic from modeling_smolvla to eval_dataset script
Refactor to improve separation of concerns:

modeling_smolvla.py changes:
- Remove all plotting logic from sample_actions method
- Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters
- Remove matplotlib and RTCDebugVisualizer imports
- Remove viz_fig, viz_axs, denoise_step_counter instance variables
- Simplify denoising loop to only track data in rtc_processor

eval_dataset.py changes:
- Add _plot_denoising_steps_from_tracker helper method
- Retrieve debug steps from tracker after inference
- Plot x_t, v_t, x1_t, correction, and error from tracker data
- Enable debug tracking (cfg.rtc.debug = True) for visualization
- Remove viz axes parameters from predict_action_chunk calls

modeling_rtc.py changes:
- Remove v_t from track() call (handled by user change)

Benefits:
- Cleaner modeling code focused on inference
- Evaluation script owns all visualization logic
- Better separation of concerns
- Tracker is single source of truth for debug data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
4575ebdfa5 Refactor SmolVLA plotting to use tracker data instead of local variables
Remove local tracking variables (correction, x1_t, error) from the
denoising loop and instead retrieve plotting data from the RTC tracker
after each denoise step. This makes the code cleaner and uses the
tracker as the single source of truth for debug/visualization data.

Changes:
- Remove initialization of correction, x1_t, error before denoising loop
- After each Euler step, retrieve most recent debug step from tracker
- Extract correction, x1_t, err from debug step for plotting
- Update tracking condition to use is_debug_enabled() method

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
dff5e8871c Fix logging buffering and enable tracking when RTC config provided
- Add force=True to logging.basicConfig to override existing configuration
- Enable line buffering for stdout/stderr for real-time log output
- Modify init_rtc_processor to create processor when rtc_config exists
  even if RTC is disabled, allowing tracking of denoising data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
c066eb3a13 fixup! Use output_dir for saving all evaluation images 2025-11-18 21:30:02 +07:00
Eugene Mironov
e8dd5343ab Use output_dir for saving all evaluation images
Update eval_dataset.py to save all comparison images to the
configured output_dir instead of the current directory. This provides
better organization and allows users to specify where outputs should be
saved.

Changes:
- Add os import at top level
- Create output_dir at start of run_evaluation()
- Save all comparison images to output_dir
- Remove duplicate os imports
- Update init_rtc_processor() docstring to be more concise

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
01f76e94a3 Rename track_debug method to track
Simplify the method name from track_debug to just track for better
readability and consistency. The method already has clear documentation
about its debug tracking purpose.

Changes:
- Rename RTCProcessor.track_debug() to track()
- Update all call sites in modeling_smolvla.py and modeling_rtc.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
70548e55f0 Refactor RTC enabled checks to use _rtc_enabled helper
Add _rtc_enabled() helper method in VLAFlowMatching class to simplify
and clean up RTC enabled checks throughout the code. This reduces
code duplication and improves readability.

Changes:
- Add _rtc_enabled() method in VLAFlowMatching
- Replace verbose rtc_config checks with _rtc_enabled() calls
- Maintain exact same functionality with cleaner code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
455d347b49 Add RTCConfig field to SmolVLAConfig
Add rtc_config as an optional field in SmolVLAConfig to properly
support Real-Time Chunking configuration. This replaces the previous
getattr() workarounds with direct attribute access, making the code
cleaner and more maintainable.

Changes:
- Import RTCConfig in configuration_smolvla.py
- Add rtc_config: RTCConfig | None = None field
- Revert getattr() calls to direct attribute access in modeling_smolvla.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
c835f03478 fixup! Fix rtc_config attribute access in SmolVLA 2025-11-18 21:30:02 +07:00
Eugene Mironov
48849c543d Fix rtc_config attribute access in SmolVLA
Use getattr() to safely check for rtc_config attribute existence
instead of direct attribute access. This fixes AttributeError when
loading policies without rtc_config in their config.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
Eugene Mironov
2afe107583 Add Real-Time Chunking (RTC) support for flow matching models
Implement Real-Time Chunking (RTC) for action chunking policies using flow
matching denoising. RTC enables smooth action transitions between consecutive
chunks by using prefix guidance during denoising.

Key features:
- RTCProcessor class with denoise_step method for RTC guidance
- Tracker system for debug tracking using time-based dictionary storage
- RTCDebugVisualizer with comprehensive visualization utilities
- Integration with SmolVLA policy for flow matching models
- Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP)
- Configurable execution horizon and max guidance weight
- Example scripts for dataset evaluation and real-time control

Technical details:
- Uses autograd-based gradient computation for RTC corrections
- Time-based tracking eliminates duplicate step issues
- Proxy methods in RTCProcessor for cleaner API
- Full integration with LeRobot's policy and dataset systems

Files added/modified:
- src/lerobot/configs/types.py: Add RTCAttentionSchedule enum
- src/lerobot/policies/rtc/: Core RTC implementation
  - configuration_rtc.py: RTC configuration
  - modeling_rtc.py: RTCProcessor with denoise_step
  - debug_handler.py: Tracker for debug information
  - debug_visualizer.py: Visualization utilities
- src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration
- examples/rtc/: Example scripts and evaluation tools

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-18 21:30:02 +07:00
111 changed files with 2475 additions and 13403 deletions

View File

@@ -31,8 +31,7 @@ jobs:
name: Upload Preview and Comment
if: >
github.event.workflow_run.event == 'pull_request' &&
github.event.workflow_run.conclusion == 'success' &&
github.repository == 'huggingface/lerobot'
github.event.workflow_run.conclusion == 'success'
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
with:
package_name: lerobot

View File

@@ -42,9 +42,7 @@ jobs:
# This job builds and deploys the official documentation.
build_main_docs:
name: Build Main Docs
if: >
(github.event_name == 'push' || github.event_name == 'workflow_dispatch') &&
github.repository == 'huggingface/lerobot'
if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
permissions:
contents: read
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
@@ -60,7 +58,7 @@ jobs:
# The result of this job triggers the 'Upload PR Documentation' workflow.
build_pr_docs:
name: Build PR Docs
if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot'
if: github.event_name == 'pull_request'
permissions:
contents: read
pull-requests: write

View File

@@ -45,6 +45,7 @@ permissions:
env:
UV_VERSION: "0.8.0"
PYTHON_VERSION: "3.10"
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
concurrency:
@@ -59,19 +60,12 @@ jobs:
runs-on: ubuntu-latest
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
lfs: true
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
- name: Setup /mnt storage
run: sudo chown -R $USER:$USER /mnt
# TODO(Steven): Evaluate the need of these dependencies
- name: Install apt dependencies
run: |

View File

@@ -58,19 +58,12 @@ jobs:
github.event_name == 'workflow_dispatch'
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
- name: Setup /mnt storage
run: sudo chown -R $USER:$USER /mnt
- name: Install apt dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \

View File

@@ -43,7 +43,6 @@ jobs:
name: Build CPU Docker for Nightly
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
steps:
@@ -78,7 +77,6 @@ jobs:
name: Build GPU Docker for Nightly
runs-on:
group: aws-general-8-plus
if: github.repository == 'huggingface/lerobot'
outputs:
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
steps:

View File

@@ -29,7 +29,6 @@ jobs:
build-and-publish:
name: Build and publish Python distributions
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
outputs:
version: ${{ steps.extract_info.outputs.tag_version }}
permissions:

View File

@@ -45,7 +45,6 @@ jobs:
stale:
name: Close Stale Issues and PRs
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
permissions:
actions: write
contents: write # only for delete-branch option

View File

@@ -43,22 +43,14 @@ jobs:
full-tests:
name: Full Unbound Tests
runs-on: ubuntu-latest
if: github.repository == 'huggingface/lerobot'
env:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
steps:
- uses: actions/checkout@v4
with:
lfs: true
persist-credentials: false
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
- name: Setup /mnt storage
run: sudo chown -R $USER:$USER /mnt
- name: Install apt dependencies
run: |
sudo apt-get update && sudo apt-get install -y build-essential \

View File

@@ -0,0 +1,94 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time
from contextlib import ContextDecorator
class TimeBenchmark(ContextDecorator):
"""
Measures execution time using a context manager or decorator.
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
environments.
Args:
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
to False.
Examples:
Using as a context manager:
>>> benchmark = TimeBenchmark()
>>> with benchmark:
... time.sleep(1)
>>> print(f"Block took {benchmark.result:.4f} seconds")
Block took approximately 1.0000 seconds
Using with multithreading:
```python
import threading
benchmark = TimeBenchmark()
def context_manager_example():
with benchmark:
time.sleep(0.01)
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
threads = []
for _ in range(3):
t1 = threading.Thread(target=context_manager_example)
threads.append(t1)
for t in threads:
t.start()
for t in threads:
t.join()
```
Expected output:
Block took approximately 10.00 milliseconds
Block took approximately 10.00 milliseconds
Block took approximately 10.00 milliseconds
"""
def __init__(self, print=False):
self.local = threading.local()
self.print_time = print
def __enter__(self):
self.local.start_time = time.perf_counter()
return self
def __exit__(self, *exc):
self.local.end_time = time.perf_counter()
self.local.elapsed_time = self.local.end_time - self.local.start_time
if self.print_time:
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
return False
@property
def result(self):
return getattr(self.local, "elapsed_time", None)
@property
def result_ms(self):
return self.result * 1e3

View File

@@ -0,0 +1,102 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Capture video feed from a camera as raw images."""
import argparse
import datetime as dt
import os
import time
from pathlib import Path
import cv2
import rerun as rr
# see https://rerun.io/docs/howto/visualization/limit-ram
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
rr.init("lerobot_capture_camera_feed")
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
now = dt.datetime.now()
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
if not capture_dir.exists():
capture_dir.mkdir(parents=True, exist_ok=True)
# Opens the default webcam
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("Error: Could not open video stream.")
return
cap.set(cv2.CAP_PROP_FPS, fps)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
frame_index = 0
start_time = time.time()
while time.time() - start_time < duration:
ret, frame = cap.read()
if not ret:
print("Error: Could not read frame.")
break
rr.log("video/stream", rr.Image(frame), static=True)
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
frame_index += 1
# Release the capture
cap.release()
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--output-dir",
type=Path,
default=Path("outputs/cam_capture/"),
help="Directory where the capture images are written. A subfolder named with the current date & time will be created inside it for each capture.",
)
parser.add_argument(
"--fps",
type=int,
default=30,
help="Frames Per Second of the capture.",
)
parser.add_argument(
"--width",
type=int,
default=1280,
help="Width of the captured images.",
)
parser.add_argument(
"--height",
type=int,
default=720,
help="Height of the captured images.",
)
parser.add_argument(
"--duration",
type=int,
default=20,
help="Duration in seconds for which the video stream should be captured.",
)
args = parser.parse_args()
display_and_save_video_stream(**vars(args))

View File

@@ -21,13 +21,11 @@ See the provided README.md or run `python benchmark/video/run_video_benchmark.py
import argparse
import datetime as dt
import itertools
import random
import shutil
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import einops
import numpy as np
@@ -37,13 +35,13 @@ import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from tqdm import tqdm
from benchmarks.video.benchmark import TimeBenchmark
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import (
decode_video_frames,
decode_video_frames_torchvision,
encode_video_frames,
)
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.utils import TimerManager
BASE_ENCODING = OrderedDict(
[
@@ -88,7 +86,7 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
frames = []
for ts in timestamps:
idx = int(ts * fps)
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
frame = torch.from_numpy(np.array(frame))
frame = frame.type(torch.float32) / 255
frame = einops.rearrange(frame, "h w c -> c h w")
@@ -99,21 +97,21 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return
save_dir.mkdir(parents=True, exist_ok=True)
for i, ts in enumerate(timestamps):
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
episode_index = 0
ep_num_images = dataset.meta.episodes["length"][episode_index]
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
return
imgs_dir.mkdir(parents=True, exist_ok=True)
@@ -127,7 +125,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
if i >= ep_num_images - 1:
break
@@ -151,6 +149,18 @@ def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> lis
return [idx / fps for idx in frame_indexes]
def decode_video_frames(
video_path: str,
timestamps: list[float],
tolerance_s: float,
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
else:
raise NotImplementedError(backend)
def benchmark_decoding(
imgs_dir: Path,
video_path: Path,
@@ -162,8 +172,8 @@ def benchmark_decoding(
num_workers: int = 4,
save_frames: bool = False,
) -> dict:
def process_sample(sample: int, lock: Lock):
time_benchmark = TimerManager(log=False)
def process_sample(sample: int):
time_benchmark = TimeBenchmark()
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
num_frames = len(timestamps)
result = {
@@ -172,13 +182,13 @@ def benchmark_decoding(
"mse_values": [],
}
with time_benchmark, lock:
with time_benchmark:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
with time_benchmark:
original_frames = load_original_frames(imgs_dir, timestamps, fps)
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
@@ -205,10 +215,8 @@ def benchmark_decoding(
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
# Use a single shared lock for all worker threads
shared_lock = Lock()
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
@@ -350,27 +358,24 @@ def main(
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for duet in [
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
for unique_combination in itertools.product(*encoding_benchmarks.values())
]:
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
for key, value in duet.items():
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
imgs_dir,
encoding_cfg,
decoding_benchmarks,
num_samples,
num_workers,
save_frames,
)
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
imgs_dir,
encoding_cfg,
decoding_benchmarks,
num_samples,
num_workers,
save_frames,
)
# Save intermediate results
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
@@ -404,9 +409,9 @@ if __name__ == "__main__":
nargs="*",
default=[
"lerobot/pusht_image",
"lerobot/aloha_mobile_shrimp_image",
"lerobot/paris_street",
"lerobot/kitchen",
"aliberts/aloha_mobile_shrimp_image",
"aliberts/paris_street",
"aliberts/kitchen",
],
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
)
@@ -414,7 +419,7 @@ if __name__ == "__main__":
"--vcodec",
type=str,
nargs="*",
default=["h264", "hevc", "libsvtav1"],
default=["libx264", "hevc", "libsvtav1"],
help="Video codecs to be tested",
)
parser.add_argument(
@@ -463,7 +468,7 @@ if __name__ == "__main__":
"--backends",
type=str,
nargs="*",
default=["torchcodec", "pyav"],
default=["pyav", "video_reader"],
help="Torchvision decoding backend to be tested.",
)
parser.add_argument(

View File

@@ -9,8 +9,6 @@
title: Imitation Learning for Robots
- local: cameras
title: Cameras
- local: bring_your_own_policies
title: Bring Your Own Policies
- local: integrate_hardware
title: Bring Your Own Hardware
- local: hilserl
@@ -39,8 +37,6 @@
title: π₀.₅ (Pi05)
- local: groot
title: NVIDIA GR00T N1.5
- local: xvla
title: X-VLA
title: "Policies"
- sections:
- local: async
@@ -51,8 +47,8 @@
- sections:
- local: envhub
title: Environments from the Hub
- local: envhub_leisaac
title: Control & Train Robots in Sim (LeIsaac)
- local: il_sim
title: Imitation Learning in Sim
- local: libero
title: Using Libero
- local: metaworld
@@ -67,8 +63,6 @@
title: Implement your own processor
- local: processors_robots_teleop
title: Processors for Robots and Teleoperators
- local: env_processor
title: Environment Processors
title: "Robot Processors"
- sections:
- local: so101
@@ -83,19 +77,11 @@
title: Hope Jr
- local: reachy2
title: Reachy 2
- local: unitree_g1
title: Unitree G1
- local: earthrover_mini_plus
title: Earth Rover Mini
title: "Robots"
- sections:
- local: phone_teleop
title: Phone
title: "Teleoperators"
- sections:
- local: torch_accelerators
title: PyTorch accelerators
title: "Supported Hardware"
- sections:
- local: notebooks
title: Notebooks

View File

@@ -196,7 +196,7 @@ client_cfg = RobotClientConfig(
server_address="localhost:8080",
policy_device="mps",
policy_type="smolvla",
pretrained_name_or_path="<user>/smolvla_async",
pretrained_name_or_path="fracapuano/smolvla_async",
chunk_size_threshold=0.5,
actions_per_chunk=50, # make sure this is less than the max actions of the policy
)
@@ -278,7 +278,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
3. **Adjust `chunk_size_threshold`**.
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
<p align="center">
<img
@@ -289,7 +289,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
<p align="center">
<i>
The action queue size is plotted at runtime when the
`--debug_visualize_queue_size` flag is passed, for various levels of
`--debug-visualize-queue-size` flag is passed, for various levels of
`chunk_size_threshold` (`g` in the SmolVLA paper).
</i>
</p>

View File

@@ -1,175 +0,0 @@
# Bring Your Own Policies
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
## Step 1: Create a Policy Package
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
### Package Structure
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
```bash
lerobot_policy_my_custom_policy/
├── pyproject.toml
└── src/
└── lerobot_policy_my_custom_policy/
├── __init__.py
├── configuration_my_custom_policy.py
├── modeling_my_custom_policy.py
└── processor_my_custom_policy.py
```
### Package Configuration
Set up your `pyproject.toml`:
```toml
[project]
name = "lerobot_policy_my_custom_policy"
version = "0.1.0"
dependencies = [
# your policy-specific dependencies
]
requires-python = ">= 3.11"
[build-system]
build-backend = # your-build-backend
requires = # your-build-system
```
## Step 2: Define the Policy Configuration
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
```python
# configuration_my_custom_policy.py
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("my_custom_policy")
@dataclass
class MyCustomPolicyConfig(PreTrainedConfig):
"""Configuration class for MyCustomPolicy.
Args:
n_obs_steps: Number of observation steps to use as input
horizon: Action prediction horizon
n_action_steps: Number of action steps to execute
hidden_dim: Hidden dimension for the policy network
# Add your policy-specific parameters here
"""
# ...PreTrainedConfig fields...
pass
def __post_init__(self):
super().__post_init__()
# Add any validation logic here
def validate_features(self) -> None:
"""Validate input/output feature compatibility."""
# Implement validation logic for your policy's requirements
pass
```
## Step 3: Implement the Policy Class
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
```python
# modeling_my_custom_policy.py
import torch
import torch.nn as nn
from typing import Dict, Any
from lerobot.policies.pretrained import PreTrainedPolicy
from .configuration_my_custom_policy import MyCustomPolicyConfig
class MyCustomPolicy(PreTrainedPolicy):
config_class = MyCustomPolicyConfig
name = "my_custom_policy"
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
super().__init__(config, dataset_stats)
...
```
## Step 4: Add Data Processors
Create processor functions:
```python
# processor_my_custom_policy.py
from typing import Dict, Any
import torch
def make_my_custom_policy_pre_post_processors(
config,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create preprocessing and postprocessing functions for your policy."""
pass # Define your preprocessing and postprocessing logic here
```
## Step 5: Package Initialization
Expose your classes in the package's `__init__.py`:
```python
# __init__.py
"""Custom policy package for LeRobot."""
try:
import lerobot # noqa: F401
except ImportError:
raise ImportError(
"lerobot is not installed. Please install lerobot to use this policy package."
)
from .configuration_my_custom_policy import MyCustomPolicyConfig
from .modeling_my_custom_policy import MyCustomPolicy
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
__all__ = [
"MyCustomPolicyConfig",
"MyCustomPolicy",
"make_my_custom_policy_pre_post_processors",
]
```
## Step 6: Installation and Usage
### Install Your Policy Package
```bash
cd lerobot_policy_my_custom_policy
pip install -e .
# Or install from PyPI if published
pip install lerobot_policy_my_custom_policy
```
### Use Your Policy
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
```bash
lerobot-train \
--policy.type my_custom_policy \
--env.type pusht \
--steps 200000
```
## Examples and Community Contributions
Check out these example policy implementations:
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
Share your policy implementations with the community! 🤗

View File

@@ -1,206 +0,0 @@
# EarthRover Mini Plus
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
## What You Need
### Hardware
- EarthRover Mini robot
- Computer with Python 3.10 or newer
- Internet connection
### Setting Up the Frodobots SDK
The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how:
1. Download and install the SDK:
```bash
git clone https://github.com/Frodobots/earth-rovers-sdk.git
cd earth-rovers-sdk
pip install -r requirements.txt
```
2. Start the SDK:
```bash
hypercorn main:app --reload
```
3. Open your web browser and go to `http://localhost:8000`, then click "Join"
The SDK gives you:
- Live video from front and rear cameras
> [!IMPORTANT]
> The SDK must be running before you can use the robot.
## Install LeRobot
Follow our [Installation Guide](./installation) to install LeRobot.
In addition to the base installation, install the EarthRover Mini dependencies:
```bash
pip install -e .
```
## How It Works
The robot uses the internet to communicate:
- **Movement commands**: Sent through the SDK
- **Camera video**: Received from the SDK
- **Robot info**: Battery, location, speed from the SDK
You don't need to plug anything in - it all works through the SDK.
## Calibration
No calibration needed! The robot is ready to use as soon as the SDK is running.
## Controlling the Robot
You control the robot using your keyboard - just like playing a video game with WASD keys.
### Keyboard Controls
| Key | Action |
| --- | -------------------------------- |
| W | Move forward |
| S | Move backward |
| A | Turn left (with forward motion) |
| D | Turn right (with forward motion) |
| Q | Rotate left in place |
| E | Rotate right in place |
| X | Stop all movement |
| +/= | Increase speed |
| - | Decrease speed |
| ESC | Disconnect |
### Speed Settings
You can adjust how fast the robot moves:
- **Forward/backward speed**: Default is full speed (1.0)
- **Turning speed**: Default is full speed (1.0)
- **Speed changes**: Use +/- keys to adjust by 0.1 each time
### Try It Out
Test driving the robot before recording data:
```python
from lerobot.robots.earthrover_mini_plus import EarthRoverMiniPlus, EarthRoverMiniPlusConfig
from lerobot.teleoperators.keyboard import KeyboardRoverTeleop, KeyboardRoverTeleopConfig
# Initialize robot
robot_config = EarthRoverMiniPlusConfig()
robot = EarthRoverMiniPlus(robot_config)
# Initialize teleoperator
teleop_config = KeyboardRoverTeleopConfig(
linear_speed=1.0,
angular_speed=1.0,
speed_increment=0.1
)
teleop = KeyboardRoverTeleop(teleop_config)
# Connect
robot.connect()
teleop.connect()
# Teleoperate (use keyboard controls)
try:
while True:
action = teleop.get_action()
robot.send_action(action)
except KeyboardInterrupt:
pass
finally:
robot.disconnect()
teleop.disconnect()
```
> [!TIP]
> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
## Recording Data
Once you can drive the robot well, you can start recording data to train AI models. The system records:
- **What you do**: How you move the robot (forward, backward, turning)
- **What the robot sees**:
- Videos from both cameras
- Robot speed and direction
- Battery level and location
- GPS position and signal
- Other sensor data
- **When it happened**: Timestamps for everything
### Setting Up Hugging Face
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
```
Store your Hugging Face username:
```bash
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```
### Start Recording
Use the standard recording command:
```bash
python src/lerobot/scripts/lerobot_record.py \
--robot.type=earthrover_mini_plus \
--teleop.type=keyboard_rover \
--dataset.repo_id=your_username/dataset_name \
--dataset.num_episodes=2 \
--dataset.fps=10 \
--dataset.single_task="Navigate around obstacles" \
--display_data=true
```
Replace `your_username/dataset_name` with your Hugging Face username and a name for your dataset.
### What Gets Saved
Your dataset includes:
**Your Actions (2 things)**:
- How much you moved forward/backward
- How much you turned left/right
**Robot Observations (12 things)**:
- Front camera video
- Rear camera video
- Current speed
- Battery level
- Which way the robot is facing
- GPS location (latitude, longitude, signal strength)
- Network signal strength
- Vibration level
- Lamp status (on/off)
### Where Your Data Goes
On your computer: `~/.cache/huggingface/lerobot/{repo-id}`
After recording, your data automatically uploads to your Hugging Face page:
```bash
echo https://huggingface.co/datasets/${HF_USER}/earthrover-navigation
```
Your dataset will be tagged with `LeRobot` for community discovery.

View File

@@ -1,418 +0,0 @@
# Environment Processors
Environment processors are a critical layer in LeRobot's data processing architecture that handle **environment-specific** transformations, separate from policy-specific processing. This separation of concerns enables cleaner code, better modularity, and easier experimentation with different environments and policies.
## Why Environment Processors?
When working with different robot environments (LIBERO, MetaWorld, Aloha, etc.), each environment often has unique data formats, coordinate systems, and conventions that need standardization **before** policy processing. Without environment processors, these transformations would be:
1. **Hardcoded in environment code** - Making it difficult to experiment with different state representations
2. **Duplicated across policies** - Each policy would need to handle environment-specific quirks
3. **Mixed with policy logic** - Violating separation of concerns and making debugging harder
Environment processors solve this by providing a **dedicated processing layer** between raw environment observations and policy inputs.
## The Processing Pipeline
Here's how data flows through the complete processing pipeline during evaluation:
```python
# In lerobot_eval.py rollout() function:
# 1. Raw environment observation (numpy arrays, various formats)
raw_observation = env.step(action)
# 2. Convert numpy to torch, normalize images [0,1]
observation = preprocess_observation(raw_observation)
# 3. Add task metadata (for multi-task environments)
observation = add_envs_task(env, observation)
# 4. ENVIRONMENT-SPECIFIC preprocessing (NEW!)
# - Flatten robot states
# - Rotate images to match dataset conventions
# - Handle environment-specific coordinate systems
observation = env_preprocessor(observation)
# 5. POLICY-SPECIFIC preprocessing
# - Normalize with dataset statistics
# - Add batch dimensions
# - Move to GPU
# - Tokenize language instructions
observation = preprocessor(observation)
# 6. Policy inference
action = policy.select_action(observation)
# 7. POLICY-SPECIFIC postprocessing
# - Unnormalize actions
# - Remove batch dimensions
action = postprocessor(action)
# 8. ENVIRONMENT-SPECIFIC postprocessing (NEW!)
# - Convert action formats if needed
# - Apply environment-specific constraints
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# 9. Execute in environment
env.step(action)
```
## The Benefits
### 1. **Separation of Concerns**
Environment processors handle transformations specific to the **environment's data format**, while policy processors handle transformations specific to the **model's requirements**.
```python
# ❌ Before: Mixed concerns
class LiberoVLAPolicy:
def preprocess(self, obs):
# Environment-specific: Flatten robot state (shouldn't be in policy!)
state = self._flatten_robot_state(obs["robot_state"])
# Policy-specific: Normalize with dataset stats
state = self.normalizer(state)
return state
# ✅ After: Clear separation
# Environment processor: Handles LIBERO's nested robot state
env_preprocessor = LiberoProcessorStep() # Flattens robot_state
# Policy processor: Handles model requirements
policy_preprocessor = NormalizerProcessorStep(stats=dataset_stats)
```
### 2. **Flexibility and Reusability**
The same policy can work with different environment processors, and the same environment processor can work with different policies:
```python
# Use SmolVLA policy with LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
# Or use ACT policy with the same LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
```
### 3. **Easier Experimentation**
Want to try different state representations for LIBERO? Just create a new processor:
```python
# Original: 8D state (pos + quat→axisangle + gripper)
@ProcessorStepRegistry.register("libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
gripper = robot_state["gripper"]["qpos"] # 2D
state = torch.cat([eef_pos, eef_axisangle, gripper], dim=-1) # 8D
return state
# Experiment: Add velocity for better control
@ProcessorStepRegistry.register("libero_velocity_processor")
class LiberoVelocityProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
# Include velocities for 14D state
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
eef_vel = robot_state["eef"]["vel"] # 3D (NEW)
gripper_pos = robot_state["gripper"]["qpos"] # 2D
gripper_vel = robot_state["gripper"]["qvel"] # 3D (NEW)
state = torch.cat([eef_pos, eef_axisangle, eef_vel,
gripper_pos, gripper_vel], dim=-1) # 14D
return state
```
### 4. **Cleaner Environment Code**
Environments expose **all available data** without needing to know what downstream models will use:
```python
# LIBERO environment exposes full robot state
observation = {
"pixels": {"image": img, "image2": img2},
"robot_state": {
"eef": {"pos": ..., "quat": ..., "vel": ..., "mat": ..., "axisangle": ...},
"gripper": {"qpos": ..., "qvel": ...},
"joints": {"pos": ..., "vel": ...}
}
}
# Environment processor decides what to use
# Policy processor handles model-specific transformations
```
## Using Environment Processors
### Factory Function
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
```python
from lerobot.envs.factory import make_env_pre_post_processors
from lerobot.envs.configs import LiberoEnv, PushtEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
# For other environments: Returns identity processors (no-op)
pusht_cfg = PushtEnv()
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
```
### Implementation in `envs/factory.py`
```python
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
# Future: Could add environment-specific action transformations
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### Integration in Evaluation
In `lerobot_eval.py`, the environment processors are created once and used throughout:
```python
def eval_main(cfg: EvalPipelineConfig):
# Create environment
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size)
# Create policy
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
# Create policy processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
)
# Create environment processors (NEW!)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
# Run evaluation with both processor types
eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor, # Environment-specific
env_postprocessor=env_postprocessor, # Environment-specific
preprocessor=preprocessor, # Policy-specific
postprocessor=postprocessor, # Policy-specific
n_episodes=cfg.eval.n_episodes,
)
```
## Example: LIBERO Environment Processor
The `LiberoProcessorStep` demonstrates a real-world environment processor:
```python
from lerobot.processor.pipeline import ObservationProcessorStep
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
**State Processing:**
- Extracts end-effector position (3D)
- Converts quaternion to axis-angle representation (3D)
- Extracts gripper joint positions (2D)
- Concatenates into 8D state vector
**Image Processing:**
- Rotates images 180° to match HuggingFaceVLA/libero convention
"""
def _process_observation(self, observation):
processed_obs = observation.copy()
# Process images: Flip 180° for camera convention
for key in list(processed_obs.keys()):
if key.startswith("observation.images."):
img = processed_obs[key]
img = torch.flip(img, dims=[2, 3]) # Flip H and W
processed_obs[key] = img
# Process robot_state: Flatten to 8D vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
eef_pos = robot_state["eef"]["pos"] # (B, 3)
eef_quat = robot_state["eef"]["quat"] # (B, 4)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
state = state.float()
processed_obs["observation.state"] = state
return processed_obs
```
### Why These Transformations?
1. **Image Rotation**: The HuggingFaceVLA/libero dataset has images rotated 180° from the raw LIBERO simulator. The processor handles this convention mismatch so policies trained on the dataset work seamlessly.
2. **State Flattening**: The raw LIBERO environment exposes nested dictionaries with all available state information (position, quaternion, velocity, matrix representation, etc.). The processor:
- Selects the relevant components (pos, quat, gripper)
- Converts quaternion to axis-angle (more suitable for learning)
- Flattens to a single 8D vector that policies expect
3. **Flexibility**: The environment still exposes **all** raw data. If you want to try different state representations (e.g., including velocities, using matrix representation instead of axis-angle), you can create a new processor without modifying the environment code.
## Adding Environment Processors for New Environments
To add environment processors for a new environment:
### 1. Create the Processor Step
```python
# In src/lerobot/processor/env_processor.py
@dataclass
@ProcessorStepRegistry.register(name="myenv_processor")
class MyEnvProcessorStep(ObservationProcessorStep):
"""Process observations from MyEnv."""
def _process_observation(self, observation):
processed = observation.copy()
# Your environment-specific transformations
if "myenv.specific.state" in processed:
state = processed.pop("myenv.specific.state")
# Transform to standard format
processed["observation.state"] = self._transform_state(state)
return processed
```
### 2. Update the Factory
```python
# In src/lerobot/envs/factory.py
def make_env_pre_post_processors(env_cfg: EnvConfig):
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
else:
preprocessor = PolicyProcessorPipeline(steps=[])
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### 3. Use in Evaluation
No changes needed! The evaluation script automatically uses the appropriate processor:
```bash
lerobot-eval \
--policy.path=lerobot/my_policy \
--env.type=myenv \ # Automatically uses MyEnvProcessorStep
--eval.n_episodes=10
```
## Future: Environment Postprocessors
Currently, postprocessors are identity (no-op) for all environments. Future use cases include:
### Action Space Transformations
```python
@dataclass
class MyEnvActionPostprocessor(ProcessorStep):
"""Convert policy actions to environment-specific format."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Convert from Cartesian to joint space
if self.action_space == "joint":
action = self.ik_solver(action)
# Example: Apply environment-specific safety limits
action = torch.clamp(action, self.min_action, self.max_action)
transition["action"] = action
return transition
```
### Coordinate System Conversions
```python
@dataclass
class CoordinateTransformPostprocessor(ProcessorStep):
"""Transform actions between coordinate systems."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Policy outputs in world frame, env expects base frame
action = self.world_to_base_transform(action)
transition["action"] = action
return transition
```
## Best Practices
1. **Keep environment processors simple**: They should only handle environment-specific data format issues, not complex learning-related transformations.
2. **Use policy processors for model requirements**: Normalization, batching, device placement, and tokenization belong in policy processors.
3. **Expose all data from environments**: Let processors decide what to use rather than hardcoding choices in the environment.
4. **Document conventions**: Clearly document any coordinate system conventions, camera orientations, or data formats that your processor handles.
5. **Test independently**: Environment processors should be testable without loading full policies or environments.
## Summary
Environment processors provide a **clean separation** between environment-specific data transformations and policy-specific model requirements. This architecture:
- ✅ Enables easy experimentation with different state representations
- ✅ Allows policies to work seamlessly across different environments
- ✅ Keeps environment code focused on simulation/hardware interface
- ✅ Makes processor pipelines more maintainable and debuggable
- ✅ Follows the single responsibility principle
The key insight: **Environments define data formats, processors standardize them, policies consume standardized data.** Each layer has a clear, focused responsibility.

View File

@@ -1,301 +0,0 @@
# LeIsaac × LeRobot EnvHub
LeRobot EnvHub now supports **imitation learning in simulation** with LeIsaac.
Spin up everyday manipulation tasks, teleoperate the robot, collect demos, push them to the Hub, and train policies in LeRobot — all in one loop.
[LeIsaac](https://github.com/LightwheelAI/leisaac) integrates with IsaacLab and the SO101 Leader/Follower setup to provide:
- 🕹️ **Teleoperation-first workflows** for data collection
- 📦 **Built-in data conversion** ready for LeRobot training
- 🤖 **Everyday skills** like picking oranges, lifting cubes, cleaning tables, and folding cloth
- ☁️ **Ongoing upgrades** from [LightWheel](https://lightwheel.ai/): cloud simulation, EnvHub support, Sim2Real tooling, and more
Below youll find the currently supported LeIsaac tasks exposed through LeRobot EnvHub.
# Available Environments
The following table lists all available tasks and environments in LeIsaac x LeRobot Envhub. You can also get the latest list of environments by running the following command:
```bash
python scripts/environments/list_envs.py
```
| Task | Environment ID | Task Description | Related Robot |
| :-------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------- |
| <video src="https://github.com/user-attachments/assets/466eddff-f720-4f99-94d5-5e123e4c302c" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-PickOrange-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/pick_orange/pick_orange_env_cfg.py)<br /><br />[LeIsaac-SO101-PickOrange-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/pick_orange/direct/pick_orange_env.py) | Pick three oranges and put them into the plate, then reset the arm to rest state. | Single-Arm SO101 Follower |
| <video src="https://github.com/user-attachments/assets/1e4eb83a-0b38-40fb-a0b2-ddb0fe201e6d" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-LiftCube-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/lift_cube/lift_cube_env_cfg.py)<br /><br />[LeIsaac-SO101-LiftCube-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/lift_cube/direct/lift_cube_env.py) | Lift the red cube up. | Single-Arm SO101 Follower |
| <video src="https://github.com/user-attachments/assets/e49d8f1c-dcc9-412b-a88f-100680d8a45b" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-CleanToyTable-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/clean_toy_table_env_cfg.py)<br /><br />[LeIsaac-SO101-CleanToyTable-BiArm-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/clean_toy_table_bi_arm_env_cfg.py)<br /><br />[LeIsaac-SO101-CleanToyTable-BiArm-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/direct/clean_toy_table_bi_arm_env.py) | Pick two letter e objects into the box, and reset the arm to rest state. | Single-Arm SO101 Follower<br /><br />Bi-Arm SO101 Follower |
| <video src="https://github.com/user-attachments/assets/e29a0f8a-9286-4ce6-b45d-342c3d3ba754" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-FoldCloth-BiArm-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/fold_cloth/fold_cloth_bi_arm_env_cfg.py)<br /><br />[LeIsaac-SO101-FoldCloth-BiArm-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/fold_cloth/direct/fold_cloth_bi_arm_env.py) | Fold the cloth, and reset the arm to rest state.<br /><br />_Note: Only the DirectEnv support check_success in this task._ | Bi-Arm SO101 Follower |
# Load LeIsaac directly in LeRobot with one line of code
> EnvHub: Share LeIsaac environments through HuggingFace
[EnvHub](https://huggingface.co/docs/lerobot/envhub) is our reproducible environment hub, spin up a packaged simulation with one line, experiment immediately, and publish your own tasks for the community.
LeIsaac offers EnvHub support so you can consume or share tasks with only a few commands.
<video
controls
src="https://github.com/user-attachments/assets/687666f5-ebe0-421d-84a0-eb86116ac5f8"
style={{ width: "100%", maxWidth: "960px", borderRadius: "8px" }}
/>
## How to get started, environment Setup
Run the following commands to setup your code environments:
```bash
# Refer to Getting Started/Installation to install leisaac firstly
conda create -n leisaac_envhub python=3.11
conda activate leisaac_envhub
conda install -c "nvidia/label/cuda-12.8.1" cuda-toolkit
pip install -U torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu128
pip install 'leisaac[isaaclab] @ git+https://github.com/LightwheelAI/leisaac.git#subdirectory=source/leisaac' --extra-index-url https://pypi.nvidia.com
# Install lerobot
pip install lerobot==0.4.1
# Fix numpy version
pip install numpy==1.26.0
```
## Usage Example
EnvHub exposes every LeIsaac-supported task in a uniform interface. The examples below load `so101_pick_orange` and demonstrate a random-action rollout and an interactive teleoperation.
### Random Action
<details>
<summary>Click to expand code example</summary>
```python
# envhub_random_action.py
import torch
from lerobot.envs.factory import make_env
# Load from the hub
envs_dict = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True)
# Access the environment
suite_name = next(iter(envs_dict))
sync_vector_env = envs_dict[suite_name][0]
# retrieve the isaac environment from the sync vector env
env = sync_vector_env.envs[0].unwrapped
# Use it like any gym environment
obs, info = env.reset()
while True:
action = torch.tensor(env.action_space.sample())
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()
env.close()
```
</details>
```bash
python envhub_random_action.py
```
You should see the SO101 arm swinging under purely random commands.
### Teleoperation
LeRobots teleoperation stack can drive the simulated arm.
Connect the SO101 Leader controller, run the calibration command below.
```bash
lerobot-calibrate \
--teleop.type=so101_leader \
--teleop.port=/dev/ttyACM0 \
--teleop.id=leader
```
And then launch the teleop script.
<details>
<summary>Click to expand code example</summary>
```python
# envhub_teleop_example.py
import logging
import time
import gymnasium as gym
from dataclasses import asdict, dataclass
from pprint import pformat
from lerobot.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
make_teleoperator_from_config,
so101_leader,
)
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging
from lerobot.envs.factory import make_env
@dataclass
class TeleoperateConfig:
teleop: TeleoperatorConfig
env_name: str = "so101_pick_orange"
fps: int = 60
@dataclass
class EnvWrap:
env: gym.Env
def make_env_from_leisaac(env_name: str = "so101_pick_orange"):
envs_dict = make_env(
f'LightwheelAI/leisaac_env:envs/{env_name}.py',
n_envs=1,
trust_remote_code=True
)
suite_name = next(iter(envs_dict))
sync_vector_env = envs_dict[suite_name][0]
env = sync_vector_env.envs[0].unwrapped
return env
def teleop_loop(teleop: Teleoperator, env: gym.Env, fps: int):
from leisaac.devices.action_process import preprocess_device_action
from leisaac.assets.robots.lerobot import SO101_FOLLOWER_MOTOR_LIMITS
from leisaac.utils.env_utils import dynamic_reset_gripper_effort_limit_sim
env_wrap = EnvWrap(env=env)
obs, info = env.reset()
while True:
loop_start = time.perf_counter()
if env.cfg.dynamic_reset_gripper_effort_limit:
dynamic_reset_gripper_effort_limit_sim(env, 'so101leader')
raw_action = teleop.get_action()
processed_action = preprocess_device_action(
dict(
so101_leader=True,
joint_state={
k.removesuffix(".pos"): v for k, v in raw_action.items()},
motor_limits=SO101_FOLLOWER_MOTOR_LIMITS),
env_wrap
)
obs, reward, terminated, truncated, info = env.step(processed_action)
if terminated or truncated:
obs, info = env.reset()
dt_s = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt_s)
loop_s = time.perf_counter() - loop_start
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
def teleoperate(cfg: TeleoperateConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
teleop = make_teleoperator_from_config(cfg.teleop)
env = make_env_from_leisaac(cfg.env_name)
teleop.connect()
if hasattr(env, 'initialize'):
env.initialize()
try:
teleop_loop(teleop=teleop, env=env, fps=cfg.fps)
except KeyboardInterrupt:
pass
finally:
teleop.disconnect()
env.close()
def main():
teleoperate(TeleoperateConfig(
teleop=so101_leader.SO101LeaderConfig(
port="/dev/ttyACM0",
id='leader',
use_degrees=False,
),
env_name="so101_pick_orange",
fps=60,
))
if __name__ == "__main__":
main()
```
</details>
```bash
python envhub_teleop_example.py
```
Running the script lets you operate the simulated arm using the physical Leader device.
## ☁️ Cloud Simulation (No GPU Required)
Dont have a local GPU or the right drivers? No problem! You can run LeIsaac entirely in the cloud with zero setup.
LeIsaac works out-of-the-box on **NVIDIA Brev**, giving you a fully configured environment directly in your browser.
👉 **Start here:** [https://lightwheelai.github.io/leisaac/docs/cloud_simulation/nvidia_brev](https://lightwheelai.github.io/leisaac/docs/cloud_simulation/nvidia_brev)
Once your instance is deployed, simply open the link for **port 80 (HTTP)** to launch **Visual Studio Code Server** (default password: `password`). From there, you can run simulations, edit code, and visualize IsaacLab environments — all from your web browser.
**No GPU, no drivers, no local installation. Just click and run.**
## Additional Notes
We keep EnvHub coverage aligned with the LeIsaac task. Currently supported:
- `so101_pick_orange`
- `so101_lift_cube`
- `so101_clean_toytable`
- `bi_so101_fold_cloth`
Switch tasks by targeting a different script when calling `make_env`, for example:
```python
envs_dict_pick_orange = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True)
envs_dict_lift_cube = make_env("LightwheelAI/leisaac_env:envs/so101_lift_cube.py", n_envs=1, trust_remote_code=True)
envs_dict_clean_toytable = make_env("LightwheelAI/leisaac_env:envs/so101_clean_toytable.py", n_envs=1, trust_remote_code=True)
envs_dict_fold_cloth = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True)
```
Note: when working with `bi_so101_fold_cloth`, call `initialize()` immediately after retrieving the env before performing any other operations:
<details>
<summary>Click to expand code example</summary>
```python
import torch
from lerobot.envs.factory import make_env
# Load from the hub
envs_dict = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True)
# Access the environment
suite_name = next(iter(envs_dict))
sync_vector_env = envs_dict[suite_name][0]
# retrieve the isaac environment from the sync vector env
env = sync_vector_env.envs[0].unwrapped
# NOTE: initialize() first
env.initialize()
# other operation with env...
```
</details>

View File

@@ -393,7 +393,7 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
episode_idx = 0
@@ -415,7 +415,7 @@ for idx in range(dataset.num_frames):
}
robot.send_action(action)
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
robot.disconnect()
```
@@ -428,7 +428,7 @@ Your robot should replicate movements similar to those you recorded. For example
## Train a policy
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_train.py) script. A few arguments are required. Here is an example command:
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
lerobot-train \
@@ -485,7 +485,7 @@ huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
## Run inference and evaluate your policy
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
<hfoptions id="eval">
<hfoption id="Command">

220
docs/source/il_sim.mdx Normal file
View File

@@ -0,0 +1,220 @@
# Imitation Learning in Sim
This tutorial will explain how to train a neural network to control a robot in simulation with imitation learning.
**You'll learn:**
1. How to record a dataset in simulation with [gym-hil](https://github.com/huggingface/gym-hil) and visualize the dataset.
2. How to train a policy using your data.
3. How to evaluate your policy in simulation and visualize the results.
For the simulation environment we use the same [repo](https://github.com/huggingface/gym-hil) that is also being used by the Human-In-the-Loop (HIL) reinforcement learning algorithm.
This environment is based on [MuJoCo](https://mujoco.org) and allows you to record datasets in LeRobotDataset format.
Teleoperation is easiest with a controller like the Logitech F710, but you can also use your keyboard if you are up for the challenge.
## Installation
First, install the `gym_hil` package within the LeRobot environment, go to your LeRobot folder and run this command:
```bash
pip install -e ".[hilserl]"
```
## Teleoperate and Record a Dataset
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/env_config.json).
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"dataset": {
"repo_id": "your_username/il_gym",
"root": null,
"task": "pick_cube",
"num_episodes_to_record": 30,
"replay_episode": null,
"push_to_hub": true
},
"mode": "record",
"device": "cuda"
}
```
Key configuration points:
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
- Ensure `mode` is set to `"record"`
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
Then we can run this command to start:
<hfoptions id="teleop_sim">
<hfoption id="Linux">
```bash
python -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
```
</hfoption>
<hfoption id="MacOS">
```bash
mjpython -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
```
</hfoption>
</hfoptions>
Once rendered you can teleoperate the robot with the gamepad or keyboard, below you can find the gamepad/keyboard controls.
Note that to teleoperate the robot you have to hold the "Human Take Over Pause Policy" Button `RB` to enable control!
**Gamepad Controls**
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/gamepad_guide.jpg?raw=true"
alt="Figure shows the control mappings on a Logitech gamepad."
title="Gamepad Control Mapping"
width="100%"
></img>
</p>
<p align="center">
<i>Gamepad button mapping for robot control and episode management</i>
</p>
**Keyboard controls**
For keyboard controls use the `spacebar` to enable control and the following keys to move the robot:
```bash
Arrow keys: Move in X-Y plane
Shift and Shift_R: Move in Z axis
Right Ctrl and Left Ctrl: Open and close gripper
ESC: Exit
```
## Visualize a dataset
If you uploaded your dataset to the hub you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/dataset_visualizer_sim.png"
alt="Figure shows the dataset visualizer"
title="Dataset visualization"
width="100%"
></img>
</p>
<p align="center">
<i>Dataset visualizer</i>
</p>
## Train a policy
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
lerobot-train \
--dataset.repo_id=${HF_USER}/il_gym \
--policy.type=act \
--output_dir=outputs/train/il_sim_test \
--job_name=il_sim_test \
--policy.device=cuda \
--wandb.enable=true
```
Let's explain the command:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
Training should take several hours, 100k steps (which is the default) will take about 1h on Nvidia A100. You will find checkpoints in `outputs/train/il_sim_test/checkpoints`.
#### Train using Collab
If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act).
#### Upload policy checkpoints
Once training is done, upload the latest checkpoint with:
```bash
huggingface-cli upload ${HF_USER}/il_sim_test \
outputs/train/il_sim_test/checkpoints/last/pretrained_model
```
You can also upload intermediate checkpoints with:
```bash
CKPT=010000
huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
outputs/train/il_sim_test/checkpoints/${CKPT}/pretrained_model
```
## Evaluate your policy in Sim
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/eval_config.json).
Here's an example evaluation configuration:
```json
{
"env": {
"type": "gym_manipulator",
"name": "gym_hil",
"task": "PandaPickCubeGamepad-v0",
"fps": 10
},
"dataset": {
"repo_id": "your_username/il_sim_dataset",
"dataset_root": null,
"task": "pick_cube"
},
"pretrained_policy_name_or_path": "your_username/il_sim_model",
"device": "cuda"
}
```
Make sure to replace:
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
Then you can run this command to visualize your trained policy
<hfoptions id="eval_policy">
<hfoption id="Linux">
```bash
python -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
```
</hfoption>
<hfoption id="MacOS">
```bash
mjpython -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
```
</hfoption>
</hfoptions>
> [!WARNING]
> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation.
Congrats 🎉, you have finished this tutorial. If you want to continue with using LeRobot in simulation follow this [Tutorial on reinforcement learning in sim with HIL-SERL](https://huggingface.co/docs/lerobot/hilserl_sim)
> [!TIP]
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).

View File

@@ -90,7 +90,7 @@ If you encounter build errors, you may need to install additional dependencies:
To install these for linux run:
```bash
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
```
For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)

View File

@@ -62,11 +62,6 @@ lerobot-eval \
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
### Control Mode
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
### Policy inputs and outputs
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:

View File

@@ -30,6 +30,131 @@ The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however,
| Wrist Roll | 5 | 1 / 147 |
| Gripper | 6 | 1 / 147 |
### Clean Parts
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
### Joint 1
- Place the first motor into the base.
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
- Install both motor horns, securing the top horn with a M3x6mm screw.
- Attach the shoulder part.
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
- Add the shoulder motor holder.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 2
- Slide the second motor in from the top.
- Fasten the second motor with 4 M2x6mm screws.
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
- Attach the upper arm with 4 M3x6mm screws on each side.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 3
- Insert motor 3 and fasten using 4 M2x6mm screws
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 4
- Slide over motor holder 4.
- Slide in motor 4.
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 5
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Gripper / Handle
<hfoptions id="assembly">
<hfoption id="Follower">
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
- Attach the motor horns and again use a M3x6mm horn screw.
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
type="video/mp4"
/>
</video>
</div>
</hfoption>
<hfoption id="Leader">
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
- Attach the handle to motor 5 using 1 M2x6mm screw.
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
- Attach the follower trigger with 4 M3x6mm screws.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
type="video/mp4"
/>
</video>
</div>
</hfoption>
</hfoptions>
## Configure the motors
### 1. Find the USB ports associated with each arm
@@ -215,131 +340,6 @@ leader.setup_motors()
</hfoption>
</hfoptions>
### Clean Parts
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
### Joint 1
- Place the first motor into the base.
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
- Install both motor horns, securing the top horn with a M3x6mm screw.
- Attach the shoulder part.
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
- Add the shoulder motor holder.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 2
- Slide the second motor in from the top.
- Fasten the second motor with 4 M2x6mm screws.
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
- Attach the upper arm with 4 M3x6mm screws on each side.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 3
- Insert motor 3 and fasten using 4 M2x6mm screws
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 4
- Slide over motor holder 4.
- Slide in motor 4.
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Joint 5
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
type="video/mp4"
/>
</video>
</div>
### Gripper / Handle
<hfoptions id="assembly">
<hfoption id="Follower">
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
- Attach the motor horns and again use a M3x6mm horn screw.
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
type="video/mp4"
/>
</video>
</div>
</hfoption>
<hfoption id="Leader">
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
- Attach the handle to motor 5 using 1 M2x6mm screw.
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
- Attach the follower trigger with 4 M3x6mm screws.
<div class="video-container">
<video controls width="600">
<source
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
type="video/mp4"
/>
</video>
</div>
</hfoption>
</hfoptions>
## Calibrate
Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.

View File

@@ -1,42 +0,0 @@
# PyTorch accelerators
LeRobot supports multiple hardware acceleration options for both training and inference.
These options include:
- **CPU**: CPU executes all computations, no dedicated accelerator is used
- **CUDA**: acceleration with NVIDIA & AMD GPUs
- **MPS**: acceleration with Apple Silicon GPUs
- **XPU**: acceleration with Intel integrated and discrete GPUs
## Getting Started
To use particular accelerator, a suitable version of PyTorch should be installed.
For CPU, CUDA, and MPS backends follow instructions provided on [PyTorch installation page](https://pytorch.org/get-started/locally).
For XPU backend, follow instructions from [PyTorch documentation](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html).
### Verifying the installation
After installation, accelerator availability can be verified by running
```python
import torch
print(torch.<backend_name>.is_available()) # <backend_name> is cuda, mps, or xpu
```
## How to run training or evaluation
To select the desired accelerator, use the `--policy.device` flag when running `lerobot-train` or `lerobot-eval`. For example, to use MPS on Apple Silicon, run:
```bash
lerobot-train
--policy.device=mps ...
```
```bash
lerobot-eval \
--policy.device=mps ...
```
However, in most cases, presence of an accelerator is detected automatically and `policy.device` parameter can be omitted from CLI commands.

View File

@@ -1,203 +0,0 @@
# Unitree G1 Robot Setup and Control
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
## About the Unitree G1
We offer support for both 29 and 23 DOF G1. In this first PR we introduce:
- **`unitree g1` robot class, handling low level communication with the humanoid**
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
- **GR00T locomotion policy** for bipedal walking and balance
---
## Part 1: Connect to Robot over Ethernet
### Step 1: Configure Your Computer's Ethernet Interface
Set a static IP on the same subnet as the robot:
```bash
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
sudo ip addr flush dev enp131s0
sudo ip addr add 192.168.123.200/24 dev enp131s0
sudo ip link set enp131s0 up
```
**Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164.
### Step 2: SSH into the Robot
```bash
ssh unitree@192.168.123.164
# Password: 123
```
You should now be connected to the robot's onboard computer.
---
## Part 2: Enable WiFi on the Robot
Once connected via Ethernet, follow these steps to enable WiFi:
### Step 1: Enable WiFi Hardware
```bash
# Unblock WiFi radio
sudo rfkill unblock wifi
sudo rfkill unblock all
# Bring up WiFi interface
sudo ip link set wlan0 up
# Enable NetworkManager control
sudo nmcli radio wifi on
sudo nmcli device set wlan0 managed yes
sudo systemctl restart NetworkManager
```
### Step 2: Enable Internet Forwarding
**On your laptop:**
```bash
# Enable IP forwarding
sudo sysctl -w net.ipv4.ip_forward=1
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
```
**On the robot:**
```bash
# Add laptop as default gateway
sudo ip route del default 2>/dev/null || true
sudo ip route add default via 192.168.123.200 dev eth0
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
# Test connection
ping -c 3 8.8.8.8
```
### Step 3: Connect to WiFi Network
```bash
# List available networks
nmcli device wifi list
# Connect to your WiFi (example)
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
sudo nmcli connection up "YourNetwork"
# Check WiFi IP address
ip a show wlan0
```
### Step 4: SSH Over WiFi
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
```bash
ssh unitree@<YOUR_ROBOT_IP>
# Password: 123
```
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`).
---
## Part 3: Robot Server Setup
### Step 1: Install LeRobot on the Orin
SSH into the robot and install LeRobot:
```bash
ssh unitree@<YOUR_ROBOT_IP>
conda create -y -n lerobot python=3.10
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
```
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
### Step 2: Run the Robot Server
On the robot:
```bash
python src/lerobot/robots/unitree_g1/run_g1_server.py
```
**Important**: Keep this terminal running. The server must be active for remote control.
---
## Part 4: Running GR00T Locomotion
With the robot server running, you can now control the robot from your laptop.
### Step 1: Install LeRobot on your machine
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
git clone https://github.com/huggingface/lerobot.git
cd lerobot
pip install -e '.[unitree_g1]'
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
cd unitree_sdk2_python && pip install -e .
```
### Step 2: Update Robot IP in Config
Edit the config file to match your robot's WiFi IP:
```python
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
```
**Note**: When running directly on the G1 (not remotely), set `robot_ip: str = "127.0.0.1"` instead.
### Step 3: Run the Locomotion Policy
```bash
# Run GR00T locomotion controller
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
```
### Step 4: Control with Remote
- **Left stick**: Forward/backward and left/right movement
- **Right stick**: Rotation
- **R1 button**: Raise waist height
- **R2 button**: Lower waist height
Press `Ctrl+C` to stop the policy.
---
## Additional Resources
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
- [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1)
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
---
_Last updated: December 2025_

View File

@@ -11,14 +11,13 @@ LeRobot provides several utilities for manipulating datasets:
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
## Command-Line Tool: lerobot-edit-dataset
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
@@ -87,71 +86,9 @@ lerobot-edit-dataset \
--operation.feature_names "['observation.images.top']"
```
#### Convert to Video
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
```bash
# Local-only: Save to a custom output directory (no hub push)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video
# Save with new repo_id (local storage)
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video
# Convert and push to Hugging Face Hub
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video \
--push_to_hub true
# Convert with custom video codec and quality settings
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.vcodec libsvtav1 \
--operation.pix_fmt yuv420p \
--operation.g 2 \
--operation.crf 30
# Convert only specific episodes
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.episode_indices "[0, 1, 2, 5, 10]"
# Convert with multiple workers for parallel processing
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir outputs/pusht_video \
--operation.num_workers 8
```
**Parameters:**
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
- `fast_decode`: Fast decode tuning option (default: 0)
- `episode_indices`: List of specific episodes to convert (default: all episodes)
- `num_workers`: Number of parallel workers for processing (default: 4)
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
### Push to Hub
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
```bash
lerobot-edit-dataset \
@@ -159,7 +96,7 @@ lerobot-edit-dataset \
--new_repo_id lerobot/pusht_after_deletion \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]" \
--push_to_hub true
--push_to_hub
```
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.

View File

@@ -1,528 +0,0 @@
# X-VLA: The First Soft-Prompted Robot Foundation Model for Any Robot, Any Task
## Overview
For years, robotics has aspired to build agents that can follow natural human instructions and operate dexterously across many environments and robot bodies. Recent breakthroughs in LLMs and VLMs suggest a path forward: extend these foundation-model architectures to embodied control by grounding them in actions. This has led to the rise of Vision-Language-Action (VLA) models, with the hope that a single generalist model could combine broad semantic understanding with robust manipulation skills.
But training such models is difficult. Robot data is fragmented across platforms, sensors, embodiments, and collection protocols. Heterogeneity appears everywhere: different arm configurations, different action spaces, different camera setups, different visual domains, and different task distributions. These inconsistencies create major distribution shifts that make pretraining unstable and adaptation unreliable.
Inspired by meta-learning and prompt learning, we ask: **"What if a VLA model could learn the structure of each robot and dataset the same way LLMs learn tasks, through prompts?"**
**X-VLA** is a soft-prompted, flow-matching VLA framework that treats each hardware setup as a "task" and encodes it using a small set of learnable embeddings. These **Soft Prompts** capture embodiment and domain-specific variations, guiding the Transformer from the earliest stages of multimodal fusion. With this mechanism, X-VLA can reconcile diverse robot morphologies, data types, and sensor setups within a single unified architecture.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
alt="XVLA Architecture"
style="max-width: 100%; height: auto; width: 800px;"
/>
</p>
Built from pure Transformer encoders, X-VLA scales naturally with model size and dataset diversity. Across 6 simulation benchmarks and 3 real robots, Soft Prompts consistently outperform existing methods in handling hardware and domain differences. X-VLA-0.9B, trained on 290K episodes spanning seven robotic platforms, learns an embodiment-agnostic generalist policy in Phase I, and adapts efficiently to new robots in Phase II simply by learning a new set of prompts, while keeping the backbone frozen.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
alt="XVLA Architecture 2"
style="width: 60%; height: auto;"
/>
</p>
With only 1% of parameters tuned (9M), X-VLA-0.9B achieves near-π₀ performance on LIBERO and Simpler-WidowX, despite using **300× fewer trainable parameters**. It also demonstrates strong real-world dexterity with minimal demonstrations, including folding cloths in under two minutes.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-fold.png"
alt="XVLA fold visualization"
style="width: 95%; max-width: 1100px; height: auto;"
/>
</p>
X-VLA shows that generalist robot intelligence does not require increasingly complex architectures, only the right way to absorb heterogeneity. Soft Prompts offer a simple, scalable mechanism for unifying diverse robotic data, paving the way toward adaptable, cross-embodiment robot foundation models.
## Installation
After installing LeRobot, install the X-VLA dependencies:
```bash
pip install -e .[xvla]
```
After the new release, you'll be able to do:
```bash
pip install lerobot[xvla]
```
## Quick Start
### Basic Usage
To use X-VLA in your LeRobot configuration, specify the policy type as:
```bash
policy.type=xvla
```
### Evaluating Pre-trained Checkpoints
Example evaluation with LIBERO:
```bash
lerobot-eval \
--policy.path="lerobot/xvla-libero" \
--env.type=libero \
--env.task=libero_spatial,libero_goal,libero_10 \
--env.control_mode=absolute \
--eval.batch_size=1 \
--eval.n_episodes=1 \
--env.episode_length=800 \
--seed=142
```
## Available Checkpoints
### 🎯 Base Model
**[lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base)**
A 0.9B parameter instantiation of X-VLA, trained with a carefully designed data processing and learning recipe. The training pipeline consists of two phases:
- **Phase I: Pretraining** - Pretrained on 290K episodes from Droid, Robomind, and Agibot, spanning seven platforms across five types of robotic arms (single-arm to bi-manual setups). By leveraging soft prompts to absorb embodiment-specific variations, the model learns an embodiment-agnostic generalist policy.
- **Phase II: Domain Adaptation** - Adapted to deployable policies for target domains. A new set of soft prompts is introduced and optimized to encode the hardware configuration of the novel domain, while the pretrained backbone remains frozen.
### Simulation Checkpoints
**[lerobot/xvla-libero](https://huggingface.co/lerobot/xvla-libero)**
Achieves 93% success rate on LIBERO benchmarks. Fine-tuned from the base model for simulation tasks.
**[lerobot/xvla-widowx](https://huggingface.co/lerobot/xvla-widowx)**
Fine-tuned on BridgeData for pick-and-place experiments on compact WidowX platforms. Demonstrates robust manipulation capabilities.
### 🤖 Real-World Checkpoints
**[lerobot/xvla-folding](https://huggingface.co/lerobot/xvla-folding)**
A fine-tuned dexterous manipulation model trained on the high-quality Soft-FOLD cloth folding dataset. Achieves 100% success rate over 2 hours of continuous cloth folding.
**[lerobot/xvla-agibot-world](https://huggingface.co/lerobot/xvla-agibot-world)**
Optimized for AgileX robot dexterous manipulation tasks.
**[lerobot/xvla-google-robot](https://huggingface.co/lerobot/xvla-google-robot)**
Adapted for Google Robot platforms.
## Training X-VLA
### Recommended Training Configuration
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
```bash
lerobot-train \
--dataset.repo_id=YOUR_DATASET \
--output_dir=./outputs/xvla_training \
--job_name=xvla_training \
--policy.path="lerobot/xvla-base" \
--policy.repo_id="HF_USER/xvla-your-robot" \
--policy.dtype=bfloat16 \
--steps=3000 \
--policy.device=cuda \
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true \
--policy.action_mode=YOUR_ACTION_MODE
```
### Training Parameters Explained
| Parameter | Default | Description |
| -------------------------- | ------- | ---------------------------------------------- |
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
| `train_soft_prompts` | `true` | Allow soft prompts to train |
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
### Example: Training on Bimanual Robot
```bash
lerobot-train \
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
--output_dir=./outputs/xvla_bimanual \
--job_name=xvla_so101_training \
--policy.path="lerobot/xvla-base" \
--policy.dtype=bfloat16 \
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
--steps=3000 \
--policy.device=cuda \
--policy.action_mode=so101_bimanual \
--policy.freeze_vision_encoder=false \
--policy.freeze_language_encoder=false \
--policy.train_policy_transformer=true \
--policy.train_soft_prompts=true
```
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
**🔥 Full-finetune all components with a custom learning-rate scheme**
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
❕Note
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
We encourage implementing this in your customized training pipeline for optimal results.
## Core Concepts
### 1. Action Modes
X-VLA uses an **Action Registry** system to handle different action spaces and embodiments. The `action_mode` parameter defines how actions are processed, what loss functions are used, and how predictions are post-processed.
#### Available Action Modes
| Action Mode | Action Dim | Description | Use Case |
| ---------------- | ----------------------- | ------------------------------------------- | ------------------------------------ |
| `ee6d` | 20 | End-effector with xyz, 6D rotation, gripper | Dual-arm setups with spatial control |
| `joint` | 14 | Joint-space with gripper | Direct joint control robots |
| `agibot_ee6d` | 20 | AGI-bot variant with MSE loss | AGI-bot platforms |
| `so101_bimanual` | 20 (model), 12 (real) | SO101 bimanual robot | Bimanual manipulation tasks |
| `auto` | 20 (model), auto (real) | Auto-detects action dim from dataset | **Recommended** for new robots |
#### Why Action Modes Matter
When you have a pretrained checkpoint like `lerobot/xvla-base` trained with `action_dim=20`, and you want to train on a dataset with a different action dimension (e.g., 14 for bimanual arms), you can't simply trim the action dimension. The action mode orchestrates:
1. **Loss Computation**: Different loss functions for different action components (MSE for joints, BCE for grippers, etc.)
2. **Preprocessing**: Zeroing out gripper channels, padding dimensions
3. **Postprocessing**: Applying sigmoid to gripper logits, trimming padding
#### Example: BimanualSO101 Action Space
The `so101_bimanual` action mode handles the mismatch between model output (20D) and real robot control (12D):
```python
# Model outputs 20 dimensions for compatibility
dim_action = 20
# Real robot only needs 12 dimensions
# [left_arm (6), right_arm (6)] = [joints (5) + gripper (1)] × 2
REAL_DIM = 12
# Preprocessing: Pad 12D actions to 20D for training
# Postprocessing: Trim 20D predictions to 12D for deployment
```
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
#### Auto Action Mode (Recommended)
The `auto` action mode is the easiest way to use X-VLA with any robot. It automatically detects your dataset's action dimension and handles padding/trimming:
```bash
lerobot-train \
--policy.path="lerobot/xvla-base" \
--policy.action_mode=auto \
--policy.max_action_dim=20 \
...
```
**How it works:**
- Reads `action_feature.shape[-1]` from your dataset (e.g., 7 for Franka)
- Model outputs `max_action_dim` (default 20) for pretrained compatibility
- Loss is computed **only on the real dimensions**: `MSE(pred[:,:,:real_dim], target[:,:,:real_dim])`
- Postprocess trims output back to `real_dim` for robot control
This eliminates the need to create custom action modes for most robots.
### 2. Domain IDs
Domain IDs are learnable identifiers for different robot configurations and camera setups. They allow X-VLA to distinguish between:
- Different robots (Robot 1 vs Robot 2)
- Different camera configurations (cam1 vs cam2)
- Different combinations (Robot1-cam1-cam2 vs Robot1-cam1 vs Robot2-cam1)
#### Setting Domain IDs
**During Training**: By default, domain_id is set to 0 for general training.
**During Evaluation**: Specify the domain_id that matches your checkpoint's training configuration.
```python
# Example: LIBERO checkpoint uses domain_id=3
domain_id = 3
```
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
#### Fine-tuning Datasets
| Dataset Name | Domain ID |
| ---------------- | --------- |
| Bridge | 0 |
| RT1 | 1 |
| Calvin | 2 |
| libero | 3 |
| widowx-air | 4 |
| AIR-AGILEX-HQ | 5 |
| robotwin2_abs_ee | 6 |
| robotwin2_clean | 6 |
| robocasa-human | 7 |
| VLABench | 8 |
| AGIBOT-challenge | 9 |
| AIR-AGILEX | 10 |
| AIRBOT | 18 |
### 3. Processor Steps
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
#### Required Preprocessing Steps
1. **XVLAImageToFloatProcessorStep**: Converts images from [0, 255] to [0, 1] range
2. **XVLAImageNetNormalizeProcessorStep**: Applies ImageNet normalization (required for VLM backbone)
3. **XVLAAddDomainIdProcessorStep**: Adds domain_id to observations
#### Example Custom Processor
For LIBERO environments, a custom processor handles the specific observation format:
```python
from lerobot.policies.xvla.processor_xvla import LiberoProcessorStep
processor = LiberoProcessorStep()
# Handles robot_state dictionary, converts rotation matrices to 6D representation
# Applies 180° image rotation for camera convention
```
### 4. Configuration Parameters
Key configuration parameters for X-VLA:
```python
# Observation and action
n_obs_steps: int = 1 # Number of observation timesteps
chunk_size: int = 32 # Action sequence length
n_action_steps: int = 32 # Number of action steps to execute
# Model architecture
hidden_size: int = 1024 # Transformer hidden dimension
depth: int = 24 # Number of transformer layers
num_heads: int = 16 # Number of attention heads
num_domains: int = 30 # Maximum number of domain IDs
len_soft_prompts: int = 32 # Length of soft prompt embeddings
# Action space
action_mode: str = "ee6d" # Action space type (use "auto" for auto-detection)
use_proprio: bool = True # Use proprioceptive state
max_state_dim: int = 32 # Maximum state dimension
max_action_dim: int = 20 # Max action dim for padding (used by "auto" mode)
# Vision
num_image_views: int | None # Number of camera views
resize_imgs_with_padding: tuple[int, int] | None # Target image size with padding
# Training
num_denoising_steps: int = 10 # Flow matching denoising steps
```
## Creating Custom Action Modes
If your robot has a unique action space, you can create a custom action mode:
### Step 1: Define Your Action Space
```python
from lerobot.policies.xvla.action_hub import BaseActionSpace, register_action
import torch.nn as nn
@register_action("my_custom_robot")
class MyCustomActionSpace(BaseActionSpace):
"""Custom action space for my robot."""
dim_action = 15 # Your robot's action dimension
gripper_idx = (7, 14) # Gripper channel indices
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
"""Define your loss computation."""
# Example: MSE for joints, BCE for grippers
joints_loss = self.mse(pred[:, :, :7], target[:, :, :7])
gripper_loss = self.bce(pred[:, :, self.gripper_idx],
target[:, :, self.gripper_idx])
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Preprocess actions before training."""
# Example: Zero out grippers in proprioception
proprio_m = proprio.clone()
action_m = action.clone() if action is not None else None
proprio_m[..., self.gripper_idx] = 0.0
if action_m is not None:
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action):
"""Post-process predictions for deployment."""
# Example: Apply sigmoid to gripper logits
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
```
### Step 2: Use Your Custom Action Mode
```bash
lerobot-train \
--policy.action_mode=my_custom_robot \
--dataset.repo_id=YOUR_DATASET \
--policy.path="lerobot/xvla-base" \
...
```
## Advanced Topics
### Multi-Camera Support
X-VLA supports multiple camera views through the `num_image_views` parameter:
```python
# Configure for 3 camera views
policy.num_image_views=3
# Add empty cameras if you have fewer physical cameras
policy.empty_cameras=1 # Adds 1 zero-padded camera view
```
### Custom Preprocessing Pipeline
Create a custom preprocessing pipeline for your environment:
```python
from lerobot.processor import PolicyProcessorPipeline
from lerobot.policies.xvla.processor_xvla import (
XVLAImageToFloatProcessorStep,
XVLAImageNetNormalizeProcessorStep,
XVLAAddDomainIdProcessorStep,
)
# Build custom pipeline
preprocessor = PolicyProcessorPipeline(
steps=[
YourCustomProcessorStep(), # Your custom processing
XVLAImageToFloatProcessorStep(), # Required: convert to float
XVLAImageNetNormalizeProcessorStep(), # Required: ImageNet norm
XVLAAddDomainIdProcessorStep(domain_id=5), # Your domain ID
]
)
```
### Handling Different Action Dimensions
When your dataset has fewer action dimensions than the pretrained model:
**Option 1 (Recommended)**: Use `auto` action mode
```bash
# Automatically detects your dataset's action dimension
# Works with any robot without custom code
policy.action_mode=auto
policy.max_action_dim=20 # Match pretrained model
```
**Option 2**: Use a predefined action mode with built-in padding
```python
# Model expects 20D, dataset has 12D
# Action mode handles padding internally
action_mode = "so101_bimanual" # Pads 12 → 20
```
**Option 2**: Create a custom action mode that maps dimensions explicitly
```python
@register_action("my_mapped_action")
class MappedActionSpace(BaseActionSpace):
dim_action = 20
REAL_DIM = 12
def _pad_to_model_dim(self, x):
# Custom padding logic
...
```
## Troubleshooting
### Common Issues
**Issue**: "Action dimension mismatch"
- **Solution**: Check that your `action_mode` matches your robot's action space. Create a custom action mode if needed.
**Issue**: "Image values outside [0, 1] range"
- **Solution**: Ensure images are preprocessed with `XVLAImageToFloatProcessorStep` before normalization.
**Issue**: "Domain ID not found"
- **Solution**: Make sure `XVLAAddDomainIdProcessorStep` is in your preprocessing pipeline with the correct domain_id.
**Issue**: "Low success rate on new embodiment"
- **Solution**:
1. Verify your action_mode is correct
2. Check that soft prompts are being trained (`train_soft_prompts=True`)
3. Ensure proper preprocessing (ImageNet normalization, domain_id)
4. Consider increasing training steps
**Issue**: "Out of memory during training"
- **Solution**:
1. Reduce `chunk_size` (e.g., from 32 to 16)
2. Enable gradient checkpointing
3. Reduce batch size
4. Freeze more components
## Citation
If you use X-VLA in your research, please cite:
```bibtex
@article{zheng2025x,
title = {X-VLA: Soft-Prompted Transformer as Scalable Cross-Embodiment Vision-Language-Action Model},
author = {Zheng, Jinliang and Li, Jianxiong and Wang, Zhihao and Liu, Dongxiu and Kang, Xirui
and Feng, Yuchun and Zheng, Yinan and Zou, Jiayin and Chen, Yilun and Zeng, Jia and others},
journal = {arXiv preprint arXiv:2510.10274},
year = {2025}
}
```
## Additional Resources
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
## Contributing
We welcome contributions! If you've implemented a new action mode or processor for your robot, please consider submitting a PR to help the community.

View File

@@ -45,7 +45,7 @@ from lerobot.robots import ( # noqa: F401
so101_follower,
)
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
log_say,
@@ -97,7 +97,7 @@ def replay(cfg: ReplayConfig):
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
precise_sleep(1 / dataset.fps - dt_s)
busy_wait(1 / dataset.fps - dt_s)
robot.disconnect()

View File

@@ -34,106 +34,105 @@ from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")
pprint(lerobot.available_datasets)
def main():
# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")
pprint(lerobot.available_datasets)
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
pprint(repo_ids)
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
pprint(repo_ids)
# Or simply explore them in your web browser directly at:
# https://huggingface.co/datasets?other=LeRobot
# Or simply explore them in your web browser directly at:
# https://huggingface.co/datasets?other=LeRobot
# Let's take this one for this example
repo_id = "lerobot/aloha_mobile_cabinet"
# We can have a look and fetch its metadata to know more about it:
ds_meta = LeRobotDatasetMetadata(repo_id)
# Let's take this one for this example
repo_id = "lerobot/aloha_mobile_cabinet"
# We can have a look and fetch its metadata to know more about it:
ds_meta = LeRobotDatasetMetadata(repo_id)
# By instantiating just this class, you can quickly access useful information about the content and the
# structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}")
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
# By instantiating just this class, you can quickly access useful information about the content and the
# structure of the dataset without downloading the actual data yet (only metadata files — which are
# lightweight).
print(f"Total number of episodes: {ds_meta.total_episodes}")
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
print(f"Frames per second used during data collection: {ds_meta.fps}")
print(f"Robot type: {ds_meta.robot_type}")
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
print("Tasks:")
print(ds_meta.tasks)
print("Features:")
pprint(ds_meta.features)
print("Tasks:")
print(ds_meta.tasks)
print("Features:")
pprint(ds_meta.features)
# You can also get a short summary by simply printing the object:
print(ds_meta)
# You can also get a short summary by simply printing the object:
print(ds_meta)
# You can then load the actual dataset from the hub.
# Either load any subset of episodes:
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
# You can then load the actual dataset from the hub.
# Either load any subset of episodes:
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
# And see how many frames you have:
print(f"Selected episodes: {dataset.episodes}")
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")
# And see how many frames you have:
print(f"Selected episodes: {dataset.episodes}")
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")
# Or simply load the entire dataset:
dataset = LeRobotDataset(repo_id)
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")
# Or simply load the entire dataset:
dataset = LeRobotDataset(repo_id)
print(f"Number of episodes selected: {dataset.num_episodes}")
print(f"Number of frames selected: {dataset.num_frames}")
# The previous metadata class is contained in the 'meta' attribute of the dataset:
print(dataset.meta)
# The previous metadata class is contained in the 'meta' attribute of the dataset:
print(dataset.meta)
# LeRobotDataset actually wraps an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets for more information).
print(dataset.hf_dataset)
# LeRobotDataset actually wraps an underlying Hugging Face dataset
# (see https://huggingface.co/docs/datasets for more information).
print(dataset.hf_dataset)
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
# frame indices associated to the first episode:
episode_index = 0
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
# frame indices associated to the first episode:
episode_index = 0
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Then we grab all the image frames from the first camera:
camera_key = dataset.meta.camera_keys[0]
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
# Then we grab all the image frames from the first camera:
camera_key = dataset.meta.camera_keys[0]
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
# The objects returned by the dataset are all torch.Tensors
print(type(frames[0]))
print(frames[0].shape)
# The objects returned by the dataset are all torch.Tensors
print(type(frames[0]))
print(frames[0].shape)
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
# We can compare this shape with the information available for that feature
pprint(dataset.features[camera_key])
# In particular:
print(dataset.features[camera_key]["shape"])
# The shape is in (h, w, c) which is a more universal format.
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
# We can compare this shape with the information available for that feature
pprint(dataset.features[camera_key])
# In particular:
print(dataset.features[camera_key]["shape"])
# The shape is in (h, w, c) which is a more universal format.
# For many machine learning applications we need to load the history of past observations or trajectories of
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
camera_key: [-1, -0.5, -0.20, 0],
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
# timestamp, you still get a valid timestamp.
# For many machine learning applications we need to load the history of past observations or trajectories of
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
camera_key: [-1, -0.5, -0.20, 0],
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
# timestamp, you still get a valid timestamp.
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
if __name__ == "__main__":
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
@@ -145,7 +144,3 @@ def main():
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break
if __name__ == "__main__":
main()

View File

@@ -33,68 +33,83 @@ TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
# Create the robot configuration & robot
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
def main():
# Create the robot configuration & robot
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
robot = LeKiwiClient(robot_config)
robot = LeKiwiClient(robot_config)
# Create policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Create policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features}
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
# Build Policy Processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
# Connect the robot
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
robot.connect()
# TODO(Steven): Update this example to use pipelines
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Build Policy Processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
# Connect the robot
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
robot.connect()
# TODO(Steven): Update this example to use pipelines
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
# Main record loop
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -103,42 +118,21 @@ def main():
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
recorded_episodes += 1
# Save episode
dataset.save_episode()
recorded_episodes += 1
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
main()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -34,62 +34,78 @@ RESET_TIME_SEC = 10
TASK_DESCRIPTION = "My task description"
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
# Create the robot and teleoperator configurations
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
keyboard_config = KeyboardTeleopConfig()
def main():
# Create the robot and teleoperator configurations
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
keyboard_config = KeyboardTeleopConfig()
# Initialize the robot and teleoperator
robot = LeKiwiClient(robot_config)
leader_arm = SO100Leader(leader_arm_config)
keyboard = KeyboardTeleop(keyboard_config)
# Initialize the robot and teleoperator
robot = LeKiwiClient(robot_config)
leader_arm = SO100Leader(leader_arm_config)
keyboard = KeyboardTeleop(keyboard_config)
# TODO(Steven): Update this example to use pipelines
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# TODO(Steven): Update this example to use pipelines
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features}
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
# Connect the robot and teleoperator
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
robot.connect()
leader_arm.connect()
keyboard.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_record")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
# Connect the robot and teleoperator
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
robot.connect()
leader_arm.connect()
keyboard.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="lekiwi_record")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
recorded_episodes = 0
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {recorded_episodes}")
# Main record loop
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
dataset=dataset,
teleop=[leader_arm, keyboard],
control_time_s=EPISODE_TIME_SEC,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
@@ -97,45 +113,23 @@ def main():
robot_observation_processor=robot_observation_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=[leader_arm, keyboard],
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=teleop_action_processor,
robot_action_processor=robot_action_processor,
robot_observation_processor=robot_observation_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
recorded_episodes += 1
# Save episode
dataset.save_episode()
recorded_episodes += 1
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
robot.disconnect()
leader_arm.disconnect()
keyboard.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
main()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -20,48 +20,42 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
# Initialize the robot config
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
def main():
# Initialize the robot config
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
# Initialize the robot
robot = LeKiwiClient(robot_config)
# Initialize the robot
robot = LeKiwiClient(robot_config)
# Fetch the dataset to replay
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
# Fetch the dataset to replay
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
# Connect to the robot
robot.connect()
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Send action to robot
_ = robot.send_action(action)
# Send action to robot
_ = robot.send_action(action)
busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
robot.disconnect()
if __name__ == "__main__":
main()
robot.disconnect()

View File

@@ -19,60 +19,54 @@ import time
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
# Create the robot and teleoperator configurations
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard")
def main():
# Create the robot and teleoperator configurations
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard")
# Initialize the robot and teleoperator
robot = LeKiwiClient(robot_config)
leader_arm = SO100Leader(teleop_arm_config)
keyboard = KeyboardTeleop(keyboard_config)
# Initialize the robot and teleoperator
robot = LeKiwiClient(robot_config)
leader_arm = SO100Leader(teleop_arm_config)
keyboard = KeyboardTeleop(keyboard_config)
# Connect to the robot and teleoperator
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
robot.connect()
leader_arm.connect()
keyboard.connect()
# Connect to the robot and teleoperator
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
robot.connect()
leader_arm.connect()
keyboard.connect()
# Init rerun viewer
init_rerun(session_name="lekiwi_teleop")
# Init rerun viewer
init_rerun(session_name="lekiwi_teleop")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting teleop loop...")
while True:
t0 = time.perf_counter()
print("Starting teleop loop...")
while True:
t0 = time.perf_counter()
# Get robot observation
observation = robot.get_observation()
# Get robot observation
observation = robot.get_observation()
# Get teleop action
# Arm
arm_action = leader_arm.get_action()
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
# Keyboard
keyboard_keys = keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
# Get teleop action
# Arm
arm_action = leader_arm.get_action()
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
# Keyboard
keyboard_keys = keyboard.get_action()
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
# Send action to robot
_ = robot.send_action(action)
# Send action to robot
_ = robot.send_action(action)
# Visualize
log_rerun_data(observation=observation, action=action)
# Visualize
log_rerun_data(observation=observation, action=action)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
if __name__ == "__main__":
main()
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))

View File

@@ -52,114 +52,125 @@ TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
def main():
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem58760434471",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
robot = SO100Follower(robot_config)
robot = SO100Follower(robot_config)
# Create policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Create policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert joints observation to EE observation
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
)
],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose_processor,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=True,
),
# User for now should be explicit on the feature keys that were used for record
# Alternatively, the user can pass the processor step that has the right features
aggregate_pipeline_dataset_features(
pipeline=make_default_teleop_action_processor(),
initial_features=create_initial_features(
action={
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
}
),
use_videos=True,
),
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert joints observation to EE observation
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose_processor,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=True,
),
# User for now should be explicit on the feature keys that were used for record
# Alternatively, the user can pass the processor step that has the right features
aggregate_pipeline_dataset_features(
pipeline=make_default_teleop_action_processor(),
initial_features=create_initial_features(
action={
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
}
),
use_videos=True,
),
),
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Build Policy Processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
# Connect the robot
robot.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Build Policy Processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
# Connect the robot
robot.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -168,40 +179,21 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
episode_idx += 1
# Save episode
dataset.save_episode()
episode_idx += 1
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
main()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -50,122 +50,133 @@ RESET_TIME_SEC = 30
TASK_DESCRIPTION = "My task description"
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
# Create the robot and teleoperator configurations
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
def main():
# Create the robot and teleoperator configurations
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
phone = Phone(teleop_config)
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
phone = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to EE action
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
tuple[RobotAction, RobotObservation], RobotAction
](
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
use_latched_reference=True,
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.20,
),
GripperVelocityToJoint(speed_factor=20.0),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert joint observation to EE observation
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
)
],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose_processor,
initial_features=create_initial_features(action=phone.action_features),
use_videos=True,
),
aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=True,
),
# Build pipeline to convert phone action to EE action
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
use_latched_reference=True,
),
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.20,
),
GripperVelocityToJoint(speed_factor=20.0),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert joint observation to EE observation
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=phone_to_robot_ee_pose_processor,
initial_features=create_initial_features(action=phone.action_features),
use_videos=True,
),
aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=True,
),
),
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Connect the robot and teleoperator
robot.connect()
phone.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_record")
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
# Connect the robot and teleoperator
robot.connect()
phone.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="phone_so100_record")
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop. Move your phone to teleoperate the robot...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
@@ -173,42 +184,22 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
teleop=phone,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=phone_to_robot_ee_pose_processor,
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
episode_idx += 1
# Save episode
dataset.save_episode()
episode_idx += 1
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
robot.disconnect()
phone.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
main()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -29,78 +29,72 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
# Initialize the robot config
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
)
def main():
# Initialize the robot config
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
)
# Initialize the robot
robot = SO100Follower(robot_config)
# Initialize the robot
robot = SO100Follower(robot_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=False, # Because replay is open loop
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=False, # Because replay is open loop
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Fetch the dataset to replay
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
# Fetch the dataset to replay
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
# Connect to the robot
robot.connect()
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
# Clean up
robot.disconnect()
if __name__ == "__main__":
main()
# Clean up
robot.disconnect()

View File

@@ -32,90 +32,82 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
# Initialize the robot and teleoperator
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
def main():
# Initialize the robot and teleoperator
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
)
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop_device = Phone(teleop_config)
# Initialize the robot and teleoperator
robot = SO100Follower(robot_config)
teleop_device = Phone(teleop_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert phone action to ee pose action to joint action
phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
use_latched_reference=True,
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
),
GripperVelocityToJoint(
speed_factor=20.0,
),
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert phone action to ee pose action to joint action
phone_to_robot_joints_processor = RobotProcessorPipeline[
tuple[RobotAction, RobotObservation], RobotAction
](
steps=[
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
motor_names=list(robot.bus.motors.keys()),
use_latched_reference=True,
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
),
GripperVelocityToJoint(
speed_factor=20.0,
),
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Connect to the robot and teleoperator
robot.connect()
teleop_device.connect()
# Connect to the robot and teleoperator
robot.connect()
teleop_device.connect()
# Init rerun viewer
init_rerun(session_name="phone_so100_teleop")
# Init rerun viewer
init_rerun(session_name="phone_so100_teleop")
if not robot.is_connected or not teleop_device.is_connected:
raise ValueError("Robot or teleop is not connected!")
if not robot.is_connected or not teleop_device.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting teleop loop. Move your phone to teleoperate the robot...")
while True:
t0 = time.perf_counter()
print("Starting teleop loop. Move your phone to teleoperate the robot...")
while True:
t0 = time.perf_counter()
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Get teleop action
phone_obs = teleop_device.get_action()
# Get teleop action
phone_obs = teleop_device.get_action()
# Phone -> EE pose -> Joints transition
joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs))
# Phone -> EE pose -> Joints transition
joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
# Visualize
log_rerun_data(observation=phone_obs, action=joint_action)
# Visualize
log_rerun_data(observation=phone_obs, action=joint_action)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
if __name__ == "__main__":
main()
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))

View File

@@ -142,7 +142,7 @@ def _check_matplotlib_available():
raise ImportError(
"matplotlib is required for RTC debug visualizations. "
"Please install it by running:\n"
" uv pip install matplotlib"
" uv pip install -e '.[matplotlib-dep]'"
)

View File

@@ -52,114 +52,126 @@ TASK_DESCRIPTION = "My task description"
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
def main():
# Create the robot configuration & robot
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
robot = SO100Follower(robot_config)
robot = SO100Follower(robot_config)
# Create policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Create policy
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert joints observation to EE observation
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
)
],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose_processor,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=True,
),
# User for now should be explicit on the feature keys that were used for record
# Alternatively, the user can pass the processor step that has the right features
aggregate_pipeline_dataset_features(
pipeline=make_default_teleop_action_processor(),
initial_features=create_initial_features(
action={
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
}
),
use_videos=True,
),
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=True,
),
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert joints observation to EE observation
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_DATASET_ID,
fps=FPS,
features=combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_pose_processor,
initial_features=create_initial_features(observation=robot.observation_features),
use_videos=True,
),
# User for now should be explicit on the feature keys that were used for record
# Alternatively, the user can pass the processor step that has the right features
aggregate_pipeline_dataset_features(
pipeline=make_default_teleop_action_processor(),
initial_features=create_initial_features(
action={
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
}
),
use_videos=True,
),
),
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Build Policy Processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
# Connect the robot and teleoperator
robot.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="so100_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Build Policy Processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
# Connect the robot and teleoperator
robot.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="so100_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting evaluate loop...")
episode_idx = 0
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
preprocessor=preprocessor, # Pass the pre and post policy processors
postprocessor=postprocessor,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
@@ -168,40 +180,21 @@ def main():
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=robot,
events=events,
fps=FPS,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_pose_processor,
)
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
episode_idx += 1
# Save episode
dataset.save_episode()
episode_idx += 1
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
robot.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
main()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -48,122 +48,134 @@ RESET_TIME_SEC = 30
TASK_DESCRIPTION = "My task description"
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
# Create the robot and teleoperator configurations
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
follower_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", cameras=camera_config, use_degrees=True
)
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
def main():
# Create the robot and teleoperator configurations
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
follower_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411",
id="my_awesome_follower_arm",
cameras=camera_config,
use_degrees=True,
)
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
# Initialize the robot and teleoperator
follower = SO100Follower(follower_config)
leader = SO100Leader(leader_config)
# Initialize the robot and teleoperator
follower = SO100Follower(follower_config)
leader = SO100Leader(leader_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
follower_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(follower.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
follower_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(follower.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
leader_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(leader.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
leader_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(leader.bus.motors.keys()),
)
# Build pipeline to convert follower joints to EE observation
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
),
],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# Build pipeline to convert leader joints to EE action
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
ForwardKinematicsJointsToEE(
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to follower joints
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
[
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
),
InverseKinematicsEEToJoints(
kinematics=follower_kinematics_solver,
motor_names=list(follower.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=leader_joints_to_ee,
initial_features=create_initial_features(action=leader.action_features),
use_videos=True,
),
aggregate_pipeline_dataset_features(
pipeline=follower_joints_to_ee,
initial_features=create_initial_features(observation=follower.observation_features),
use_videos=True,
),
# Build pipeline to convert follower joints to EE observation
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
),
robot_type=follower.name,
use_videos=True,
image_writer_threads=4,
],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# Build pipeline to convert leader joints to EE action
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
ForwardKinematicsJointsToEE(
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to follower joints
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
[
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
),
InverseKinematicsEEToJoints(
kinematics=follower_kinematics_solver,
motor_names=list(follower.bus.motors.keys()),
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Create the dataset
dataset = LeRobotDataset.create(
repo_id=HF_REPO_ID,
fps=FPS,
features=combine_feature_dicts(
# Run the feature contract of the pipelines
# This tells you how the features would look like after the pipeline steps
aggregate_pipeline_dataset_features(
pipeline=leader_joints_to_ee,
initial_features=create_initial_features(action=leader.action_features),
use_videos=True,
),
aggregate_pipeline_dataset_features(
pipeline=follower_joints_to_ee,
initial_features=create_initial_features(observation=follower.observation_features),
use_videos=True,
),
),
robot_type=follower.name,
use_videos=True,
image_writer_threads=4,
)
# Connect the robot and teleoperator
leader.connect()
follower.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="recording_phone")
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
# Connect the robot and teleoperator
leader.connect()
follower.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
init_rerun(session_name="recording_phone")
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
print("Starting record loop...")
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
# Main record loop
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
@@ -171,42 +183,22 @@ def main():
robot_observation_processor=follower_joints_to_ee,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
record_loop(
robot=follower,
events=events,
fps=FPS,
teleop=leader,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=leader_joints_to_ee,
robot_action_processor=ee_to_follower_joints,
robot_observation_processor=follower_joints_to_ee,
)
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
if events["rerecord_episode"]:
log_say("Re-recording episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Save episode
dataset.save_episode()
episode_idx += 1
# Save episode
dataset.save_episode()
episode_idx += 1
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
# Clean up
log_say("Stop recording")
leader.disconnect()
follower.disconnect()
listener.stop()
dataset.finalize()
dataset.push_to_hub()
if __name__ == "__main__":
main()
dataset.finalize()
dataset.push_to_hub()

View File

@@ -30,78 +30,72 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
EPISODE_IDX = 0
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
# Initialize the robot config
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
)
def main():
# Initialize the robot config
robot_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
)
# Initialize the robot
robot = SO100Follower(robot_config)
# Initialize the robot
robot = SO100Follower(robot_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(robot.bus.motors.keys()),
)
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=False, # Because replay is open loop
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert EE action to joints action
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=list(robot.bus.motors.keys()),
initial_guess_current_joints=False, # Because replay is open loop
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Fetch the dataset to replay
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
# Fetch the dataset to replay
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns(ACTION)
# Connect to the robot
robot.connect()
# Connect to the robot
robot.connect()
if not robot.is_connected:
raise ValueError("Robot is not connected!")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
print("Starting replay loop...")
log_say(f"Replaying episode {EPISODE_IDX}")
for idx in range(len(episode_frames)):
t0 = time.perf_counter()
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get recorded action from dataset
ee_action = {
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
robot_obs = robot.get_observation()
# Get robot observation
robot_obs = robot.get_observation()
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Dataset EE -> robot joints
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
# Send action to robot
_ = robot.send_action(joint_action)
# Send action to robot
_ = robot.send_action(joint_action)
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
# Clean up
robot.disconnect()
if __name__ == "__main__":
main()
# Clean up
robot.disconnect()

View File

@@ -32,96 +32,90 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
# Initialize the robot and teleoperator config
follower_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
)
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
def main():
# Initialize the robot and teleoperator config
follower_config = SO100FollowerConfig(
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
)
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
# Initialize the robot and teleoperator
follower = SO100Follower(follower_config)
leader = SO100Leader(leader_config)
# Initialize the robot and teleoperator
follower = SO100Follower(follower_config)
leader = SO100Leader(leader_config)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
follower_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(follower.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
follower_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(follower.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
leader_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(leader.bus.motors.keys()),
)
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
leader_kinematics_solver = RobotKinematics(
urdf_path="./SO101/so101_new_calib.urdf",
target_frame_name="gripper_frame_link",
joint_names=list(leader.bus.motors.keys()),
)
# Build pipeline to convert teleop joints to EE action
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
steps=[
ForwardKinematicsJointsToEE(
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
),
],
to_transition=robot_action_to_transition,
to_output=transition_to_robot_action,
)
# Build pipeline to convert teleop joints to EE action
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
steps=[
ForwardKinematicsJointsToEE(
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
),
],
to_transition=robot_action_to_transition,
to_output=transition_to_robot_action,
)
# build pipeline to convert EE action to robot joints
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
[
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
),
InverseKinematicsEEToJoints(
kinematics=follower_kinematics_solver,
motor_names=list(follower.bus.motors.keys()),
initial_guess_current_joints=False,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# build pipeline to convert EE action to robot joints
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
[
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
),
InverseKinematicsEEToJoints(
kinematics=follower_kinematics_solver,
motor_names=list(follower.bus.motors.keys()),
initial_guess_current_joints=False,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# Connect to the robot and teleoperator
follower.connect()
leader.connect()
# Connect to the robot and teleoperator
follower.connect()
leader.connect()
# Init rerun viewer
init_rerun(session_name="so100_so100_EE_teleop")
# Init rerun viewer
init_rerun(session_name="so100_so100_EE_teleop")
print("Starting teleop loop...")
while True:
t0 = time.perf_counter()
print("Starting teleop loop...")
while True:
t0 = time.perf_counter()
# Get robot observation
robot_obs = follower.get_observation()
# Get robot observation
robot_obs = follower.get_observation()
# Get teleop observation
leader_joints_obs = leader.get_action()
# Get teleop observation
leader_joints_obs = leader.get_action()
# teleop joints -> teleop EE action
leader_ee_act = leader_to_ee(leader_joints_obs)
# teleop joints -> teleop EE action
leader_ee_act = leader_to_ee(leader_joints_obs)
# teleop EE -> robot joints
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
# teleop EE -> robot joints
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
# Send action to robot
_ = follower.send_action(follower_joints_act)
# Send action to robot
_ = follower.send_action(follower_joints_act)
# Visualize
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
# Visualize
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
if __name__ == "__main__":
main()
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))

View File

@@ -19,86 +19,80 @@ def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[flo
return [i / fps for i in delta_indices]
def main():
output_directory = Path("outputs/robot_learning_tutorial/act")
output_directory.mkdir(parents=True, exist_ok=True)
output_directory = Path("outputs/robot_learning_tutorial/act")
output_directory.mkdir(parents=True, exist_ok=True)
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
dataset_id = "lerobot/svla_so101_pickplace"
dataset_id = "lerobot/svla_so101_pickplace"
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
cfg = ACTConfig(input_features=input_features, output_features=output_features)
policy = ACTPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
cfg = ACTConfig(input_features=input_features, output_features=output_features)
policy = ACTPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
policy.train()
policy.to(device)
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
for k in cfg.image_features
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save all assets to the Hub
policy.push_to_hub("<user>/robot_learning_tutorial_act")
preprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
postprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
if __name__ == "__main__":
main()
# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")

View File

@@ -8,56 +8,50 @@ from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_act"
model = ACTPolicy.from_pretrained(model_id)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
def main():
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "<user>/robot_learning_tutorial_act"
model = ACTPolicy.from_pretrained(model_id)
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
obs = preprocess(obs_frame)
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
action = model.select_action(obs)
action = postprocess(action)
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
action = make_robot_action(action, dataset_metadata.features)
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
robot.send_action(action)
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")
if __name__ == "__main__":
main()
print("Episode finished! Starting new episode...")

View File

@@ -1,17 +1,11 @@
from lerobot.async_inference.configs import PolicyServerConfig
from lerobot.async_inference.policy_server import serve
host = ... # something like "127.0.0.1" if you're exposing to localhost
port = ... # something like 8080
def main():
host = ... # something like "127.0.0.1" if you're exposing to localhost
port = ... # something like 8080
config = PolicyServerConfig(
host=host,
port=port,
)
serve(config)
if __name__ == "__main__":
main()
config = PolicyServerConfig(
host=host,
port=port,
)
serve(config)

View File

@@ -6,56 +6,50 @@ from lerobot.async_inference.robot_client import RobotClient
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.robots.so100_follower import SO100FollowerConfig
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
# check the config.json on the Hub for the policy you are using to see the expected camera specs
camera_cfg = {
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
def main():
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
# check the config.json on the Hub for the policy you are using to see the expected camera specs
camera_cfg = {
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
server_address = ... # something like "127.0.0.1:8080" if using localhost
server_address = ... # something like "127.0.0.1:8080" if using localhost
# 3. Create client configuration
client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address=server_address,
policy_device="mps",
policy_type="act",
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g
actions_per_chunk=50, # make sure this is less than the max actions of the policy
)
# 3. Create client configuration
client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address=server_address,
policy_device="mps",
policy_type="act",
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g
actions_per_chunk=50, # make sure this is less than the max actions of the policy
)
# 4. Create and start client
client = RobotClient(client_cfg)
# 4. Create and start client
client = RobotClient(client_cfg)
# 5. Provide a textual description of the task
task = ...
# 5. Provide a textual description of the task
task = ...
if client.start():
# Start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
action_receiver_thread.start()
if client.start():
# Start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
action_receiver_thread.start()
try:
# Run the control loop
client.control_loop(task)
except KeyboardInterrupt:
client.stop()
action_receiver_thread.join()
# (Optionally) plot the action queue size
visualize_action_queue_size(client.action_queue_size)
if __name__ == "__main__":
main()
try:
# Run the control loop
client.control_loop(task)
except KeyboardInterrupt:
client.stop()
action_receiver_thread.join()
# (Optionally) plot the action queue size
visualize_action_queue_size(client.action_queue_size)

View File

@@ -19,87 +19,81 @@ def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[flo
return [i / fps for i in delta_indices]
def main():
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
dataset_id = "lerobot/svla_so101_pickplace"
dataset_id = "lerobot/svla_so101_pickplace"
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
policy = DiffusionPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
policy = DiffusionPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
policy.train()
policy.to(device)
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
for k in cfg.image_features
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save all assets to the Hub
policy.push_to_hub("<user>/robot_learning_tutorial_diffusion")
preprocessor.push_to_hub("<user>/robot_learning_tutorial_diffusion")
postprocessor.push_to_hub("<user>/robot_learning_tutorial_diffusion")
if __name__ == "__main__":
main()
# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")

View File

@@ -8,57 +8,53 @@ from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_diffusion"
model = DiffusionPolicy.from_pretrained(model_id)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(
model.config, model_id, dataset_stats=dataset_metadata.stats
)
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
def main():
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "<user>/robot_learning_tutorial_diffusion"
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
model = DiffusionPolicy.from_pretrained(model_id)
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(
model.config, model_id, dataset_stats=dataset_metadata.stats
)
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
if __name__ == "__main__":
main()
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")

View File

@@ -11,63 +11,57 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/pi0_base"
def main():
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/pi0_base"
model = PI0Policy.from_pretrained(model_id)
model = PI0Policy.from_pretrained(model_id)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
}
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
obs = preprocess(obs_frame)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
print("Episode finished! Starting new episode...")
if __name__ == "__main__":
main()
print("Episode finished! Starting new episode...")

View File

@@ -20,8 +20,6 @@ from lerobot.teleoperators.utils import TeleopEvents
LOG_EVERY = 10
SEND_EVERY = 10
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
def run_learner(
@@ -225,123 +223,123 @@ def make_policy_obs(obs, device: torch.device = "cpu"):
}
def main():
"""Main function - coordinates actor and learner processes."""
"""Main function - coordinates actor and learner processes."""
device = "mps" # or "cuda" or "cpu"
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
output_directory.mkdir(parents=True, exist_ok=True)
device = "mps" # or "cuda" or "cpu"
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
output_directory.mkdir(parents=True, exist_ok=True)
# find ports using lerobot-find-port
follower_port = ...
leader_port = ...
# find ports using lerobot-find-port
follower_port = ...
leader_port = ...
# the robot ids are used the load the right calibration files
follower_id = ...
leader_id = ...
# the robot ids are used the load the right calibration files
follower_id = ...
leader_id = ...
# A pretrained model (to be used in-distribution!)
reward_classifier_id = "<user>/reward_classifier_hil_serl_example"
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
# A pretrained model (to be used in-distribution!)
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
reward_classifier.to(device)
reward_classifier.eval()
reward_classifier.to(device)
reward_classifier.eval()
# Robot and environment configuration
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
# Robot and environment configuration
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
# Create robot environment
env, teleop_device = make_robot_env(env_cfg)
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
action_features = hw_to_dataset_features(env.robot.action_features, "action")
# Create robot environment
env, teleop_device = make_robot_env(env_cfg)
# Create SAC policy for action selection
policy_cfg = SACConfig(
device=device,
input_features=obs_features,
output_features=action_features,
)
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
action_features = hw_to_dataset_features(env.robot.action_features, "action")
policy_actor = SACPolicy(policy_cfg)
policy_learner = SACPolicy(policy_cfg)
# Create SAC policy for action selection
policy_cfg = SACConfig(
device=device,
input_features=obs_features,
output_features=action_features,
)
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
policy_actor = SACPolicy(policy_cfg)
policy_learner = SACPolicy(policy_cfg)
# Online buffer: initialized from scratch
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
)
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
# Create communication channels between learner and actor processes
transitions_queue = mp.Queue(maxsize=10)
parameters_queue = mp.Queue(maxsize=2)
shutdown_event = mp.Event()
# Online buffer: initialized from scratch
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
)
# Signal handler for graceful shutdown
def signal_handler(sig):
print(f"\nSignal {sig} received, shutting down...")
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Create processes
learner_process = mp.Process(
target=run_learner,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_learner,
online_replay_buffer,
offline_replay_buffer,
),
kwargs={"device": device}, # can run on accelerated hardware for training
)
actor_process = mp.Process(
target=run_actor,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_actor,
reward_classifier,
env_cfg,
output_directory,
),
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
)
learner_process.start()
actor_process.start()
try:
# Wait for actor to finish (it controls the episode loop)
actor_process.join()
shutdown_event.set()
learner_process.join(timeout=10)
except KeyboardInterrupt:
print("Main process interrupted")
shutdown_event.set()
actor_process.join(timeout=5)
learner_process.join(timeout=10)
finally:
if learner_process.is_alive():
learner_process.terminate()
if actor_process.is_alive():
actor_process.terminate()
# Create communication channels between learner and actor processes
transitions_queue = mp.Queue(maxsize=10)
parameters_queue = mp.Queue(maxsize=2)
shutdown_event = mp.Event()
if __name__ == "__main__":
main()
# Signal handler for graceful shutdown
def signal_handler(sig):
print(f"\nSignal {sig} received, shutting down...")
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Create processes
learner_process = mp.Process(
target=run_learner,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_learner,
online_replay_buffer,
offline_replay_buffer,
),
kwargs={"device": device}, # can run on accelerated hardware for training
)
actor_process = mp.Process(
target=run_actor,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_actor,
reward_classifier,
env_cfg,
output_directory,
),
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
)
learner_process.start()
actor_process.start()
try:
# Wait for actor to finish (it controls the episode loop)
actor_process.join()
shutdown_event.set()
learner_process.join(timeout=10)
except KeyboardInterrupt:
print("Main process interrupted")
shutdown_event.set()
actor_process.join(timeout=5)
learner_process.join(timeout=10)
finally:
if learner_process.is_alive():
learner_process.terminate()
if actor_process.is_alive():
actor_process.terminate()

View File

@@ -4,64 +4,59 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
# Device to use for training
device = "mps" # or "cuda", or "cpu"
def main():
# Device to use for training
device = "mps" # or "cuda", or "cpu"
# Load the dataset used for training
repo_id = "lerobot/example_hil_serl_dataset"
dataset = LeRobotDataset(repo_id)
# Load the dataset used for training
repo_id = "lerobot/example_hil_serl_dataset"
dataset = LeRobotDataset(repo_id)
# Configure the policy to extract features from the image frames
camera_keys = dataset.meta.camera_keys
# Configure the policy to extract features from the image frames
camera_keys = dataset.meta.camera_keys
config = RewardClassifierConfig(
num_cameras=len(camera_keys),
device=device,
# backbone model to extract features from the image frames
model_name="microsoft/resnet-18",
)
config = RewardClassifierConfig(
num_cameras=len(camera_keys),
device=device,
# backbone model to extract features from the image frames
model_name="microsoft/resnet-18",
)
# Make policy, preprocessor, and optimizer
policy = make_policy(config, ds_meta=dataset.meta)
optimizer = config.get_optimizer_preset().build(policy.parameters())
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
classifier_id = "<user>/reward_classifier_hil_serl_example"
# Instantiate a dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
total_loss = 0
total_accuracy = 0
for batch in dataloader:
# Preprocess the batch and move it to the correct device.
batch = preprocessor(batch)
# Forward pass
loss, output_dict = policy.forward(batch)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_accuracy += output_dict["accuracy"]
avg_loss = total_loss / len(dataloader)
avg_accuracy = total_accuracy / len(dataloader)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
print("Training finished!")
# You can now save the trained policy.
policy.push_to_hub(classifier_id)
# Make policy, preprocessor, and optimizer
policy = make_policy(config, ds_meta=dataset.meta)
optimizer = config.get_optimizer_preset().build(policy.parameters())
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
if __name__ == "__main__":
main()
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
# Instantiate a dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
total_loss = 0
total_accuracy = 0
for batch in dataloader:
# Preprocess the batch and move it to the correct device.
batch = preprocessor(batch)
# Forward pass
loss, output_dict = policy.forward(batch)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_accuracy += output_dict["accuracy"]
avg_loss = total_loss / len(dataloader)
avg_accuracy = total_accuracy / len(dataloader)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
print("Training finished!")
# You can now save the trained policy.
policy.push_to_hub(classifier_id)

View File

@@ -11,62 +11,56 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/smolvla_base"
def main():
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/smolvla_base"
model = SmolVLAPolicy.from_pretrained(model_id)
model = SmolVLAPolicy.from_pretrained(model_id)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
obs = preprocess(obs_frame)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
print("Episode finished! Starting new episode...")
if __name__ == "__main__":
main()
print("Episode finished! Starting new episode...")

View File

@@ -1,347 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example: GR00T Locomotion with Pre-loaded Policies
This example demonstrates the NEW pattern for loading GR00T policies externally
and passing them to the robot class.
"""
import argparse
import logging
import threading
import time
from collections import deque
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
logger = logging.getLogger(__name__)
GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # hip pitch
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # knee
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # ankle pitch
MISSING_JOINTS = []
G1_MODEL = "g1_23" # or "g1_29"
if G1_MODEL == "g1_23":
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
LOCOMOTION_ACTION_SCALE = 0.25
LOCOMOTION_CONTROL_DT = 0.02
ANG_VEL_SCALE: float = 0.25
DOF_POS_SCALE: float = 1.0
DOF_VEL_SCALE: float = 0.05
CMD_SCALE: list = [2.0, 2.0, 0.25]
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
def load_groot_policies(
repo_id: str = DEFAULT_GROOT_REPO_ID,
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
Args:
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
"""
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
# Download ONNX policies from Hugging Face Hub
balance_path = hf_hub_download(
repo_id=repo_id,
filename="GR00T-WholeBodyControl-Balance.onnx",
)
walk_path = hf_hub_download(
repo_id=repo_id,
filename="GR00T-WholeBodyControl-Walk.onnx",
)
# Load ONNX policies
policy_balance = ort.InferenceSession(balance_path)
policy_walk = ort.InferenceSession(walk_path)
logger.info("GR00T policies loaded successfully")
return policy_balance, policy_walk
class GrootLocomotionController:
"""
Handles GR00T-style locomotion control for the Unitree G1 robot.
This controller manages:
- Dual-policy system (Balance + Walk)
- 29-joint observation processing
- 15D action output (legs + waist)
- Policy inference and motor command generation
"""
def __init__(self, policy_balance, policy_walk, robot, config):
self.policy_balance = policy_balance
self.policy_walk = policy_walk
self.robot = robot
self.config = config
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
# GR00T-specific state
self.groot_qj_all = np.zeros(29, dtype=np.float32)
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
self.groot_action = np.zeros(15, dtype=np.float32)
self.groot_obs_single = np.zeros(86, dtype=np.float32)
self.groot_obs_history = deque(maxlen=6)
self.groot_obs_stacked = np.zeros(516, dtype=np.float32)
self.groot_height_cmd = 0.74 # Default base height
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
# input to gr00t is 6 frames (6*86D=516)
for _ in range(6):
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
# Thread management
self.locomotion_running = False
self.locomotion_thread = None
logger.info("GrootLocomotionController initialized")
def groot_locomotion_run(self):
# get current observation
robot_state = self.robot.get_observation()
if robot_state is None:
return
# get command from remote controller
if robot_state.wireless_remote is not None:
self.robot.remote_controller.set(robot_state.wireless_remote)
if self.robot.remote_controller.button[0]: # R1 - raise waist
self.groot_height_cmd += 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
if self.robot.remote_controller.button[4]: # R2 - lower waist
self.groot_height_cmd -= 0.001
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
else:
self.robot.remote_controller.lx = 0.0
self.robot.remote_controller.ly = 0.0
self.robot.remote_controller.rx = 0.0
self.robot.remote_controller.ry = 0.0
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
for i in range(29):
self.groot_qj_all[i] = robot_state.motor_state[i].q
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
# adapt observation for g1_23dof
for idx in MISSING_JOINTS:
self.groot_qj_all[idx] = 0.0
self.groot_dqj_all[idx] = 0.0
# Scale joint positions and velocities
qj_obs = self.groot_qj_all.copy()
dqj_obs = self.groot_dqj_all.copy()
# express imu data in gravity frame of reference
quat = robot_state.imu_state.quaternion
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
gravity_orientation = self.robot.get_gravity_orientation(quat)
# scale joint positions and velocities before policy inference
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
dqj_obs = dqj_obs * DOF_VEL_SCALE
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
# build single frame observation
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
self.groot_obs_single[3] = self.groot_height_cmd
self.groot_obs_single[4:7] = self.groot_orientation_cmd
self.groot_obs_single[7:10] = ang_vel_scaled
self.groot_obs_single[10:13] = gravity_orientation
self.groot_obs_single[13:42] = qj_obs
self.groot_obs_single[42:71] = dqj_obs
self.groot_obs_single[71:86] = self.groot_action # 15D previous actions
# Add to history and stack observations (6 frames × 86D = 516D)
self.groot_obs_history.append(self.groot_obs_single.copy())
# Stack all 6 frames into 516D vector
for i, obs_frame in enumerate(self.groot_obs_history):
start_idx = i * 86
end_idx = start_idx + 86
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
# Run policy inference (ONNX) with 516D stacked observation
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
selected_policy = (
self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk
) # balance/standing policy for small commands, walking policy for movement commands
# run policy inference
ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)}
ort_outs = selected_policy.run(None, ort_inputs)
self.groot_action = ort_outs[0].squeeze()
# transform action back to target joint positions
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
# command motors
for i in range(15):
motor_idx = i
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
self.robot.msg.motor_cmd[motor_idx].tau = 0
# adapt action for g1_23dof
for joint_idx in MISSING_JOINTS:
self.robot.msg.motor_cmd[joint_idx].q = 0.0
self.robot.msg.motor_cmd[joint_idx].qd = 0
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx]
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
self.robot.msg.motor_cmd[joint_idx].tau = 0
# send action to robot
self.robot.send_action(self.robot.msg)
def _locomotion_thread_loop(self):
"""Background thread that runs the locomotion policy at specified rate."""
logger.info("Locomotion thread started")
while self.locomotion_running:
start_time = time.time()
try:
self.groot_locomotion_run()
except Exception as e:
logger.error(f"Error in locomotion loop: {e}")
# Sleep to maintain control rate
elapsed = time.time() - start_time
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
time.sleep(sleep_time)
logger.info("Locomotion thread stopped")
def start_locomotion_thread(self):
if self.locomotion_running:
logger.warning("Locomotion thread already running")
return
logger.info("Starting locomotion control thread...")
self.locomotion_running = True
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
self.locomotion_thread.start()
logger.info("Locomotion control thread started!")
def stop_locomotion_thread(self):
if not self.locomotion_running:
return
logger.info("Stopping locomotion control thread...")
self.locomotion_running = False
if self.locomotion_thread:
self.locomotion_thread.join(timeout=2.0)
logger.info("Locomotion control thread stopped")
def reset_robot(self):
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
total_time = 3.0
num_step = int(total_time / self.robot.control_dt)
# Only control legs, not arms (first 12 joints)
default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles
dof_size = len(default_pos)
# Get current lowstate
robot_state = self.robot.get_observation()
# Record the current leg positions
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
for i in range(dof_size):
init_dof_pos[i] = robot_state.motor_state[i].q
# Move legs to default pos
for i in range(num_step):
alpha = i / num_step
for motor_idx in range(dof_size):
target_pos = default_pos[motor_idx]
self.robot.msg.motor_cmd[motor_idx].q = (
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
)
self.robot.msg.motor_cmd[motor_idx].qd = 0
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
self.robot.msg.motor_cmd[motor_idx].tau = 0
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
self.robot.lowcmd_publisher.Write(self.robot.msg)
time.sleep(self.robot.control_dt)
logger.info("Reached default position (legs only)")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
parser.add_argument(
"--repo-id",
type=str,
default=DEFAULT_GROOT_REPO_ID,
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
)
args = parser.parse_args()
# load policies
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
# initialize robot
config = UnitreeG1Config()
robot = UnitreeG1(config)
# initialize gr00t locomotion controller
groot_controller = GrootLocomotionController(
policy_balance=policy_balance,
policy_walk=policy_walk,
robot=robot,
config=config,
)
# reset legs and start locomotion thread
try:
groot_controller.reset_robot()
groot_controller.start_locomotion_thread()
# log status
logger.info("Robot initialized with GR00T locomotion policies")
logger.info("Locomotion controller running in background thread")
logger.info("Press Ctrl+C to stop")
# keep robot alive
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("\nStopping locomotion...")
groot_controller.stop_locomotion_thread()
print("Done!")

View File

@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.3"
version = "0.4.2"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
@@ -107,10 +107,6 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
unitree_g1 = [
"pyzmq>=26.2.1,<28.0.0",
"onnxruntime>=1.16.0"
]
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
kinematics = ["lerobot[placo-dep]"]
intelrealsense = [
@@ -133,7 +129,6 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
@@ -162,7 +157,6 @@ all = [
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",
@@ -362,9 +356,9 @@ ignore_errors = false
# module = "lerobot.async_inference.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.transport.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.transport.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.scripts.*"

View File

@@ -136,40 +136,21 @@ def update_meta_data(
df["_orig_chunk"] = df[orig_chunk_col].copy()
df["_orig_file"] = df[orig_file_col].copy()
# Get mappings for this video key
# Update chunk and file indices to point to destination
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
# Apply per-source-file timestamp offsets
src_to_offset = video_idx.get("src_to_offset", {})
src_to_dst = video_idx.get("src_to_dst", {})
# Apply per-source-file mappings
if src_to_dst:
# Map each episode to its correct destination file and apply offset
if src_to_offset:
# Apply offset based on original source file
for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
# Get destination chunk/file for this source file
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
df.at[idx, orig_chunk_col] = dst_chunk
df.at[idx, orig_file_col] = dst_file
# Apply timestamp offset
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
elif src_to_offset:
# Fallback: use same destination for all, but apply per-file offsets
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
for idx in df.index:
# Convert to Python int to avoid numpy type mismatch in dict lookup
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
else:
# Fallback to simple offset (for backward compatibility)
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
df[f"videos/{key}/from_timestamp"] = (
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
)
@@ -287,12 +268,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
videos_idx[key]["episode_duration"] = 0
# Track offset for each source (chunk, file) pair
videos_idx[key]["src_to_offset"] = {}
# Track destination (chunk, file) for each source (chunk, file) pair
videos_idx[key]["src_to_dst"] = {}
# Initialize dst_file_durations if not present
# dst_file_durations tracks duration of each destination file
if "dst_file_durations" not in videos_idx[key]:
videos_idx[key]["dst_file_durations"] = {}
for key, video_idx in videos_idx.items():
unique_chunk_file_pairs = {
@@ -307,13 +282,9 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
chunk_idx = video_idx["chunk"]
file_idx = video_idx["file"]
dst_file_durations = video_idx["dst_file_durations"]
current_offset = video_idx["latest_duration"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
# Convert to Python int to ensure consistent dict keys
src_chunk_idx = int(src_chunk_idx)
src_file_idx = int(src_file_idx)
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=src_chunk_idx,
@@ -327,17 +298,14 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
)
src_duration = get_video_duration_in_s(src_path)
dst_key = (chunk_idx, file_idx)
if not dst_path.exists():
# New destination file: offset is 0
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
# Store offset before incrementing
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
# Track duration of this destination file
dst_file_durations[dst_key] = src_duration
videos_idx[key]["episode_duration"] += src_duration
current_offset += src_duration
continue
# Check file sizes before appending
@@ -345,11 +313,10 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
dst_size = get_file_size_in_mb(dst_path)
if dst_size + src_size >= video_files_size_in_mb:
# Rotate to a new file - offset is 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_key = (chunk_idx, file_idx)
# Rotate to a new file, this source becomes start of new destination
# So its offset should be 0
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
chunk_index=chunk_idx,
@@ -357,20 +324,16 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
# Track duration of this new destination file
dst_file_durations[dst_key] = src_duration
# Reset offset for next file
current_offset = src_duration
else:
# Append to existing destination file
# Offset is the current duration of this destination file
current_dst_duration = dst_file_durations.get(dst_key, 0)
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
# Append to existing video file - use current accumulated offset
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
concatenate_video_files(
[dst_path, src_path],
dst_path,
)
# Update duration of this destination file
dst_file_durations[dst_key] = current_dst_duration + src_duration
current_offset += src_duration
videos_idx[key]["episode_duration"] += src_duration

View File

@@ -110,8 +110,8 @@ def worker_thread_loop(queue: queue.Queue):
if item is None:
queue.task_done()
break
image_array, fpath, compress_level = item
write_image(image_array, fpath, compress_level)
image_array, fpath = item
write_image(image_array, fpath)
queue.task_done()
@@ -169,13 +169,11 @@ class AsyncImageWriter:
p.start()
self.processes.append(p)
def save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
):
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()
self.queue.put((image, fpath, compress_level))
self.queue.put((image, fpath))
def wait_until_done(self):
self.queue.join()

View File

@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import contextlib
import logging
import shutil
@@ -540,15 +539,6 @@ class LeRobotDatasetMetadata:
return obj
def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
img_dir = (root / fpath).parent
encode_video_frames(img_dir, temp_path, fps, overwrite=True)
shutil.rmtree(img_dir)
return temp_path
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
@@ -722,15 +712,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
# Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded
# Build a mapping: absolute_index -> relative_index_in_filtered_dataset
self._absolute_to_relative_idx = None
if self.episodes is not None:
self._absolute_to_relative_idx = {
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
}
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@@ -849,7 +830,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset = load_nested_dataset(self.root / "data", features=features)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -866,8 +847,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Determine requested episodes
if self.episodes is None:
# Requesting all episodes - check if we have all episodes from metadata
requested_episodes = set(range(self.meta.total_episodes))
else:
# Requesting specific episodes
requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data
@@ -949,11 +932,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_timestamps = {}
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
if self._absolute_to_relative_idx is not None:
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
timestamps = self.hf_dataset[relative_indices]["timestamp"]
else:
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
@@ -976,16 +955,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key, q_idx in query_indices.items():
if key in self.meta.video_keys:
continue
# Map absolute indices to relative indices if needed
relative_indices = (
q_idx
if self._absolute_to_relative_idx is None
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
)
try:
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
result[key] = torch.stack(self.hf_dataset[key][q_idx])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
result[key] = torch.stack(self.hf_dataset[q_idx][key])
return result
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
@@ -1081,7 +1054,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
# TODO(Steven): consider move this to utils
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
@@ -1091,15 +1063,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
def _save_image(
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
) -> None:
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
if self.image_writer is None:
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
write_image(image, fpath, compress_level=compress_level)
write_image(image, fpath)
else:
self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level)
self.image_writer.save_image(image=image, fpath=fpath)
def add_frame(self, frame: dict) -> None:
"""
@@ -1137,19 +1107,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
if frame_index == 0:
img_path.parent.mkdir(parents=True, exist_ok=True)
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
self._save_image(frame[key], img_path, compress_level)
self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path))
else:
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
def save_episode(
self,
episode_data: dict | None = None,
parallel_encoding: bool = True,
) -> None:
def save_episode(self, episode_data: dict | None = None) -> None:
"""
This will save to disk the current episode in self.episode_buffer.
@@ -1161,8 +1126,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None.
parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor.
Defaults to True on Linux, False on macOS as it tends to use all the CPU available already.
"""
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
@@ -1199,40 +1162,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
use_batched_encoding = self.batch_encoding_size > 1
if has_video_keys and not use_batched_encoding:
num_cameras = len(self.meta.video_keys)
if parallel_encoding and num_cameras > 1:
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
# num_cameras * num_threads = (total_cpu -1)
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
future_to_key = {
executor.submit(
_encode_video_worker,
video_key,
episode_index,
self.root,
self.fps,
): video_key
for video_key in self.meta.video_keys
}
results = {}
for future in concurrent.futures.as_completed(future_to_key):
video_key = future_to_key[future]
try:
temp_path = future.result()
results[video_key] = temp_path
except Exception as exc:
logging.error(f"Video encoding failed for {video_key}: {exc}")
raise exc
for video_key in self.meta.video_keys:
temp_path = results[video_key]
ep_metadata.update(
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
)
else:
for video_key in self.meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
for video_key in self.meta.video_keys:
ep_metadata.update(self._save_episode_video(video_key, episode_index))
# `meta.save_episode` need to be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
@@ -1397,18 +1328,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
return metadata
def _save_episode_video(
self,
video_key: str,
episode_index: int,
temp_path: Path | None = None,
) -> dict:
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
# Encode episode frames into a temporary video
if temp_path is None:
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
else:
ep_path = temp_path
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
ep_size_in_mb = get_file_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
@@ -1526,7 +1448,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
img_dir = self._get_image_file_dir(episode_index, video_key)
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
shutil.rmtree(img_dir)
return temp_path
@classmethod
def create(
@@ -1572,7 +1498,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj._absolute_to_relative_idx = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.writer = None
obj.latest_episode = None

View File

@@ -28,7 +28,6 @@ import numpy as np
import packaging.version
import pandas
import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from datasets import Dataset
@@ -49,7 +48,7 @@ from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_strin
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
INFO_PATH = "meta/info.json"
STATS_PATH = "meta/stats.json"
@@ -104,9 +103,7 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
return chunk_idx, file_idx
def load_nested_dataset(
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
) -> Dataset:
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
@@ -114,26 +111,15 @@ def load_nested_dataset(
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars():
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
return datasets
def get_parquet_num_frames(parquet_path: str | Path) -> int:

View File

@@ -311,7 +311,6 @@ def encode_video_frames(
fast_decode: int = 0,
log_level: int | None = av.logging.ERROR,
overwrite: bool = False,
preset: int | None = None,
) -> None:
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
# Check encoder availability
@@ -360,9 +359,6 @@ def encode_video_frames(
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
video_options[key] = value
if vcodec == "libsvtav1":
video_options["preset"] = str(preset) if preset is not None else "12"
# Set logging level
if log_level is not None:
# "While less efficient, it is generally preferable to modify logging with Python's logging"

View File

@@ -21,22 +21,7 @@ import draccus
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.constants import (
ACTION,
LIBERO_KEY_EEF_MAT,
LIBERO_KEY_EEF_POS,
LIBERO_KEY_EEF_QUAT,
LIBERO_KEY_GRIPPER_QPOS,
LIBERO_KEY_GRIPPER_QVEL,
LIBERO_KEY_JOINTS_POS,
LIBERO_KEY_JOINTS_VEL,
LIBERO_KEY_PIXELS_AGENTVIEW,
LIBERO_KEY_PIXELS_EYE_IN_HAND,
OBS_ENV_STATE,
OBS_IMAGE,
OBS_IMAGES,
OBS_STATE,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
@dataclass
@@ -245,7 +230,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
fps: int = 30
episode_length: int | None = None
episode_length: int = 520
obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array"
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
@@ -261,62 +246,28 @@ class LiberoEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
"agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
}
)
control_mode: str = "relative" # or "absolute"
def __post_init__(self):
if self.obs_type == "pixels":
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(3,),
)
self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(4,),
)
self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(3, 3),
)
self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}")

View File

@@ -14,18 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -39,46 +33,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env_pre_post_processors(
env_cfg: EnvConfig,
policy_cfg: PreTrainedConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
This function creates processor pipelines that transform raw environment
observations and actions. By default, it returns identity processors that do nothing.
For specific environments like LIBERO, it adds environment-specific processing steps.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
"""
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
if isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
return make_xvla_libero_pre_post_processors()
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
return preprocessor, postprocessor
def make_env(
cfg: EnvConfig | str,
n_envs: int = 1,
@@ -143,8 +97,6 @@ def make_env(
init_states=cfg.init_states,
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
)
elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs

View File

@@ -28,6 +28,7 @@ import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.transform_utils import quat2axisangle
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
@@ -80,7 +81,10 @@ def get_libero_dummy_action():
return [0, 0, 0, 0, 0, 0, -1]
OBS_STATE_DIM = 8
ACTION_DIM = 7
AGENT_POS_LOW = -1000.0
AGENT_POS_HIGH = 1000.0
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
TASK_SUITE_MAX_STEPS: dict[str, int] = {
@@ -100,7 +104,6 @@ class LiberoEnv(gym.Env):
task_suite: Any,
task_id: int,
task_suite_name: str,
episode_length: int | None = None,
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
obs_type: str = "pixels",
render_mode: str = "rgb_array",
@@ -112,7 +115,6 @@ class LiberoEnv(gym.Env):
episode_index: int = 0,
camera_name_mapping: dict[str, str] | None = None,
num_steps_wait: int = 10,
control_mode: str = "relative",
):
super().__init__()
self.task_id = task_id
@@ -140,19 +142,14 @@ class LiberoEnv(gym.Env):
self.camera_name_mapping = camera_name_mapping
self.num_steps_wait = num_steps_wait
self.episode_index = episode_index
self.episode_length = episode_length
# Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500
self._max_episode_steps = (
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
if self.episode_length is None
else self.episode_length
)
self.control_mode = control_mode
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
images = {}
for cam in self.camera_name:
images[self.camera_name_mapping[cam]] = spaces.Box(
@@ -178,36 +175,11 @@ class LiberoEnv(gym.Env):
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"robot_state": spaces.Dict(
{
"eef": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
"quat": spaces.Box(
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
),
"mat": spaces.Box(
low=-np.inf, high=np.inf, shape=(3, 3), dtype=np.float64
),
}
),
"gripper": spaces.Dict(
{
"qpos": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
"qvel": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
}
),
"joints": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
}
),
}
"agent_pos": spaces.Box(
low=AGENT_POS_LOW,
high=AGENT_POS_HIGH,
shape=(OBS_STATE_DIM,),
dtype=np.float64,
),
}
)
@@ -219,7 +191,6 @@ class LiberoEnv(gym.Env):
def render(self):
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
image = image[::-1, ::-1] # flip both H and W for visualization
return image
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
@@ -241,48 +212,23 @@ class LiberoEnv(gym.Env):
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image
eef_pos = raw_obs.get("robot0_eef_pos")
eef_quat = raw_obs.get("robot0_eef_quat")
# rotation matrix from controller
eef_mat = self._env.robots[0].controller.ee_ori_mat if eef_pos is not None else None
gripper_qpos = raw_obs.get("robot0_gripper_qpos")
gripper_qvel = raw_obs.get("robot0_gripper_qvel")
joint_pos = raw_obs.get("robot0_joint_pos")
joint_vel = raw_obs.get("robot0_joint_vel")
obs = {
"pixels": images,
"robot_state": {
"eef": {
"pos": eef_pos, # (3,)
"quat": eef_quat, # (4,)
"mat": eef_mat, # (3, 3)
},
"gripper": {
"qpos": gripper_qpos, # (2,)
"qvel": gripper_qvel, # (2,)
},
"joints": {
"pos": joint_pos, # (7,)
"vel": joint_vel, # (7,)
},
},
}
state = np.concatenate(
(
raw_obs["robot0_eef_pos"],
quat2axisangle(raw_obs["robot0_eef_quat"]),
raw_obs["robot0_gripper_qpos"],
)
)
agent_pos = state
if self.obs_type == "pixels":
return {"pixels": images.copy()}
if self.obs_type == "pixels_agent_pos":
# Validate required fields are present
if eef_pos is None or eef_quat is None or gripper_qpos is None:
raise ValueError(
f"Missing required robot state fields in raw observation. "
f"Got eef_pos={eef_pos is not None}, eef_quat={eef_quat is not None}, "
f"gripper_qpos={gripper_qpos is not None}"
)
return obs
return {
"pixels": images.copy(),
"agent_pos": agent_pos,
}
raise NotImplementedError(
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
@@ -300,15 +246,6 @@ class LiberoEnv(gym.Env):
# Increasing this value can improve determinism and reproducibility across resets.
for _ in range(self.num_steps_wait):
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
if self.control_mode == "absolute":
for robot in self._env.robots:
robot.controller.use_delta = False
elif self.control_mode == "relative":
for robot in self._env.robots:
robot.controller.use_delta = True
else:
raise ValueError(f"Invalid control mode: {self.control_mode}")
observation = self._format_raw_obs(raw_obs)
info = {"is_success": False}
return observation, info
@@ -354,10 +291,8 @@ def _make_env_fns(
task_id: int,
n_envs: int,
camera_names: list[str],
episode_length: int | None,
init_states: bool,
gym_kwargs: Mapping[str, Any],
control_mode: str,
) -> list[Callable[[], LiberoEnv]]:
"""Build n_envs factory callables for a single (suite, task_id)."""
@@ -369,9 +304,7 @@ def _make_env_fns(
task_suite_name=suite_name,
camera_name=camera_names,
init_states=init_states,
episode_length=episode_length,
episode_index=episode_index,
control_mode=control_mode,
**local_kwargs,
)
@@ -391,8 +324,6 @@ def create_libero_envs(
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
init_states: bool = True,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
control_mode: str = "relative",
episode_length: int | None = None,
) -> dict[str, dict[int, Any]]:
"""
Create vectorized LIBERO environments with a consistent return shape.
@@ -424,24 +355,24 @@ def create_libero_envs(
print(f"Restricting to task_ids={task_ids_filter}")
out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names:
suite = _get_suite(suite_name)
total = len(suite.tasks)
selected = _select_task_ids(total, task_ids_filter)
if not selected:
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
for tid in selected:
fns = _make_env_fns(
suite=suite,
episode_length=episode_length,
suite_name=suite_name,
task_id=tid,
n_envs=n_envs,
camera_names=camera_names,
init_states=init_states,
gym_kwargs=gym_kwargs,
control_mode=control_mode,
)
out[suite_name][tid] = env_cls(fns)
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")

View File

@@ -29,22 +29,10 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.utils import get_channel_first_image_shape
def _convert_nested_dict(d):
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _convert_nested_dict(v)
elif isinstance(v, np.ndarray):
result[k] = torch.from_numpy(v)
else:
result[k] = v
return result
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
@@ -90,14 +78,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[OBS_ENV_STATE] = env_state
if "agent_pos" in observations:
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
if "robot_state" in observations:
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
return return_observations

View File

@@ -104,107 +104,6 @@ class SGDConfig(OptimizerConfig):
return torch.optim.SGD(params, **kwargs)
@OptimizerConfig.register_subclass("xvla-adamw")
@dataclass
class XVLAAdamWConfig(OptimizerConfig):
"""Custom AdamW optimizer for XVLA with differential learning rates.
The Vision-Language Model (VLM) is trained with 1/10 of the base learning rate
for stable optimization, while all other components use the full LR.
This LR ratio is crucial for achieving strong and stable finetuning performance.
Soft-prompts can optionally use a separate learning rate with warm-up support.
Set `soft_prompt_lr_scale` to a value < 1.0 (e.g., 0.1) to start soft-prompts
at a lower LR. Combine with a warmup scheduler for optimal results.
Note:
Completely matching official reported performance may require an additional
warm-up LR schedule for soft-prompts, which can bring minor improvements.
When `soft_prompt_warmup_lr_scale` is set, soft-prompts start at
`lr * soft_prompt_warmup_lr_scale` and should be warmed up via the scheduler.
Parameter Groups:
- Group 0 (vlm): VLM parameters at lr * 0.1, weight_decay * 0.1
- Group 1 (soft_prompts): Soft-prompt parameters at lr * soft_prompt_lr_scale
- Group 2 (other): All other parameters at full lr
"""
lr: float = 1e-4
betas: tuple[float, float] = (0.9, 0.99)
eps: float = 1e-8
weight_decay: float = 0.0
grad_clip_norm: float = 10.0
# Soft-prompt specific settings
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
def build(self, params: dict) -> torch.optim.Optimizer:
"""
Build AdamW optimizer with differential learning rates.
Expects `named_parameters()` as input (dict of name -> param).
Applies:
- lr * 0.1 for all VLM-related parameters
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
- full lr for all other parameters
Args:
params: Dictionary of parameter names to parameters (from named_parameters())
Returns:
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
"""
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
vlm_group, soft_prompt_group, other_group = [], [], []
for name, p in params.items():
if not p.requires_grad:
continue
if "vlm" in name.lower():
vlm_group.append(p)
elif "soft_prompt" in name.lower():
soft_prompt_group.append(p)
else:
other_group.append(p)
# Determine soft-prompt LR
soft_prompt_lr = self.lr * self.soft_prompt_lr_scale
if self.soft_prompt_warmup_lr_scale is not None:
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
param_groups = [
{
"params": vlm_group,
"lr": self.lr * 0.1,
"weight_decay": self.weight_decay * 0.1,
"name": "vlm",
},
{
"params": soft_prompt_group,
"lr": soft_prompt_lr,
"weight_decay": self.weight_decay,
"name": "soft_prompts",
},
{
"params": other_group,
"lr": self.lr,
"weight_decay": self.weight_decay,
"name": "other",
},
]
# Filter out empty groups
param_groups = [g for g in param_groups if len(g["params"]) > 0]
return torch.optim.AdamW(
param_groups,
betas=self.betas,
eps=self.eps,
)
@OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):

View File

@@ -21,7 +21,6 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
__all__ = [
"ACTConfig",
@@ -32,5 +31,4 @@ __all__ = [
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
"XVLAConfig",
]

View File

@@ -16,7 +16,6 @@
from __future__ import annotations
import importlib
import logging
from typing import Any, TypedDict
@@ -41,7 +40,6 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
@@ -109,15 +107,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.groot.modeling_groot import GrootPolicy
return GrootPolicy
elif name == "xvla":
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
return XVLAPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
except Exception as e:
raise ValueError(f"Policy type '{name}' is not available.") from e
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
@@ -159,14 +150,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
elif policy_type == "xvla":
return XVLAConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
return config_cls(**kwargs)
except Exception as e:
raise ValueError(f"Policy type '{policy_type}' is not available.") from e
raise ValueError(f"Policy type '{policy_type}' is not available.")
class ProcessorConfigKwargs(TypedDict, total=False):
@@ -344,24 +329,9 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import (
make_xvla_pre_post_processors,
)
processors = make_xvla_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
except Exception as e:
raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
return processors
@@ -430,7 +400,8 @@ def make_policy(
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.output_features:
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
@@ -454,65 +425,3 @@ def make_policy(
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
return policy
def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]:
"""Get policy class from its registered name using dynamic imports.
This is used as a helper function to import policies from 3rd party lerobot plugins.
Args:
name: The name of the policy.
Returns:
The policy class corresponding to the given name.
"""
if name not in PreTrainedConfig.get_known_choices():
raise ValueError(
f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}"
)
config_cls = PreTrainedConfig.get_choice_class(name)
config_cls_name = config_cls.__name__
model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion
if model_name == config_cls_name:
raise ValueError(
f"The config class name '{config_cls_name}' does not follow the expected naming convention."
f"Make sure it ends with 'Config'!"
)
cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy
module_path = config_cls.__module__.replace(
"configuration_", "modeling_"
) # e.g., configuration_diffusion -> modeling_diffusion
module = importlib.import_module(module_path)
policy_cls = getattr(module, cls_name)
return policy_cls
def _make_processors_from_policy_config(
config: PreTrainedConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[Any, Any]:
"""Create pre- and post-processors from a policy configuration using dynamic imports.
This is used as a helper function to import processor factories from 3rd party lerobot plugins.
Args:
config: The policy configuration object.
dataset_stats: Dataset statistics for normalization.
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
"""
policy_type = config.type
function_name = f"make_{policy_type}_pre_post_processors"
module_path = config.__class__.__module__.replace(
"configuration_", "processor_"
) # e.g., configuration_diffusion -> processor_diffusion
logging.debug(
f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'"
)
module = importlib.import_module(module_path)
function = getattr(module, function_name)
return function(config, dataset_stats=dataset_stats)

View File

@@ -812,13 +812,16 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
@@ -843,11 +846,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
else:
v_t = denoise_step_partial_call(x_t)
x_t = x_t + dt * v_t
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(

View File

@@ -538,8 +538,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
if config.compile_model:
torch.set_float32_matmul_precision("high")
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
@@ -787,13 +785,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
dt = -1.0 / num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
@@ -817,11 +818,15 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
else:
v_t = denoise_step_partial_call(x_t)
x_t = x_t + dt * v_t
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(

View File

@@ -783,15 +783,18 @@ class VLAFlowMatching(nn.Module):
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
num_steps = self.config.num_steps
dt = -1.0 / num_steps
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
time = torch.tensor(1.0, dtype=torch.float32, device=device)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
while time >= -dt / 2:
expanded_time = time.expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
@@ -815,11 +818,15 @@ class VLAFlowMatching(nn.Module):
else:
v_t = denoise_step_partial_call(x_t)
x_t = x_t + dt * v_t
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt
return x_t
def denoise_step(

View File

@@ -1,6 +0,0 @@
# register the processor steps
from lerobot.policies.xvla.processor_xvla import (
XVLAAddDomainIdProcessorStep,
XVLAImageNetNormalizeProcessorStep,
XVLAImageToFloatProcessorStep,
)

View File

@@ -1,588 +0,0 @@
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF and HuggingFace Inc. (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Iterable
import torch
import torch.nn as nn
# =============================================================================
# Registry
# =============================================================================
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
def register_action(name: str):
"""Decorator for registering a new action space."""
def _wrap(cls):
key = name.lower()
if key in ACTION_REGISTRY:
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
ACTION_REGISTRY[key] = cls
cls.name = key
return cls
return _wrap
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
"""Instantiate a registered action space by name."""
key = name.lower()
if key not in ACTION_REGISTRY:
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
return ACTION_REGISTRY[key](**kwargs)
# =============================================================================
# Base class
# =============================================================================
class BaseActionSpace(nn.Module):
"""
Abstract base class for all action-space definitions.
Each subclass defines:
- `dim_action`: dimension of the action vector.
- `gripper_idx`: indices of gripper channels.
- `compute_loss(pred, target)`: supervised loss for this space.
- `preprocess(proprio, action, mode)`: pre-step modifications.
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
"""
name: str = "base"
dim_action: int = 0
gripper_idx: tuple[int, ...] = ()
def __init__(self):
super().__init__()
# ---------------------------------------------------------------------
# Core supervised loss
# ---------------------------------------------------------------------
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
raise NotImplementedError
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
"""Alias for compute_loss."""
return self.compute_loss(pred, target)
# ---------------------------------------------------------------------
# Space-level hooks
# ---------------------------------------------------------------------
def preprocess(
self,
proprio: torch.Tensor,
action: torch.Tensor,
mode: str = "train",
) -> tuple[torch.Tensor, torch.Tensor]:
"""Default: return unchanged."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Default: return unchanged."""
return action
# =============================================================================
# Utilities
# =============================================================================
def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
bad = [i for i in idx if i < 0 or i >= dim_action]
if bad:
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
# =============================================================================
# Implementations
# =============================================================================
@register_action("ee6d")
class EE6DActionSpace(BaseActionSpace):
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
dim_action = 20
gripper_idx = (9, 19)
GRIPPER_SCALE = 1.0
XYZ_SCALE = 500.0
ROT_SCALE = 10.0
POS_IDX_1 = (0, 1, 2)
POS_IDX_2 = (10, 11, 12)
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
# Gripper BCE
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
# XYZ position
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
# Rotation 6D
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
"position_loss": pos_loss,
"rotate6D_loss": rot_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
@register_action("joint")
class JointActionSpace(BaseActionSpace):
"""Joint-space layout with joints + gripper only."""
dim_action = 14
gripper_idx = (6, 13)
GRIPPER_SCALE = 0.1
JOINTS_SCALE = 1.0
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""Zero-out gripper channels in proprio/action."""
proprio_m = proprio.clone()
action_m = action.clone()
proprio_m[..., self.gripper_idx] = 0.0
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Apply sigmoid to gripper logits."""
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
return action
@register_action("agibot_ee6d")
class AGIBOTEE6DActionSpace(BaseActionSpace):
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
dim_action = 20
gripper_idx = (9, 19)
GRIPPER_SCALE = 10.0
XYZ_SCALE = 500.0
ROT_SCALE = 10.0
POS_IDX_1 = (0, 1, 2)
POS_IDX_2 = (10, 11, 12)
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape
batch_size, seq_len, action_dim = pred.shape
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
gripper_loss = (
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
)
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
"position_loss": pos_loss,
"rotate6D_loss": rot_loss,
"gripper_loss": gripper_loss,
}
def preprocess(self, proprio, action, mode="train"):
"""No preprocessing applied in AGIBOT variant."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""AGIBOT does not postprocess."""
return action
@register_action("franka_joint7")
class FrankaJoint7ActionSpace(BaseActionSpace):
"""
Franka Panda joint-space: 7 joints, with gripper.
- Real robot action dim: 7
- Model-facing dim: 20 (padded with zeros)
compatible with pretrained VLA models expecting 20D.
"""
dim_action = 20 # model dimension
REAL_DIM = 7 # actual Franka joints
JOINTS_SCALE = 1.0
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Pad 7 → 20 dims (zeros for the dummy channels)."""
if x is None:
return None
if x.size(-1) == self.dim_action:
return x
if x.size(-1) != self.REAL_DIM:
raise ValueError(
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
)
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM] # 13 zeros
pad = x.new_zeros(pad_shape)
return torch.cat([x, pad], dim=-1)
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Trim model output 20 → 7 dims."""
return x[..., : self.REAL_DIM]
def compute_loss(self, pred, target):
"""
pred : [B, T, 20]
target : [B, T, 7] or [B, T, 20]
Only compute MSE on the first 7 dims.
"""
pred = self._pad_to_model_dim(pred)
target = self._pad_to_model_dim(target)
assert pred.shape == target.shape
joints_loss = (
self.mse(
pred[:, :, : self.REAL_DIM], # use only the first 7 joints
target[:, :, : self.REAL_DIM],
)
* self.JOINTS_SCALE
)
return {"joints_loss": joints_loss}
def preprocess(self, proprio, action, mode="train"):
"""
During training:
- Pad [7] → [20]
"""
return proprio, self._pad_to_model_dim(action)
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""
After model prediction:
- Trim [20] → [7] for real robot control.
"""
return self._trim_to_real_dim(action)
@register_action("auto")
class AutoActionSpace(BaseActionSpace):
"""
Auto-detecting action space that adapts to any action dimension.
- Auto-detects the real action dimension from the policy feature
- Model outputs max_dim for compatibility with pretrained models
- Loss is computed only on the first real_dim dimensions
- Postprocess trims output back to real_dim
Args:
real_dim: The actual action dimension from the dataset/policy feature
max_dim: The model's output dimension for pretrained VLA compatibility
"""
JOINTS_SCALE = 1.0
def __init__(self, real_dim: int, max_dim: int):
super().__init__()
self.real_dim = real_dim
self.dim_action = max_dim # Model-facing dimension
self.mse = nn.MSELoss()
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Pad real_dim → max_dim (zeros for the dummy channels)."""
if x is None:
return None
if x.size(-1) == self.dim_action:
return x
if x.size(-1) != self.real_dim:
# If dimension doesn't match either, pad/trim to real_dim first
if x.size(-1) < self.real_dim:
pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)]
pad = x.new_zeros(pad_shape)
x = torch.cat([x, pad], dim=-1)
else:
x = x[..., : self.real_dim]
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim]
pad = x.new_zeros(pad_shape)
return torch.cat([x, pad], dim=-1)
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Trim model output max_dim → real_dim."""
return x[..., : self.real_dim]
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Compute loss only on the first real_dim dimensions.
pred: [B, T, max_dim] from the model
target: [B, T, real_dim] or [B, T, max_dim]
Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim])
"""
pred = self._pad_to_model_dim(pred)
target = self._pad_to_model_dim(target)
assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}"
# only compute loss on the real dimensions
joints_loss = (
self.mse(
pred[:, :, : self.real_dim],
target[:, :, : self.real_dim],
)
* self.JOINTS_SCALE
)
return {"joints_loss": joints_loss}
def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"):
"""
Pad action from real_dim to max_dim for the model.
"""
return proprio, self._pad_to_model_dim(action)
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""
Trim model output from max_dim to real_dim for real robot control.
"""
return self._trim_to_real_dim(action)
@register_action("so101_bimanual")
class BimanualSO101ActionSpace(BaseActionSpace):
"""
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
Layout (real robot):
[left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
Real action dim: 12
Model-facing dim: 20 (extra 8 dummy dims at the end)
"""
# Model output / training dimension (to match pretrained policy)
dim_action = 20
# Real robot action dimension
REAL_DIM = 12
# Indices of real vs dummy channels
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
# Grippers live in the real part
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
GRIPPER_SCALE = 1.0
JOINTS_SCALE = 1.0
# Indices for left and right arm joints (excluding grippers)
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
# ---------- helpers ----------
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
"""If last dim is REAL_DIM (12), pad zeros to reach dim_action (20)."""
if x is None:
return None
if x.size(-1) == self.dim_action:
return x
if x.size(-1) != self.REAL_DIM:
raise ValueError(
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
)
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM]
pad = x.new_zeros(pad_shape)
return torch.cat([x, pad], dim=-1)
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
"""Keep only the first REAL_DIM (12) dims for the real robot."""
return x[..., : self.REAL_DIM]
# ---------- loss ----------
def compute_loss(self, pred, target):
"""
pred: [B, T, 20] from the model
target: [B, T, 12] or [B, T, 20]
We pad target → 20 and compute loss only on the real dims.
"""
# Ensure both are [B, T, 20]
pred = self._pad_to_model_dim(pred)
target = self._pad_to_model_dim(target)
assert pred.shape == target.shape
# ---- MSE for all real dims (011) ----
real_dims = 12
joints_loss = (
self.mse(
pred[:, :, :real_dims],
target[:, :, :real_dims],
)
* self.JOINTS_SCALE
)
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
gripper_loss = (
self.mse(
pred[:, :, [5, 11]],
target[:, :, [5, 11]],
)
* self.GRIPPER_SCALE
)
return {
"joints_loss": joints_loss,
"gripper_loss": gripper_loss,
"left_arm_loss": left_arm_loss,
"right_arm_loss": right_arm_loss,
}
# ---------- preprocess / postprocess ----------
def preprocess(self, proprio, action, mode="train"):
"""
- If proprio/action are 12-dim, pad them to 20 for the model.
- Zero-out gripper channels in proprio/action to focus learning on joints.
"""
proprio_m = self._pad_to_model_dim(proprio.clone())
action_m = self._pad_to_model_dim(action.clone()) if action is not None else None
proprio_m[..., self.gripper_idx] = 0.0
if action_m is not None:
action_m[..., self.gripper_idx] = 0.0
return proprio_m, action_m
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""
- Model outputs [*, 20]
- Apply sigmoid to gripper logits
- Return only the first 12 dims for the real robot:
["left_shoulder_pan.pos",
"left_shoulder_lift.pos",
"left_elbow_flex.pos",
"left_wrist_flex.pos",
"left_wrist_roll.pos",
"left_gripper.pos",
"right_shoulder_pan.pos",
"right_shoulder_lift.pos",
"right_elbow_flex.pos",
"right_wrist_flex.pos",
"right_wrist_roll.pos",
"right_gripper.pos"]
"""
# Ensure we at least have the real dims + grippers
if action.size(-1) < self.REAL_DIM:
raise ValueError(f"Expected at least {self.REAL_DIM} dims in action, got {action.size(-1)}")
# Apply sigmoid on gripper channels in model space (indices 5 and 11)
if action.size(-1) > max(self.gripper_idx):
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
# Return only the real 12-dim control vector for the env
return self._trim_to_real_dim(action)
# =============================================================================
# Exports
# =============================================================================
__all__ = [
"BaseActionSpace",
"build_action_space",
"register_action",
"EE6DActionSpace",
"JointActionSpace",
"AGIBOTEE6DActionSpace",
"FrankaJoint7ActionSpace",
"AutoActionSpace",
"BimanualSO101ActionSpace",
"ACTION_REGISTRY",
]

View File

@@ -1,353 +0,0 @@
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
""" Florence-2 configuration"""
logger = logging.get_logger(__name__)
class Florence2VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
drop_path_rate (`float`, *optional*, defaults to 0.1):
The dropout rate of the drop path layer.
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
The patch size of the image.
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
The patch stride of the image.
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
The patch padding of the image.
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
Whether to apply layer normalization before the patch embedding layer.
enable_checkpoint (`bool`, *optional*, defaults to False):
Whether to enable checkpointing.
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
The dimension of the embedding layer.
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
The number of attention heads.
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
The number of groups.
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
The depth of the model.
window_size (`int`, *optional*, defaults to 12):
The window size of the model.
projection_dim (`int`, *optional*, defaults to 1024):
The dimension of the projection layer.
visual_temporal_embedding (`dict`, *optional*):
The configuration of the visual temporal embedding.
image_pos_embed (`dict`, *optional*):
The configuration of the image position embedding.
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
The source of the image feature.
Example:
```python
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
>>> # Initializing a Florence2 Vision style configuration
>>> configuration = Florence2VisionConfig()
>>> # Initializing a model (with random weights)
>>> model = Florence2VisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "davit"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
drop_path_rate=0.1,
patch_size=None,
patch_stride=None,
patch_padding=None,
patch_prenorm=None,
enable_checkpoint=False,
dim_embed=None,
num_heads=None,
num_groups=None,
depths=None,
window_size=12,
projection_dim=1024,
visual_temporal_embedding=None,
image_pos_embed=None,
image_feature_source=None,
**kwargs,
):
self.drop_path_rate = drop_path_rate
self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
self.enable_checkpoint = enable_checkpoint
self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
self.depths = depths if depths is not None else [1, 1, 9, 1]
self.window_size = window_size
self.projection_dim = projection_dim
if visual_temporal_embedding is None:
visual_temporal_embedding = {
"type": "COSINE",
"max_temporal_embeddings": 100,
}
self.visual_temporal_embedding = visual_temporal_embedding
if image_pos_embed is None:
image_pos_embed = {
"type": "learned_abs_2d",
"max_pos_embeddings": 1000,
}
self.image_pos_embed = image_pos_embed
self.image_feature_source = (
image_feature_source
if image_feature_source is not None
else ["spatial_avg_pool", "temporal_avg_pool"]
)
super().__init__(**kwargs)
class Florence2LanguageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the BART
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 51289):
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Florence2LanguageModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models).
num_labels (`int`, *optional*, defaults to 3):
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
>>> # Initializing a Florence2 Language style configuration
>>> configuration = Florence2LanguageConfig()
>>> # Initializing a model (with random weights)
>>> model = Florence2LanguageModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "florence2_language"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
def __init__(
self,
vocab_size=51289,
max_position_embeddings=1024,
encoder_layers=12,
encoder_ffn_dim=4096,
encoder_attention_heads=16,
decoder_layers=12,
decoder_ffn_dim=4096,
decoder_attention_heads=16,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
activation_function="gelu",
d_model=1024,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
is_encoder_decoder=True,
decoder_start_token_id=2,
forced_eos_token_id=2,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
num_labels=num_labels,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
# ensure backward compatibility for BART CNN models
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
"The config can simply be saved and uploaded again to be fixed.",
stacklevel=2,
)
class Florence2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
Florence-2 model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Florence2VisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
vocab_size (`int`, *optional*, defaults to 51289):
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
projection_dim (`int`, *optional*, defaults to 1024):
Dimension of the multimodal projection space.
Example:
```python
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
>>> # Initializing a clip-like vision config
>>> vision_config = CLIPVisionConfig()
>>> # Initializing a Bart config
>>> text_config = BartConfig()
>>> # Initializing a Florence-2 configuration
>>> configuration = Florence2Config(vision_config, text_config)
>>> # Initializing a model from the florence-2 configuration
>>> model = Florence2ForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "florence2"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
vocab_size=51289,
projection_dim=1024,
**kwargs,
):
self.ignore_index = ignore_index
self.vocab_size = vocab_size
self.projection_dim = projection_dim
if vision_config is not None:
vision_config = Florence2VisionConfig(**vision_config)
self.vision_config = vision_config
self.text_config = text_config
if text_config is not None:
self.text_config = Florence2LanguageConfig(**text_config)
super().__init__(**kwargs)

View File

@@ -1,203 +0,0 @@
#!/usr/bin/env python
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import XVLAAdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
# Conditional import for type checking and lazy loading
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from .configuration_florence2 import Florence2Config
else:
Florence2Config = None
@PreTrainedConfig.register_subclass("xvla")
@dataclass
class XVLAConfig(PreTrainedConfig):
"""
Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
plug into the LeRobot training stack.
The config mirrors the knobs exposed in the original XVLA repository but also
declares the input/output feature contract required by LeRobot.
"""
# Input / output structure
n_obs_steps: int = 1
chunk_size: int = 32
n_action_steps: int = 32
dtype: str = "float32" # Options: "bfloat16", "float32"
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.IDENTITY,
}
)
# Florence2 backbone and tokenizer configuration
florence_config: dict[str, Any] = field(default_factory=dict)
tokenizer_name: str = "facebook/bart-large"
tokenizer_max_length: int = 64
tokenizer_padding_side: str = "right"
pad_language_to: str = "max_length"
# Transformer head
hidden_size: int = 1024
depth: int = 24
num_heads: int = 16
mlp_ratio: float = 4.0
num_domains: int = 30
len_soft_prompts: int = 32
dim_time: int = 32
max_len_seq: int = 512
use_hetero_proj: bool = False
# Action & proprioception
action_mode: str = "ee6d"
num_denoising_steps: int = 10
use_proprio: bool = True
max_state_dim: int = 32
max_action_dim: int = 20 # Maximum action dimension for padding (used by "auto" action mode)
domain_feature_key: str | None = None
# Vision preprocessing
resize_imgs_with_padding: tuple[int, int] | None = None
num_image_views: int | None = None
empty_cameras: int = 0
# Freezing options for VLM components
# By default, VLM encoders are frozen and only policy transformer + soft prompts train
freeze_vision_encoder: bool = False # Freeze VLM vision encoder weights
freeze_language_encoder: bool = False # Freeze VLM language encoder weights
train_policy_transformer: bool = True # Allow policy transformer to train
train_soft_prompts: bool = True # Allow soft prompts to train
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.99)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.0
optimizer_grad_clip_norm: float = 10.0
# Soft-prompt LR settings (for optional warm-up)
optimizer_soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR
optimizer_soft_prompt_warmup_lr_scale: float | None = None # Start scale for warmup (e.g., 0.01)
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
def __post_init__(self) -> None:
super().__post_init__()
if self.chunk_size <= 0:
raise ValueError("`chunk_size` must be strictly positive.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
)
if self.num_image_views is not None and self.num_image_views <= 0:
raise ValueError("`num_image_views` must be > 0 when specified.")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
self._florence_config_obj: Florence2Config | None = None
def get_florence_config(self) -> Florence2Config:
"""
Build (and cache) the Florence2 transformer config that should back the VLM.
"""
if self._florence_config_obj is None:
config_dict = dict(self.florence_config)
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
raise ValueError("vision_config is required")
if "text_config" not in config_dict or config_dict["text_config"] is None:
raise ValueError("text_config is required")
self._florence_config_obj = Florence2Config(**config_dict)
return self._florence_config_obj
def validate_features(self) -> None:
if not self.image_features:
raise ValueError("XVLA requires at least one visual feature in the inputs.")
if self.use_proprio and self.robot_state_feature is None:
raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
if self.num_image_views is None:
self.num_image_views = len(self.image_features) + self.empty_cameras
else:
self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
if self.empty_cameras > 0:
height, width = (480, 640)
if self.resize_imgs_with_padding is not None:
height, width = self.resize_imgs_with_padding
for idx in range(self.empty_cameras):
key = f"{OBS_IMAGES}.empty_camera_{idx}"
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, height, width),
)
def get_optimizer_preset(self) -> XVLAAdamWConfig:
"""Return the XVLA-specific optimizer with differential learning rates.
This optimizer applies:
- 1/10 LR for VLM parameters (stable optimization)
- Full LR for transformer/action head
- Configurable LR for soft-prompts (with optional warm-up)
"""
return XVLAAdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
soft_prompt_lr_scale=self.optimizer_soft_prompt_lr_scale,
soft_prompt_warmup_lr_scale=self.optimizer_soft_prompt_warmup_lr_scale,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list[int] | None:
return None
@property
def action_delta_indices(self) -> list[int]:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> list[int] | None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -1,548 +0,0 @@
#!/usr/bin/env python
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
import builtins
import logging
import os
from collections import deque
from pathlib import Path
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
from .action_hub import build_action_space
from .configuration_florence2 import Florence2Config
from .configuration_xvla import XVLAConfig
from .modeling_florence2 import Florence2ForConditionalGeneration
from .soft_transformer import SoftPromptedTransformer
class XVLAModel(nn.Module):
"""
XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
"""
def __init__(
self,
config: XVLAConfig,
florence_config: Florence2Config,
proprio_dim: int,
) -> None:
super().__init__()
self.config = config
self.chunk_size: int = config.chunk_size
self.use_proprio: bool = config.use_proprio
# Build action space with auto-detection for "auto" mode
if config.action_mode.lower() == "auto":
# Auto-detect real action dim from config.action_feature
real_dim = (
config.action_feature.shape[-1]
if config.action_feature is not None
else config.max_action_dim
)
self.action_space = build_action_space(
config.action_mode.lower(),
real_dim=real_dim,
max_dim=config.max_action_dim,
)
else:
self.action_space = build_action_space(config.action_mode.lower())
self.dim_action = self.action_space.dim_action
self.dim_proprio = proprio_dim
self.vlm = Florence2ForConditionalGeneration(florence_config)
if hasattr(self.vlm, "language_model"):
lm = self.vlm.language_model
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
del lm.model.decoder
if hasattr(lm, "lm_head"):
del lm.lm_head
projection_dim = getattr(self.vlm.config, "projection_dim", None)
if projection_dim is None:
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
self.transformer = SoftPromptedTransformer(
hidden_size=config.hidden_size,
multi_modal_input_size=projection_dim,
depth=config.depth,
num_heads=config.num_heads,
mlp_ratio=config.mlp_ratio,
num_domains=config.num_domains,
dim_action=self.dim_action,
dim_propio=self.dim_proprio,
len_soft_prompts=config.len_soft_prompts,
dim_time=config.dim_time,
max_len_seq=config.max_len_seq,
use_hetero_proj=config.use_hetero_proj,
)
# Apply freezing based on config
self._apply_freezing()
# Apply dtype casting based on config
self._apply_dtype()
def _get_target_dtype(self) -> torch.dtype:
"""Get the target dtype based on config."""
if self.config.dtype == "bfloat16":
return torch.bfloat16
return torch.float32
def _apply_dtype(self) -> None:
"""
Apply dtype casting to model components based on config.
"""
target_dtype = self._get_target_dtype()
self.to(dtype=target_dtype)
def _apply_freezing(self) -> None:
"""
Freeze VLM vision and language encoders based on config options.
Keep only policy transformer and soft prompts trainable.
"""
# Freeze vision encoder
if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
for param in self.vlm.vision_tower.parameters():
param.requires_grad = False
# Freeze language encoder
if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
lm = self.vlm.language_model
# Freeze encoder
if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
for param in lm.model.encoder.parameters():
param.requires_grad = False
# Freeze shared embeddings
if hasattr(lm, "model") and hasattr(lm.model, "shared"):
for param in lm.model.shared.parameters():
param.requires_grad = False
# Freeze or unfreeze policy transformer
if not self.config.train_policy_transformer:
for name, param in self.transformer.named_parameters():
if "soft_prompts" not in name:
param.requires_grad = False
# Freeze or unfreeze soft prompts
if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
for param in self.transformer.soft_prompt_hub.parameters():
param.requires_grad = False
def forward_vlm(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
image_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""
Encode text and multi-view images via Florence2 encoder.
"""
batch_size, num_views = pixel_values.shape[:2]
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
flat_images = pixel_values.flatten(0, 1)
num_valid = int(flat_mask.sum().item())
if num_valid == 0:
raise ValueError("At least one image view must be valid per batch.")
valid_images = flat_images[flat_mask]
valid_feats = self.vlm._encode_image(valid_images)
tokens_per_view, hidden_dim = valid_feats.shape[1:]
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
image_features[flat_mask] = valid_feats
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
image_features[:, 0],
inputs_embeds,
)
enc_out = self.vlm.language_model.model.encoder(
attention_mask=attention_mask,
inputs_embeds=merged_embeds,
)[0]
aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
def forward(
self,
input_ids: torch.LongTensor,
image_input: torch.FloatTensor,
image_mask: torch.Tensor,
domain_id: torch.LongTensor,
proprio: torch.Tensor,
action: torch.Tensor,
) -> dict[str, torch.Tensor]:
"""
Forward pass for the XVLA model.
"""
target_dtype = self._get_target_dtype()
image_input = image_input.to(dtype=target_dtype)
proprio = proprio.to(dtype=target_dtype)
action = action.to(dtype=target_dtype)
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
t = (
torch.rand(1, device=input_ids.device, dtype=target_dtype)
+ torch.arange(batch_size, device=input_ids.device, dtype=target_dtype) / batch_size
) % (1 - 1e-5)
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
pred_action = self.transformer(
domain_id=domain_id,
action_with_noise=action_noisy_m,
t=t,
proprio=proprio_m,
**enc,
)
return self.action_space.compute_loss(pred_action, action)
@torch.no_grad()
def generate_actions(
self,
input_ids: torch.LongTensor,
image_input: torch.FloatTensor,
image_mask: torch.Tensor,
domain_id: torch.LongTensor,
proprio: torch.Tensor,
steps: int,
) -> torch.Tensor:
self.eval()
target_dtype = self._get_target_dtype()
image_input = image_input.to(dtype=target_dtype)
proprio = proprio.to(dtype=target_dtype)
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
action_dim = self.dim_action
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=target_dtype)
action = torch.zeros_like(x1)
steps = max(1, int(steps))
for i in range(steps, 0, -1):
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=target_dtype)
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
action = self.transformer(
domain_id=domain_id,
action_with_noise=x_t_m,
proprio=proprio_m,
t=t,
**enc,
)
return self.action_space.postprocess(action)
class XVLAPolicy(PreTrainedPolicy):
"""LeRobot-compliant wrapper built around the XVLA model."""
config_class = XVLAConfig
name = "xvla"
def __init__(self, config: XVLAConfig):
super().__init__(config)
config.validate_features()
florence_config = config.get_florence_config()
proprio_dim = config.max_state_dim if config.use_proprio else 0
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
self.reset()
def reset(self) -> None:
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
def get_optim_params(self) -> dict:
"""Return trainable named parameters for optimization.
Returns a dict of name -> param for all trainable parameters.
This enables the xvla-adamw optimizer to apply differential learning rates
based on parameter names (e.g., 1/10 LR for VLM components).
"""
return dict(filter(lambda kv: kv[1].requires_grad, self.named_parameters()))
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
if not self.config.use_proprio or OBS_STATE not in batch:
return torch.zeros(batch_size, 0, device=device)
state = batch[OBS_STATE]
if state.ndim > 2:
state = state[:, -1, :]
return pad_vector(state, self.model.dim_proprio)
def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
present_img_keys = [key for key in self.config.image_features if key in batch]
if len(present_img_keys) == 0:
raise ValueError(
"All image features are missing from the batch. "
f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
)
images = []
masks = []
for key in present_img_keys:
img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
images.append(img)
masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
stacked_imgs = torch.stack(images, dim=1)
stacked_masks = torch.stack(masks, dim=1)
total_views = self.config.num_image_views or stacked_imgs.size(1)
total_views = max(total_views, stacked_imgs.size(1))
num_pad = total_views - stacked_imgs.size(1)
if num_pad > 0:
pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
pad_imgs = stacked_imgs.new_zeros(pad_shape)
pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
return stacked_imgs, stacked_masks
def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
candidate = None
if self.config.domain_feature_key and self.config.domain_feature_key in batch:
candidate = batch[self.config.domain_feature_key]
elif "domain_id" in batch:
candidate = batch["domain_id"]
if candidate is None:
return torch.zeros(batch_size, dtype=torch.long, device=device)
if not isinstance(candidate, torch.Tensor):
candidate = torch.as_tensor(candidate, device=device)
else:
candidate = candidate.to(device=device)
if candidate.ndim == 0:
candidate = candidate.expand(batch_size)
if candidate.ndim > 1:
candidate = candidate.view(candidate.shape[0], -1)[:, 0]
if candidate.shape[0] != batch_size:
candidate = candidate.expand(batch_size)
return candidate.to(dtype=torch.long)
def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
if ACTION not in batch:
raise ValueError("Batch is missing action targets required for training.")
actions = batch[ACTION]
if actions.ndim == 2:
actions = actions.unsqueeze(1)
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
if actions.shape[-1] != self.model.dim_action:
actions = pad_vector(actions, self.model.dim_action)
return actions
def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
input_ids = batch[OBS_LANGUAGE_TOKENS]
batch_size = input_ids.shape[0]
images, image_mask = self._prepare_images(batch)
domain_id = self._get_domain_id(batch, batch_size, images.device)
proprio = self._prepare_state(batch, batch_size, images.device)
return {
"input_ids": input_ids,
"image_input": images,
"image_mask": image_mask,
"domain_id": domain_id,
"proprio": proprio,
}
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
inputs = self._build_model_inputs(batch)
targets = self._prepare_action_targets(batch)
losses = self.model(action=targets, **inputs)
total_loss = sum(losses.values())
log_dict = {k: v.detach().item() for k, v in losses.items()}
log_dict["loss"] = total_loss.detach().item()
return total_loss, log_dict
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
inputs = self._build_model_inputs(batch)
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
return actions
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
return self._get_action_chunk(batch)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if len(self._queues[ACTION]) == 0:
actions = self._get_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = False,
**kwargs,
):
"""
Loads XVLA model weights with:
- automatic prefix 'model.' added to all keys
- skip list for layers that should remain randomly initialized
"""
import safetensors.torch
# step 1: load config
# TODO: jadechoghari, fix this
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
# step 2: locate model.safetensors
if os.path.isdir(model_id):
logging.info("Loading weights from local directory")
model_file = os.path.join(model_id, "model.safetensors")
else:
try:
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
model_file = hf_hub_download(
repo_id=model_id,
filename="model.safetensors",
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
logging.info(f"Loading checkpoint from {model_file}")
# step 3: load state dict
state_dict = safetensors.torch.load_file(model_file)
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
shared_key = "model.vlm.language_model.model.shared.weight"
if encoder_key in state_dict:
state_dict[shared_key] = state_dict[encoder_key]
# or deepcopy
# step 4: load into instance
instance.load_state_dict(state_dict, strict=True)
logging.info("Loaded XVLA checkpoint")
# step 5: finalize
# Reapply dtype after loading state dict
instance.model._apply_dtype()
instance.to(config.device)
instance.eval()
return instance
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
current_height, current_width = img.shape[2:]
if current_height == height and current_width == width:
return img
ratio = max(current_width / width, current_height / height)
resized_height = int(current_height / ratio)
resized_width = int(current_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, height - resized_height)
pad_width = max(0, width - resized_width)
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
if vector.shape[-1] == new_dim:
return vector
if new_dim == 0:
shape = list(vector.shape)
shape[-1] = 0
return vector.new_zeros(*shape)
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = vector.new_zeros(*shape)
length = min(current_dim, new_dim)
new_vector[..., :length] = vector[..., :length]
return new_vector
def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
current_len = tensor.size(dim)
if current_len == target_len:
return tensor
if current_len > target_len:
slices = [slice(None)] * tensor.dim()
slices[dim] = slice(0, target_len)
return tensor[tuple(slices)]
pad_shape = list(tensor.shape)
pad_shape[dim] = target_len - current_len
pad_tensor = tensor.new_zeros(pad_shape)
return torch.cat([tensor, pad_tensor], dim=dim)

View File

@@ -1,554 +0,0 @@
# ------------------------------------------------------------------------------
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.datasets.factory import IMAGENET_STATS
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
ObservationProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_IMAGES,
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def make_xvla_pre_post_processors(
config: XVLAConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Build the LeRobot processor pipelines for XVLA.
"""
features = {**config.input_features, **config.output_features}
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.tokenizer_name,
max_length=config.tokenizer_max_length,
padding=config.pad_language_to,
padding_side=config.tokenizer_padding_side,
),
XVLAImageToFloatProcessorStep(),
XVLAImageNetNormalizeProcessorStep(),
XVLAAddDomainIdProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
# Custom XVLA processor steps
@dataclass
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
if key == f"{OBS_IMAGES}.image":
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(20,),
dtype="float32",
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
"""
Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
Args:
rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
Returns:
Tensor: 6D rotation representation, shape (B, 6)
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 3, 3)
"""
if not isinstance(rot_mats, torch.Tensor):
raise TypeError(f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}")
if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
raise ValueError(f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}")
rot_mats = rot_mats.to(torch.float32)
col1 = rot_mats[:, :3, 0] # (B, 3)
col2 = rot_mats[:, :3, 1] # (B, 3)
rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
return rot6d
def observation(self, observation):
return self._process_observation(observation)
@dataclass
@ProcessorStepRegistry.register(name="xvla_image_scale")
class XVLAImageScaleProcessorStep(ProcessorStep):
"""Scale image observations by 255 to convert from [0, 1] to [0, 255] range.
This processor step multiplies all image observations by 255, which is required
for XVLA models that expect images in uint8-like range.
Args:
image_keys: List of observation keys that contain images to scale.
If None, will automatically detect keys starting with "observation.images."
"""
image_keys: list[str] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Scale image observations by 255."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to scale
keys_to_scale = self.image_keys
if keys_to_scale is None:
# Auto-detect image keys
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
# Scale each image
for key in keys_to_scale:
if key in obs and isinstance(obs[key], torch.Tensor):
obs[key] = obs[key] * 255
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""Image scaling doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_image_to_float")
class XVLAImageToFloatProcessorStep(ProcessorStep):
"""Convert image observations from [0, 255] to [0, 1] range.
This processor step divides image observations by 255 to convert from uint8-like
range [0, 255] to float range [0, 1]. This is typically used when loading images
that are stored as uint8 values.
Args:
image_keys: List of observation keys that contain images to convert.
If None, will automatically detect keys starting with "observation.images."
validate_range: If True, validates that input values are in [0, 255] range (default: True)
Raises:
ValueError: If validate_range is True and image values are not in [0, 255] range.
"""
image_keys: list[str] | None = None
validate_range: bool = True
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Convert image observations from [0, 255] to [0, 1]."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to convert
keys_to_convert = self.image_keys
if keys_to_convert is None:
# Auto-detect image keys
keys_to_convert = [k for k in obs if k.startswith("observation.images.")]
# Convert each image
for key in keys_to_convert:
if key in obs and isinstance(obs[key], torch.Tensor):
tensor = obs[key]
min_val = tensor.min().item()
max_val = tensor.max().item()
if max_val <= 1.0:
obs[key] = tensor.float() # ensure float dtype, but no division
continue
# Validate that values are in [0, 255] range if requested
if self.validate_range and (min_val < 0.0 or max_val > 255.0):
raise ValueError(
f"Image '{key}' has values outside [0, 255] range: "
f"min={min_val:.4f}, max={max_val:.4f}. "
f"Cannot convert to [0, 1] range."
)
# Convert to float and divide by 255
obs[key] = tensor.float() / 255.0
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""Image conversion doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
"validate_range": self.validate_range,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_imagenet_normalize")
class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
"""Normalize image observations using ImageNet statistics.
This processor step applies ImageNet normalization (mean and std) to image observations.
It validates that input values are in the [0, 1] range before normalizing.
The normalization formula is: (image - mean) / std
Args:
image_keys: List of observation keys that contain images to normalize.
If None, will automatically detect keys starting with "observation.images."
Raises:
ValueError: If image values are not in the [0, 1] range.
"""
image_keys: list[str] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Normalize image observations using ImageNet statistics."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to normalize
keys_to_normalize = self.image_keys
if keys_to_normalize is None:
# Auto-detect image keys
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
# Normalize each image
for key in keys_to_normalize:
if key in obs and isinstance(obs[key], torch.Tensor):
tensor = obs[key]
# Validate that values are in [0, 1] range
min_val = tensor.min().item()
max_val = tensor.max().item()
if min_val < 0.0 or max_val > 1.0:
raise ValueError(
f"Image '{key}' has values outside [0, 1] range: "
f"min={min_val:.4f}, max={max_val:.4f}. "
f"ImageNet normalization requires input values in [0, 1]."
)
# Apply ImageNet normalization
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
while mean.dim() < tensor.dim():
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
# Normalize: (image - mean) / std
obs[key] = (tensor - mean) / std
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""ImageNet normalization doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_add_domain_id")
class XVLAAddDomainIdProcessorStep(ProcessorStep):
"""Add domain_id to complementary data.
This processor step adds a domain_id tensor to the complementary data,
which is used by XVLA to identify different robot embodiments or task domains.
Args:
domain_id: The domain ID to add (default: 3)
"""
domain_id: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add domain_id to complementary data."""
new_transition = transition.copy()
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
comp = {} if comp is None else comp.copy()
# Infer batch size from observation tensors
obs = new_transition.get(TransitionKey.OBSERVATION, {})
batch_size = 1
if obs:
for v in obs.values():
if isinstance(v, torch.Tensor):
batch_size = v.shape[0]
break
# Add domain_id tensor
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return new_transition
def transform_features(self, features):
"""Domain ID addition doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"domain_id": self.domain_id,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_rotation_6d_to_axis_angle")
class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
"""Convert 6D rotation representation to axis-angle and reorganize action dimensions.
This processor step takes actions with 6D rotation representation and converts them to
axis-angle representation, reorganizing the action dimensions as:
- action[:, :3] -> target_eef (end-effector position)
- action[:, 3:9] -> 6D rotation (converted to axis-angle, 3D)
- action[:, 9:10] -> gripper action
Final output: [target_eef (3), axis_angle (3), gripper (1)] = 7D action
Args:
expected_action_dim: Expected input action dimension (default: 10, supports 6D rotation + extras)
"""
expected_action_dim: int = 10
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Convert 6D rotation to axis-angle in action."""
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or not isinstance(action, torch.Tensor):
return new_transition
# Convert to numpy for processing
device = action.device
dtype = action.dtype
action_np = action.cpu().numpy()
# Extract components
# action shape: (B, D) where D >= 10
target_eef = action_np[:, :3] # (B, 3)
rotation_6d = action_np[:, 3:9] # (B, 6)
target_act = action_np[:, 9:10] # (B, 1)
# Convert 6D rotation to axis-angle
target_axis = rotate6d_to_axis_angle(rotation_6d) # (B, 3)
# Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D
action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1)
# Convert gripper action to -1 or 1
action_np[:, -1] = np.where(action_np[:, -1] > 0.5, 1.0, -1.0)
# Convert back to tensor
action = torch.from_numpy(action_np).to(device=device, dtype=dtype)
new_transition[TransitionKey.ACTION] = action
return new_transition
def transform_features(self, features):
"""Rotation conversion changes action dimension from 10 to 7."""
# Note: This is a simplified version. In practice, you might want to
# update the action feature shape in the features dict.
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"expected_action_dim": self.expected_action_dim,
}
def make_xvla_libero_pre_post_processors() -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Build the LeRobot processor pipelines for XVLA with LIBERO environment.
"""
pre_processor_steps: list[ProcessorStep] = []
post_processor_steps: list[ProcessorStep] = []
pre_processor_steps.extend(
[LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()]
)
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=pre_processor_steps,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=post_processor_steps,
),
)

View File

@@ -1,415 +0,0 @@
# ------------------------------------------------------------------------------
# Copyright 2025 2toINF (https://github.com/2toINF)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------
from __future__ import annotations
import math
from collections.abc import Iterable
from functools import partial
from typing import Final
import torch
import torch.nn as nn
import torch.nn.functional as functional
# ------------------------------- Small utils ----------------------------------
def _to_2tuple(x) -> tuple:
"""Minimal replacement for timm.layers.to_2tuple."""
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
t = tuple(x)
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
return (x, x)
def _has_sdp_attention() -> bool:
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
return hasattr(functional, "scaled_dot_product_attention")
# ---------------------------------- MLP --------------------------------------
class Mlp(nn.Module):
"""
MLP used in ViT-style blocks.
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
"""
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
norm_layer: type[nn.Module] | None = None,
bias: bool | tuple[bool, bool] = True,
drop: float | tuple[float, float] = 0.0,
use_conv: bool = False,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = _to_2tuple(bias)
drop_probs = _to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = nn.GELU(approximate="tanh")
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
# -------------------------------- Attention ----------------------------------
class Attention(nn.Module):
"""
Multi-Head Self-Attention with optional fused SDPA fallback.
If PyTorch provides `scaled_dot_product_attention`, it will be used
(usually faster and more stable); otherwise we use a manual implementation.
"""
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = _has_sdp_attention()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, shape [batch_size, seq_len, channels]
Input sequence.
Returns
-------
Tensor, shape [batch_size, seq_len, channels]
Output sequence after MHSA + projection.
"""
batch_size, seq_len, channels = x.shape
qkv = (
self.qkv(x)
.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim]
)
q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim]
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
) # [batch_size, num_heads, seq_len, head_dim]
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [batch_size, num_heads, seq_len, head_dim]
x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels]
x = self.proj(x)
x = self.proj_drop(x)
return x
# ------------------------------- Utilities -----------------------------------
def basic_init(module: nn.Module) -> None:
"""
Apply a basic initialization scheme to Linear layers.
- Weight: Xavier uniform initialization.
- Bias: Set to zero.
"""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0.0)
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
"""
Create sinusoidal timestep embeddings.
Parameters
----------
t : torch.Tensor
Shape [B]. Each element is a timestep index, may be fractional.
dim : int
Dimensionality of the output embedding.
max_period : int, default=100
Controls the minimum frequency of the sinusoids.
Returns
-------
torch.Tensor
Shape [B, dim]. Sinusoidal embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2 == 1:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
# ------------------------------- Core Layers ----------------------------------
class DomainAwareLinear(nn.Module):
"""
Linear layer with domain-conditioned parameters (per-sample).
Each domain has its own weight and bias vectors, stored in embeddings.
"""
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.fc = nn.Embedding(num_domains, output_size * input_size)
self.bias = nn.Embedding(num_domains, output_size)
nn.init.xavier_uniform_(self.fc.weight)
nn.init.zeros_(self.bias.weight)
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor
[B, I] or [B, T, I]
domain_id : LongTensor
[B], domain indices.
Returns
-------
Tensor
[batch_size, output_size] or [batch_size, seq_len, output_size]
"""
batch_size = domain_id.shape[0]
squeeze_seq = False
if x.dim() == 2:
x = x.unsqueeze(1)
squeeze_seq = True
weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size)
bias = self.bias(domain_id).view(batch_size, self.output_size)
y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size)
if squeeze_seq:
y = y.squeeze(1)
return y
class TransformerBlock(nn.Module):
"""
Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
"""
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
drop=0.1,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, [B, T, H]
Returns
-------
Tensor, [B, T, H]
"""
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
# --------------------------- Main Model ---------------------------------------
class SoftPromptedTransformer(nn.Module):
"""
Multi-modal, domain-aware Transformer with optional soft prompts.
See parameter and forward I/O descriptions inside the docstrings.
"""
def __init__(
self,
hidden_size: int = 768,
multi_modal_input_size: int = 768,
depth: int = 24,
num_heads: int = 16,
mlp_ratio: float = 4.0,
num_domains: int = 20,
dim_action: int = 20,
dim_propio: int = 20,
dim_time: int = 32,
len_soft_prompts: int = 32,
max_len_seq: int = 512,
use_hetero_proj: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.dim_action = dim_action
self.dim_time = dim_time
self.len_soft_prompts = len_soft_prompts
self.use_hetero_proj = use_hetero_proj
self.blocks = nn.ModuleList(
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
)
if use_hetero_proj:
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
self.aux_visual_proj = DomainAwareLinear(
multi_modal_input_size, hidden_size, num_domains=num_domains
)
else:
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
nn.init.normal_(self.pos_emb, std=0.02)
self.norm = nn.LayerNorm(hidden_size)
self.action_encoder = DomainAwareLinear(
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
)
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
if len_soft_prompts > 0:
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
self.apply(basic_init)
def forward(
self,
domain_id: torch.LongTensor,
vlm_features: torch.Tensor,
aux_visual_inputs: torch.Tensor,
action_with_noise: torch.Tensor,
proprio: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass.
Inputs
------
domain_id : [B]
vlm_features : [B, T_vlm, D]
aux_visual_inputs : [B, T_aux, D]
action_with_noise : [B, T_action, dim_action]
proprio : [B, dim_propio]
t : [B]
Returns
-------
Tensor
Predicted actions, [batch_size, num_actions, dim_action]
"""
batch_size, num_actions = action_with_noise.shape[:2]
# Encode (action + proprio + time) → tokens
time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time]
time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time)
proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1])
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size]
# Project visual streams and concatenate
if self.use_hetero_proj:
x = torch.cat(
[
x,
self.vlm_proj(vlm_features, domain_id),
self.aux_visual_proj(aux_visual_inputs, domain_id),
],
dim=1,
)
else:
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
# Add positional embeddings (truncate if needed)
seq_len = x.shape[1]
if seq_len > self.pos_emb.shape[1]:
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
x = x + self.pos_emb[:, :seq_len, :]
# Append soft prompts
if self.len_soft_prompts > 0:
soft_prompts = self.soft_prompt_hub(domain_id).view(
batch_size, self.len_soft_prompts, self.hidden_size
)
x = torch.cat([x, soft_prompts], dim=1)
# Transformer backbone
for block in self.blocks:
x = block(x)
# Decode only the action segment
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)

View File

@@ -1,138 +0,0 @@
import math
import numpy as np
def mat2quat(rmat):
"""
Converts given rotation matrix to quaternion.
Args:
rmat (np.array): 3x3 rotation matrix
Returns:
np.array: (x,y,z,w) float quaternion angles
"""
mat = np.asarray(rmat).astype(np.float32)[:3, :3]
m00 = mat[0, 0]
m01 = mat[0, 1]
m02 = mat[0, 2]
m10 = mat[1, 0]
m11 = mat[1, 1]
m12 = mat[1, 2]
m20 = mat[2, 0]
m21 = mat[2, 1]
m22 = mat[2, 2]
# symmetric matrix k
k = np.array(
[
[m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
[m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
[m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
]
)
k /= 3.0
# quaternion is Eigen vector of k that corresponds to largest eigenvalue
w, v = np.linalg.eigh(k)
inds = np.array([3, 0, 1, 2])
q1 = v[inds, np.argmax(w)]
if q1[0] < 0.0:
np.negative(q1, q1)
inds = np.array([1, 2, 3, 0])
return q1[inds]
def quat2axisangle(quat):
"""
Converts quaternion to axis-angle format.
Returns a unit vector direction scaled by its angle in radians.
Args:
quat (np.array): (x,y,z,w) vec4 float angles
Returns:
np.array: (ax,ay,az) axis-angle exponential coordinates
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
def rotate6d_to_axis_angle(r6d):
"""
r6d: np.ndarray, shape (N, 6)
return: np.ndarray, shape (N, 3), axis-angle vectors
"""
flag = 0
if len(r6d.shape) == 1:
r6d = r6d[None, ...]
flag = 1
a1 = r6d[:, 0:3]
a2 = r6d[:, 3:6]
# b1
b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-6)
# b2
dot_prod = np.sum(b1 * a2, axis=-1, keepdims=True)
b2_orth = a2 - dot_prod * b1
b2 = b2_orth / (np.linalg.norm(b2_orth, axis=-1, keepdims=True) + 1e-6)
# b3
b3 = np.cross(b1, b2, axis=-1)
rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
axis_angle_list = []
for i in range(rotation_matrix.shape[0]):
quat = mat2quat(rotation_matrix[i])
axis_angle = quat2axisangle(quat)
axis_angle_list.append(axis_angle)
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
if flag == 1:
axis_angle_array = axis_angle_array[0]
return axis_angle_array
def mat_to_rotate6d(abs_action):
if len(abs_action.shape) == 2:
return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
elif len(abs_action.shape) == 3:
return np.concatenate([abs_action[:, :3, 0], abs_action[:, :3, 1]], axis=-1)
else:
raise NotImplementedError
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor

View File

@@ -1,154 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
This step handles the specific observation structure from LIBERO environments,
which includes nested robot_state dictionaries and image observations.
**State Processing:**
- Processes the `robot_state` dictionary which contains nested end-effector,
gripper, and joint information.
- Extracts and concatenates:
- End-effector position (3D)
- End-effector quaternion converted to axis-angle (3D)
- Gripper joint positions (2D)
- Maps the concatenated state to `"observation.state"`.
**Image Processing:**
- Rotates images by 180 degrees by flipping both height and width dimensions.
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
"""
def _process_observation(self, observation):
"""
Processes both image and robot_state observations from LIBERO.
"""
processed_obs = observation.copy()
for key in list(processed_obs.keys()):
if key.startswith(f"{OBS_IMAGES}."):
img = processed_obs[key]
# Flip both H and W
img = torch.flip(img, dims=[2, 3])
processed_obs[key] = img
# Process robot_state into a flat state vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
# Extract components
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into a single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
# ensure float32
state = state.float()
if state.dim() == 1:
state = state.unsqueeze(0)
processed_obs[OBS_STATE] = state
return processed_obs
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
Transforms feature keys from the LIBERO format to the LeRobot standard.
"""
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
# copy over non-STATE features
for ft, feats in features.items():
if ft != PipelineFeatureType.STATE:
new_features[ft] = feats.copy()
# rebuild STATE features
state_feats = {}
# add our new flattened state
state_feats["observation.state"] = PolicyFeature(
key="observation.state",
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
dtype="float32",
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
)
new_features[PipelineFeatureType.STATE] = state_feats
return new_features
def observation(self, observation):
return self._process_observation(observation)
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
"""
Convert batched quaternions to axis-angle format.
Only accepts torch tensors of shape (B, 4).
Args:
quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format
Returns:
Tensor: (B, 3) axis-angle vectors
Raises:
TypeError: if input is not a torch tensor
ValueError: if shape is not (B, 4)
"""
if not isinstance(quat, torch.Tensor):
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
if quat.ndim != 2 or quat.shape[1] != 4:
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
quat = quat.to(dtype=torch.float32)
device = quat.device
batch_size = quat.shape[0]
w = quat[:, 3].clamp(-1.0, 1.0)
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
result = torch.zeros((batch_size, 3), device=device)
mask = den > 1e-10
if mask.any():
angle = 2.0 * torch.acos(w[mask]) # (M,)
axis = quat[mask, :3] / den[mask].unsqueeze(1)
result[mask] = axis * angle.unsqueeze(1)
return result

View File

@@ -78,7 +78,7 @@ from lerobot.transport.utils import (
transitions_to_bytes,
)
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.transition import (
Transition,
move_state_dict_to_device,
@@ -398,7 +398,7 @@ def act_with_policy(
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
precise_sleep(1 / cfg.env.fps - dt_time)
busy_wait(1 / cfg.env.fps - dt_time)
# Communication Functions - Group all gRPC/messaging functions

View File

@@ -15,8 +15,7 @@
# limitations under the License.
import functools
import itertools
from collections.abc import Callable, Generator, Sequence
from collections.abc import Callable, Sequence
from contextlib import suppress
from typing import TypedDict
@@ -30,20 +29,13 @@ from lerobot.utils.transition import Transition
class BatchTransition(TypedDict):
"""Batch transition for single-step RL algorithms.
Uses Gymnasium terminology:
- terminated: True termination due to task success/failure
- truncated: Termination due to time limit or other external factors
"""
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: torch.Tensor
next_state: dict[str, torch.Tensor]
terminated: torch.Tensor # True termination due to task success/failure
truncated: torch.Tensor # Termination due to time limit
complementary_info: dict[str, torch.Tensor] | None
done: torch.Tensor
truncated: torch.Tensor
complementary_info: dict[str, torch.Tensor | float | int] | None = None
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
@@ -86,8 +78,6 @@ def random_shift(images: torch.Tensor, pad: int = 4):
class ReplayBuffer:
"""Replay buffer for storing transitions used in RL training (e.g., SAC)."""
def __init__(
self,
capacity: int,
@@ -143,7 +133,7 @@ class ReplayBuffer:
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
complementary_info: dict[str, torch.Tensor | float | int] | None = None,
complementary_info: dict[str, torch.Tensor] | None = None,
):
"""Initialize the storage tensors based on the first transition."""
# Determine shapes from the first transition
@@ -169,8 +159,8 @@ class ReplayBuffer:
# Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency
self.terminated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize storage for complementary_info
self.has_complementary_info = complementary_info is not None
@@ -205,7 +195,7 @@ class ReplayBuffer:
next_state: dict[str, torch.Tensor],
done: bool,
truncated: bool,
complementary_info: dict[str, torch.Tensor | float | int] | None = None,
complementary_info: dict[str, torch.Tensor] | None = None,
):
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition
@@ -222,8 +212,8 @@ class ReplayBuffer:
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
self.terminated[self.position] = done
self.truncated[self.position] = truncated
self.dones[self.position] = done
self.truncateds[self.position] = truncated
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
@@ -293,8 +283,8 @@ class ReplayBuffer:
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_terminated = self.terminated[idx].to(self.device).float()
batch_truncated = self.truncated[idx].to(self.device).float()
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
@@ -308,8 +298,8 @@ class ReplayBuffer:
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_state,
terminated=batch_terminated,
truncated=batch_truncated,
done=batch_dones,
truncated=batch_truncateds,
complementary_info=batch_complementary_info,
)
@@ -441,6 +431,7 @@ class ReplayBuffer:
device (str): The device for sampling tensors. Defaults to "cuda:0".
state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`.
capacity (int | None): Buffer capacity. If None, uses dataset length.
action_mask (Sequence[int] | None): Indices of action dimensions to keep.
image_augmentation_function (Callable | None): Function for image augmentation.
If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling.
@@ -469,16 +460,12 @@ class ReplayBuffer:
optimize_memory=optimize_memory,
)
# Convert dataset to transitions generator
transitions_generator = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Get first transition to initialize storage
first_transition = next(transitions_generator, None)
# Convert dataset to transitions
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Initialize the buffer with the first transition to set up storage tensors
if first_transition is not None:
if list_transition:
first_transition = list_transition[0]
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition[ACTION].to(device)
@@ -496,28 +483,26 @@ class ReplayBuffer:
state=first_state, action=first_action, complementary_info=first_complementary_info
)
# Fill the buffer with all transitions (first + remaining)
if first_transition is not None:
for data in itertools.chain([first_transition], transitions_generator):
for k, v in data.items():
if isinstance(v, dict):
for key, tensor in v.items():
if isinstance(tensor, torch.Tensor):
v[key] = tensor.to(storage_device)
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
# Fill the buffer with all transitions
for data in list_transition:
for k, v in data.items():
if isinstance(v, dict):
for key, tensor in v.items():
v[key] = tensor.to(storage_device)
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
action = data[ACTION]
action = data[ACTION]
replay_buffer.add(
state=data["state"],
action=action,
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
truncated=data["truncated"],
complementary_info=data.get("complementary_info"),
)
replay_buffer.add(
state=data["state"],
action=action,
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
complementary_info=data.get("complementary_info", None),
)
return replay_buffer
@@ -591,12 +576,10 @@ class ReplayBuffer:
for key in self.states:
frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done (done = terminated or truncated)
# Fill action, reward, done
frame_dict[ACTION] = self.actions[actual_idx].cpu()
frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict[DONE] = torch.tensor(
[self.terminated[actual_idx] or self.truncated[actual_idx]], dtype=torch.bool
).cpu()
frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
frame_dict["task"] = task_name
# Add complementary_info if available
@@ -616,7 +599,7 @@ class ReplayBuffer:
lerobot_dataset.add_frame(frame_dict)
# If we reached an episode boundary, call save_episode, reset counters
if self.terminated[actual_idx] or self.truncated[actual_idx]:
if self.dones[actual_idx] or self.truncateds[actual_idx]:
lerobot_dataset.save_episode()
# Save any remaining frames in the buffer
@@ -632,11 +615,9 @@ class ReplayBuffer:
def _lerobotdataset_to_transitions(
dataset: LeRobotDataset,
state_keys: Sequence[str] | None = None,
) -> Generator[Transition, None, None]:
) -> list[Transition]:
"""
Convert a LeRobotDataset into a generator of RL (s, a, r, s', done) transitions.
Using a generator instead of a list is more memory efficient for large datasets.
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
Args:
dataset (LeRobotDataset):
@@ -656,12 +637,14 @@ class ReplayBuffer:
["observation.state", "observation.environment_state"].
If None, you must handle or define default keys.
Yields:
Transition: A transition dictionary.
Returns:
transitions (List[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
"""
if state_keys is None:
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
transitions = []
num_frames = len(dataset)
# Check if the dataset has "next.done" key
@@ -704,17 +687,8 @@ class ReplayBuffer:
if next_sample["episode_index"] != current_sample["episode_index"]:
done = True
# Handle truncation separately from done
# This is important if the dataset has truncations (e.g., time limits)
truncated = False
if not done:
# If this is the last frame or if next frame is in a different episode, mark as truncated
if i == num_frames - 1:
truncated = True
elif i < num_frames - 1:
next_sample = dataset[i + 1]
if next_sample["episode_index"] != current_sample["episode_index"]:
truncated = True
# TODO: (azouitine) Handle truncation (using the same value as done for now)
truncated = done
# ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state.
@@ -742,6 +716,7 @@ class ReplayBuffer:
if isinstance(val, torch.Tensor):
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
else:
# TODO: (azouitine) Check if it's necessary to convert to tensor
# For non-tensor values, use directly
complementary_info[clean_key] = val
@@ -755,10 +730,12 @@ class ReplayBuffer:
truncated=truncated,
complementary_info=complementary_info,
)
transitions.append(transition)
yield transition
return transitions
# Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t, name: str):
"""
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
@@ -828,9 +805,9 @@ def concatenate_batch_transitions(
for key in left_batch_transitions["next_state"]
}
# Concatenate terminated and truncated fields
left_batch_transitions["terminated"] = torch.cat(
[left_batch_transitions["terminated"], right_batch_transition["terminated"]], dim=0
# Concatenate done and truncated fields
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
)
left_batch_transitions["truncated"] = torch.cat(
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],

View File

@@ -74,7 +74,7 @@ from lerobot.teleoperators import (
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
logging.basicConfig(level=logging.INFO)
@@ -114,7 +114,7 @@ def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> No
for pose in trajectory:
action_dict = dict(zip(current_position_dict, pose, strict=False))
robot_arm.bus.sync_write("Goal_Position", action_dict)
precise_sleep(0.015)
busy_wait(0.015)
class RobotEnv(gym.Env):
@@ -238,7 +238,7 @@ class RobotEnv(gym.Env):
reset_follower_position(self.robot, np.array(self.reset_pose))
log_say("Reset the environment done.", play_sounds=True)
precise_sleep(self.reset_time_s - (time.perf_counter() - start_time))
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
super().reset(seed=seed, options=options)
@@ -713,7 +713,7 @@ def control_loop(
transition = env_processor(transition)
# Maintain fps timing
precise_sleep(dt - (time.perf_counter() - step_start_time))
busy_wait(dt - (time.perf_counter() - step_start_time))
if dataset is not None and cfg.dataset.push_to_hub:
logging.info("Pushing dataset to hub")
@@ -745,7 +745,7 @@ def replay_trajectory(
)
transition = action_processor(transition)
env.step(transition[TransitionKey.ACTION])
precise_sleep(1 / cfg.env.fps - (time.perf_counter() - start_time))
busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time))
@parser.wrap()

View File

@@ -1,20 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
from .robot_earthrover_mini_plus import EarthRoverMiniPlus
__all__ = ["EarthRoverMiniPlus", "EarthRoverMiniPlusConfig"]

View File

@@ -1,35 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration for EarthRover Mini Plus robot."""
from dataclasses import dataclass
from ..config import RobotConfig
@RobotConfig.register_subclass("earthrover_mini_plus")
@dataclass
class EarthRoverMiniPlusConfig(RobotConfig):
"""Configuration for EarthRover Mini Plus robot using Frodobots SDK.
This robot uses cloud-based control via the Frodobots SDK HTTP API.
Camera frames are accessed directly through SDK HTTP endpoints.
Attributes:
sdk_url: URL of the Frodobots SDK server (default: http://localhost:8000)
"""
sdk_url: str = "http://localhost:8000"

View File

@@ -1 +0,0 @@
../../../../docs/source/earthrover_mini_plus.mdx

View File

@@ -1,473 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""EarthRover Mini Plus robot using Frodobots SDK."""
import base64
import logging
from functools import cached_property
from typing import Any
import cv2
import numpy as np
import requests
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..robot import Robot
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
logger = logging.getLogger(__name__)
# Action feature keys
ACTION_LINEAR_VEL = "linear.vel"
ACTION_ANGULAR_VEL = "angular.vel"
# Observation feature keys
OBS_FRONT = "front"
OBS_REAR = "rear"
OBS_LINEAR_VEL = "linear.vel"
OBS_BATTERY_LEVEL = "battery.level"
OBS_ORIENTATION_DEG = "orientation.deg"
OBS_GPS_LATITUDE = "gps.latitude"
OBS_GPS_LONGITUDE = "gps.longitude"
OBS_GPS_SIGNAL = "gps.signal"
OBS_SIGNAL_LEVEL = "signal.level"
OBS_VIBRATION = "vibration"
OBS_LAMP_STATE = "lamp.state"
class EarthRoverMiniPlus(Robot):
"""
EarthRover Mini Plus robot controlled via Frodobots SDK HTTP API.
This robot uses cloud-based control through the Frodobots SDK instead of direct
hardware connection. Cameras stream via WebRTC through Agora cloud, and control
commands are sent via HTTP POST requests.
The robot supports:
- Dual cameras (front and rear) accessed via SDK HTTP endpoints
- Linear and angular velocity control
- Battery and orientation telemetry
Attributes:
config: Robot configuration
sdk_base_url: URL of the Frodobots SDK server (default: http://localhost:8000)
"""
config_class = EarthRoverMiniPlusConfig
name = "earthrover_mini_plus"
def __init__(self, config: EarthRoverMiniPlusConfig):
"""Initialize EarthRover Mini Plus robot.
Args:
config: Robot configuration including SDK URL
"""
super().__init__(config)
self.config = config
self.sdk_base_url = "http://localhost:8000"
# Empty cameras dict for compatibility with recording script
# Cameras are accessed directly via SDK, not through Camera objects
self.cameras = {}
self._is_connected = False
# Cache for camera frames (fallback when requests fail)
self._last_front_frame = None
self._last_rear_frame = None
# Cache for robot telemetry data (fallback when requests fail)
self._last_robot_data = None
logger.info(f"Initialized {self.name} with SDK at {self.sdk_base_url}")
@property
def is_connected(self) -> bool:
"""Check if robot is connected to SDK."""
return self._is_connected
def connect(self, calibrate: bool = True) -> None:
"""Connect to robot via Frodobots SDK.
Args:
calibrate: Not used for SDK-based robot (kept for API compatibility)
Raises:
DeviceAlreadyConnectedError: If robot is already connected
DeviceNotConnectedError: If cannot connect to SDK server
"""
if self._is_connected:
raise DeviceAlreadyConnectedError(f"{self.name} is already connected")
# Verify SDK is running and accessible
try:
response = requests.get(f"{self.sdk_base_url}/data", timeout=10.0)
if response.status_code != 200:
raise DeviceNotConnectedError(
f"Cannot connect to SDK at {self.sdk_base_url}. "
"Make sure it's running: hypercorn main:app --reload"
)
except requests.RequestException as e:
raise DeviceNotConnectedError(f"Cannot connect to SDK at {self.sdk_base_url}: {e}") from e
self._is_connected = True
logger.info(f"{self.name} connected to SDK")
if calibrate:
self.calibrate()
def calibrate(self) -> None:
"""Calibration not needed for SDK-based robot."""
logger.info("Calibration not required for SDK-based robot")
@property
def is_calibrated(self) -> bool:
"""SDK robot doesn't require calibration.
Returns:
bool: Always True for SDK-based robots
"""
return True
def configure(self) -> None:
"""Configure robot (no-op for SDK-based robot)."""
pass
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
"""Define the observation space for dataset recording.
Returns:
dict: Observation features with types/shapes:
- front: (480, 640, 3) - Front camera RGB image
- rear: (480, 640, 3) - Rear camera RGB image
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
- battery.level: float - Battery level (0-1, normalized from 0-100)
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
- gps.latitude: float - GPS latitude coordinate
- gps.longitude: float - GPS longitude coordinate
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
- signal.level: float - Network signal level (0-1, normalized from 0-5)
- vibration: float - Vibration sensor reading
- lamp.state: float - Lamp state (0=off, 1=on)
"""
return {
# Cameras (height, width, channels)
OBS_FRONT: (480, 640, 3),
OBS_REAR: (480, 640, 3),
# Motion state
OBS_LINEAR_VEL: float,
# Robot state
OBS_BATTERY_LEVEL: float,
OBS_ORIENTATION_DEG: float,
# GPS
OBS_GPS_LATITUDE: float,
OBS_GPS_LONGITUDE: float,
OBS_GPS_SIGNAL: float,
# Sensors
OBS_SIGNAL_LEVEL: float,
OBS_VIBRATION: float,
OBS_LAMP_STATE: float,
}
@cached_property
def action_features(self) -> dict[str, type]:
"""Define the action space.
Returns:
dict: Action features with types:
- linear.vel: float - Target linear velocity
- angular.vel: float - Target angular velocity
"""
return {
ACTION_LINEAR_VEL: float,
ACTION_ANGULAR_VEL: float,
}
def get_observation(self) -> dict[str, Any]:
"""Get current robot observation from SDK.
Returns:
dict: Observation containing:
- front: Front camera image (480, 640, 3) in RGB format
- rear: Rear camera image (480, 640, 3) in RGB format
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
- battery.level: Battery level (0-1, normalized from 0-100)
- orientation.deg: Robot orientation (0-1, normalized from raw value)
- gps.latitude: GPS latitude coordinate
- gps.longitude: GPS longitude coordinate
- gps.signal: GPS signal strength (0-1, normalized from percentage)
- signal.level: Network signal level (0-1, normalized from 0-5)
- vibration: Vibration sensor reading
- lamp.state: Lamp state (0=off, 1=on)
Raises:
DeviceNotConnectedError: If robot is not connected
Note:
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
Frames are decoded from base64 and converted from BGR to RGB format.
Robot telemetry is retrieved from /data endpoint.
All SDK values are normalized to appropriate ranges for dataset recording.
"""
if not self._is_connected:
raise DeviceNotConnectedError(f"{self.name} is not connected")
observation = {}
# Get camera images from SDK
frames = self._get_camera_frames()
observation[OBS_FRONT] = frames["front"]
observation[OBS_REAR] = frames["rear"]
# Get robot state from SDK
robot_data = self._get_robot_data()
# Motion state
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
# Robot state
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
# GPS data
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
# Sensors
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
observation[OBS_VIBRATION] = robot_data["vibration"]
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
return observation
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
"""Send action to robot via SDK.
Args:
action: Action dict with keys:
- linear.vel: Target linear velocity (-1 to 1)
- angular.vel: Target angular velocity (-1 to 1)
Returns:
dict: The action that was sent (matches action_features keys)
Raises:
DeviceNotConnectedError: If robot is not connected
Note:
Actions are sent to SDK via POST /control endpoint.
SDK expects commands in range [-1, 1].
"""
if not self._is_connected:
raise DeviceNotConnectedError(f"{self.name} is not connected")
# Extract action values and convert to float
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
# Send command to SDK
try:
self._send_command_to_sdk(linear, angular)
except Exception as e:
logger.error(f"Error sending action: {e}")
# Return action in format matching action_features
return {
ACTION_LINEAR_VEL: linear,
ACTION_ANGULAR_VEL: angular,
}
def disconnect(self) -> None:
"""Disconnect from robot.
Stops the robot and closes connection to SDK.
Raises:
DeviceNotConnectedError: If robot is not connected
"""
if not self._is_connected:
raise DeviceNotConnectedError(f"{self.name} is not connected")
# Stop the robot before disconnecting
try:
self._send_command_to_sdk(0.0, 0.0)
except Exception as e:
logger.warning(f"Failed to stop robot during disconnect: {e}")
self._is_connected = False
logger.info(f"{self.name} disconnected")
# Private helper methods for SDK communication
def _get_camera_frames(self) -> dict[str, np.ndarray]:
"""Get camera frames from SDK using v2 endpoints with caching fallback.
Returns:
dict: Dictionary with 'front' and 'rear' keys containing:
- Current frame (if request succeeds)
- Cached frame (if request fails but cache exists)
- Zero array (if request fails and no cache exists yet)
Note:
Uses /v2/front and /v2/rear endpoints which are 15x faster than /screenshot.
Images are base64 encoded, resized to 640x480, and converted from BGR to RGB.
If request fails, returns the last successfully retrieved frame (cached).
"""
frames = {}
# Get front camera
try:
response = requests.get(f"{self.sdk_base_url}/v2/front", timeout=2.0)
if response.status_code == 200:
data = response.json()
if "front_frame" in data and data["front_frame"]:
front_img = self._decode_base64_image(data["front_frame"])
if front_img is not None:
# Resize and convert BGR to RGB
front_img = cv2.resize(front_img, (640, 480))
front_rgb = cv2.cvtColor(front_img, cv2.COLOR_BGR2RGB)
frames["front"] = front_rgb
# Cache the successful frame
self._last_front_frame = front_rgb
except Exception as e:
logger.warning(f"Error fetching front camera: {e}")
# Fallback: use cache or zero array
if "front" not in frames:
if self._last_front_frame is not None:
frames["front"] = self._last_front_frame
else:
frames["front"] = np.zeros((480, 640, 3), dtype=np.uint8)
# Get rear camera
try:
response = requests.get(f"{self.sdk_base_url}/v2/rear", timeout=2.0)
if response.status_code == 200:
data = response.json()
if "rear_frame" in data and data["rear_frame"]:
rear_img = self._decode_base64_image(data["rear_frame"])
if rear_img is not None:
# Resize and convert BGR to RGB
rear_img = cv2.resize(rear_img, (640, 480))
rear_rgb = cv2.cvtColor(rear_img, cv2.COLOR_BGR2RGB)
frames["rear"] = rear_rgb
# Cache the successful frame
self._last_rear_frame = rear_rgb
except Exception as e:
logger.warning(f"Error fetching rear camera: {e}")
# Fallback: use cache or zero array
if "rear" not in frames:
if self._last_rear_frame is not None:
frames["rear"] = self._last_rear_frame
else:
frames["rear"] = np.zeros((480, 640, 3), dtype=np.uint8)
return frames
def _decode_base64_image(self, base64_string: str) -> np.ndarray | None:
"""Decode base64 string to image.
Args:
base64_string: Base64 encoded image string
Returns:
np.ndarray: Decoded image in BGR format (OpenCV default), or None if decoding fails
"""
try:
img_bytes = base64.b64decode(base64_string)
nparr = np.frombuffer(img_bytes, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return img # Return in BGR format (OpenCV default)
except Exception as e:
logger.error(f"Error decoding image: {e}")
return None
def _get_robot_data(self) -> dict:
"""Get robot telemetry data from SDK.
Returns:
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
- Current data (if request succeeds)
- Cached data (if request fails but cache exists)
- Default values (if request fails and no cache exists yet)
Note:
Uses /data endpoint which provides comprehensive robot state.
If request fails, returns the last successfully retrieved data (cached).
"""
try:
response = requests.get(f"{self.sdk_base_url}/data", timeout=2.0)
if response.status_code == 200:
data = response.json()
# Cache the successful data
self._last_robot_data = data
return data
except Exception as e:
logger.warning(f"Error fetching robot data: {e}")
# Fallback: use cache or default values
if self._last_robot_data is not None:
return self._last_robot_data
else:
# Return dict with default values (used only on first failure before any cache exists)
return {
"speed": 0,
"battery": 0,
"orientation": 0,
"latitude": 0.0,
"longitude": 0.0,
"gps_signal": 0,
"signal_level": 0,
"vibration": 0.0,
"lamp": 0,
}
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
"""Send control command to SDK.
Args:
linear: Linear velocity command (-1 to 1)
angular: Angular velocity command (-1 to 1)
lamp: Lamp control (0=off, 1=on)
Returns:
bool: True if command sent successfully, False otherwise
Note:
Uses POST /control endpoint. Commands are sent as JSON payload.
"""
try:
payload = {
"command": {
"linear": linear,
"angular": angular,
"lamp": lamp,
}
}
response = requests.post(
f"{self.sdk_base_url}/control",
json=payload,
timeout=1.0,
)
return response.status_code == 200
except Exception as e:
logger.error(f"Error sending command: {e}")
return False

View File

@@ -1,18 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_unitree_g1 import UnitreeG1Config
from .unitree_g1 import UnitreeG1

View File

@@ -1,55 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from ..config import RobotConfig
_GAINS: dict[str, dict[str, list[float]]] = {
"left_leg": {
"kp": [150, 150, 150, 300, 40, 40],
"kd": [2, 2, 2, 4, 2, 2],
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
}
def _build_gains() -> tuple[list[float], list[float]]:
"""Build kp and kd lists from body-part groupings."""
kp = [v for g in _GAINS.values() for v in g["kp"]]
kd = [v for g in _GAINS.values() for v in g["kd"]]
return kp, kd
_DEFAULT_KP, _DEFAULT_KD = _build_gains()
@RobotConfig.register_subclass("unitree_g1")
@dataclass
class UnitreeG1Config(RobotConfig):
kp: list[float] = field(default_factory=lambda: _DEFAULT_KP.copy())
kd: list[float] = field(default_factory=lambda: _DEFAULT_KD.copy())
control_dt: float = 1.0 / 250.0 # 250Hz
# socket config for ZMQ bridge
robot_ip: str = "192.168.123.164"

View File

@@ -1,89 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import IntEnum
# ruff: noqa: N801, N815
NUM_MOTORS = 35
class G1_29_JointArmIndex(IntEnum):
# Left arm
kLeftShoulderPitch = 15
kLeftShoulderRoll = 16
kLeftShoulderYaw = 17
kLeftElbow = 18
kLeftWristRoll = 19
kLeftWristPitch = 20
kLeftWristyaw = 21
# Right arm
kRightShoulderPitch = 22
kRightShoulderRoll = 23
kRightShoulderYaw = 24
kRightElbow = 25
kRightWristRoll = 26
kRightWristPitch = 27
kRightWristYaw = 28
class G1_29_JointIndex(IntEnum):
# Left leg
kLeftHipPitch = 0
kLeftHipRoll = 1
kLeftHipYaw = 2
kLeftKnee = 3
kLeftAnklePitch = 4
kLeftAnkleRoll = 5
# Right leg
kRightHipPitch = 6
kRightHipRoll = 7
kRightHipYaw = 8
kRightKnee = 9
kRightAnklePitch = 10
kRightAnkleRoll = 11
kWaistYaw = 12
kWaistRoll = 13
kWaistPitch = 14
# Left arm
kLeftShoulderPitch = 15
kLeftShoulderRoll = 16
kLeftShoulderYaw = 17
kLeftElbow = 18
kLeftWristRoll = 19
kLeftWristPitch = 20
kLeftWristyaw = 21
# Right arm
kRightShoulderPitch = 22
kRightShoulderRoll = 23
kRightShoulderYaw = 24
kRightElbow = 25
kRightWristRoll = 26
kRightWristPitch = 27
kRightWristYaw = 28
# not used
kNotUsedJoint0 = 29
kNotUsedJoint1 = 30
kNotUsedJoint2 = 31
kNotUsedJoint3 = 32
kNotUsedJoint4 = 33
kNotUsedJoint5 = 34

View File

@@ -1,212 +0,0 @@
#!/usr/bin/env python3
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DDS-to-ZMQ bridge server for Unitree G1 robot.
This server runs on the robot and forwards:
- Robot state (LowState) from DDS to ZMQ (for remote clients)
- Robot commands (LowCmd) from ZMQ to DDS (from remote clients)
Uses JSON for secure serialization instead of pickle.
"""
import base64
import contextlib
import json
import threading
import time
from typing import Any
import zmq
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import MotionSwitcherClient
from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
from unitree_sdk2py.utils.crc import CRC
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
kTopicLowState = "rt/lowstate" # observation from robot
LOWCMD_PORT = 6000
LOWSTATE_PORT = 6001
NUM_MOTORS = 35
def lowstate_to_dict(msg: hg_LowState) -> dict[str, Any]:
"""Convert LowState SDK message to a JSON-serializable dictionary."""
motor_states = []
for i in range(NUM_MOTORS):
temp = msg.motor_state[i].temperature
avg_temp = float(sum(temp) / len(temp)) if isinstance(temp, list) else float(temp)
motor_states.append(
{
"q": float(msg.motor_state[i].q),
"dq": float(msg.motor_state[i].dq),
"tau_est": float(msg.motor_state[i].tau_est),
"temperature": avg_temp,
}
)
return {
"motor_state": motor_states,
"imu_state": {
"quaternion": [float(x) for x in msg.imu_state.quaternion],
"gyroscope": [float(x) for x in msg.imu_state.gyroscope],
"accelerometer": [float(x) for x in msg.imu_state.accelerometer],
"rpy": [float(x) for x in msg.imu_state.rpy],
"temperature": float(msg.imu_state.temperature),
},
# Encode bytes as base64 for JSON compatibility
"wireless_remote": base64.b64encode(bytes(msg.wireless_remote)).decode("ascii"),
"mode_machine": int(msg.mode_machine),
}
def dict_to_lowcmd(data: dict[str, Any]) -> hg_LowCmd:
"""Convert dictionary back to LowCmd SDK message."""
cmd = unitree_hg_msg_dds__LowCmd_()
cmd.mode_pr = data.get("mode_pr", 0)
cmd.mode_machine = data.get("mode_machine", 0)
for i, motor_data in enumerate(data.get("motor_cmd", [])):
cmd.motor_cmd[i].mode = motor_data.get("mode", 0)
cmd.motor_cmd[i].q = motor_data.get("q", 0.0)
cmd.motor_cmd[i].dq = motor_data.get("dq", 0.0)
cmd.motor_cmd[i].kp = motor_data.get("kp", 0.0)
cmd.motor_cmd[i].kd = motor_data.get("kd", 0.0)
cmd.motor_cmd[i].tau = motor_data.get("tau", 0.0)
return cmd
def state_forward_loop(
lowstate_sub: ChannelSubscriber,
lowstate_sock: zmq.Socket,
state_period: float,
) -> None:
"""Read observation from DDS and forward to ZMQ clients."""
last_state_time = 0.0
while True:
# read from DDS
msg = lowstate_sub.Read()
if msg is None:
continue
now = time.time()
# optional downsampling (if robot dds rate > state_period)
if now - last_state_time >= state_period:
# Convert to dict and serialize with JSON
state_dict = lowstate_to_dict(msg)
payload = json.dumps({"topic": kTopicLowState, "data": state_dict}).encode("utf-8")
# if no subscribers / tx buffer full, just drop
with contextlib.suppress(zmq.Again):
lowstate_sock.send(payload, zmq.NOBLOCK)
last_state_time = now
def cmd_forward_loop(
lowcmd_sock: zmq.Socket,
lowcmd_pub_debug: ChannelPublisher,
crc: CRC,
) -> None:
"""Receive commands from ZMQ and forward to DDS."""
while True:
payload = lowcmd_sock.recv()
msg_dict = json.loads(payload.decode("utf-8"))
topic = msg_dict.get("topic", "")
cmd_data = msg_dict.get("data", {})
# Reconstruct LowCmd object from dict
cmd = dict_to_lowcmd(cmd_data)
# recompute crc
cmd.crc = crc.Crc(cmd)
if topic == kTopicLowCommand_Debug:
lowcmd_pub_debug.Write(cmd)
def main() -> None:
"""Main entry point for the robot server bridge."""
# initialize DDS
ChannelFactoryInitialize(0)
# stop all active publishers on the robot
msc = MotionSwitcherClient()
msc.SetTimeout(5.0)
msc.Init()
status, result = msc.CheckMode()
while result is not None and "name" in result and result["name"]:
msc.ReleaseMode()
status, result = msc.CheckMode()
time.sleep(1.0)
crc = CRC()
# initialize DDS publisher
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
lowcmd_pub_debug.Init()
# initialize DDS subscriber
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
lowstate_sub.Init()
# initialize ZMQ
ctx = zmq.Context.instance()
# receive commands from remote client
lowcmd_sock = ctx.socket(zmq.PULL)
lowcmd_sock.bind(f"tcp://0.0.0.0:{LOWCMD_PORT}")
# publish state to remote clients
lowstate_sock = ctx.socket(zmq.PUB)
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
state_period = 0.002 # ~500 hz
# start observation forwarding thread
t_state = threading.Thread(
target=state_forward_loop,
args=(lowstate_sub, lowstate_sock, state_period),
daemon=True,
)
t_state.start()
# start action forwarding thread
t_cmd = threading.Thread(
target=cmd_forward_loop,
args=(lowcmd_sock, lowcmd_pub_debug, crc),
daemon=True,
)
t_cmd.start()
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
# keep main thread alive so daemon threads don't exit
try:
while True:
time.sleep(1.0)
except KeyboardInterrupt:
print("shutting down bridge...")
if __name__ == "__main__":
main()

View File

@@ -1,267 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import struct
import threading
import time
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
import numpy as np
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
LowCmd_ as hg_LowCmd,
LowState_ as hg_LowState,
)
from unitree_sdk2py.utils.crc import CRC
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
ChannelFactoryInitialize,
ChannelPublisher,
ChannelSubscriber,
)
from ..robot import Robot
from .config_unitree_g1 import UnitreeG1Config
logger = logging.getLogger(__name__)
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd"
kTopicLowState = "rt/lowstate"
G1_29_Num_Motors = 35
G1_23_Num_Motors = 35
H1_2_Num_Motors = 35
H1_Num_Motors = 20
@dataclass
class MotorState:
q: float | None = None # position
dq: float | None = None # velocity
tau_est: float | None = None # estimated torque
temperature: float | None = None # motor temperature
@dataclass
class IMUState:
quaternion: np.ndarray | None = None # [w, x, y, z]
gyroscope: np.ndarray | None = None # [x, y, z] angular velocity (rad/s)
accelerometer: np.ndarray | None = None # [x, y, z] linear acceleration (m/s²)
rpy: np.ndarray | None = None # [roll, pitch, yaw] (rad)
temperature: float | None = None # IMU temperature
# g1 observation class
@dataclass
class G1_29_LowState: # noqa: N801
motor_state: list[MotorState] = field(
default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)]
)
imu_state: IMUState = field(default_factory=IMUState)
wireless_remote: Any = None # Raw wireless remote data
mode_machine: int = 0 # Robot mode
class DataBuffer:
def __init__(self):
self.data = None
self.lock = threading.Lock()
def get_data(self):
with self.lock:
return self.data
def set_data(self, data):
with self.lock:
self.data = data
class UnitreeG1(Robot):
config_class = UnitreeG1Config
name = "unitree_g1"
# unitree remote controller
class RemoteController:
def __init__(self):
self.lx = 0
self.ly = 0
self.rx = 0
self.ry = 0
self.button = [0] * 16
def set(self, data):
# wireless_remote
keys = struct.unpack("H", data[2:4])[0]
for i in range(16):
self.button[i] = (keys & (1 << i)) >> i
self.lx = struct.unpack("f", data[4:8])[0]
self.rx = struct.unpack("f", data[8:12])[0]
self.ry = struct.unpack("f", data[12:16])[0]
self.ly = struct.unpack("f", data[20:24])[0]
def __init__(self, config: UnitreeG1Config):
super().__init__(config)
logger.info("Initialize UnitreeG1...")
self.config = config
self.control_dt = config.control_dt
# connect robot
self.connect()
# initialize direct motor control interface
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
self.lowcmd_publisher.Init()
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
self.lowstate_subscriber.Init()
self.lowstate_buffer = DataBuffer()
# initialize subscribe thread to read robot state
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
self.subscribe_thread.daemon = True
self.subscribe_thread.start()
while not self.is_connected:
time.sleep(0.1)
# initialize hg's lowcmd msg
self.crc = CRC()
self.msg = unitree_hg_msg_dds__LowCmd_()
self.msg.mode_pr = 0
# Wait for first state message to arrive
lowstate = None
while lowstate is None:
lowstate = self.lowstate_buffer.get_data()
if lowstate is None:
time.sleep(0.01)
logger.warning("[UnitreeG1] Waiting for robot state...")
logger.warning("[UnitreeG1] Connected to robot.")
self.msg.mode_machine = lowstate.mode_machine
# initialize all motors with unified kp/kd from config
self.kp = np.array(config.kp, dtype=np.float32)
self.kd = np.array(config.kd, dtype=np.float32)
for id in G1_29_JointIndex:
self.msg.motor_cmd[id].mode = 1
self.msg.motor_cmd[id].kp = self.kp[id.value]
self.msg.motor_cmd[id].kd = self.kd[id.value]
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
# Initialize remote controller
self.remote_controller = self.RemoteController()
def _subscribe_motor_state(self): # polls robot state @ 250Hz
while True:
start_time = time.time()
msg = self.lowstate_subscriber.Read()
if msg is not None:
lowstate = G1_29_LowState()
# Capture motor states
for id in range(G1_29_Num_Motors):
lowstate.motor_state[id].q = msg.motor_state[id].q
lowstate.motor_state[id].dq = msg.motor_state[id].dq
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
# Capture IMU state
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
lowstate.imu_state.gyroscope = list(msg.imu_state.gyroscope)
lowstate.imu_state.accelerometer = list(msg.imu_state.accelerometer)
lowstate.imu_state.rpy = list(msg.imu_state.rpy)
lowstate.imu_state.temperature = msg.imu_state.temperature
# Capture wireless remote data
lowstate.wireless_remote = msg.wireless_remote
# Capture mode_machine
lowstate.mode_machine = msg.mode_machine
self.lowstate_buffer.set_data(lowstate)
current_time = time.time()
all_t_elapsed = current_time - start_time
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
time.sleep(sleep_time)
@cached_property
def action_features(self) -> dict[str, type]:
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
def calibrate(self) -> None: # robot is already calibrated
pass
def configure(self) -> None:
pass
def connect(self, calibrate: bool = True) -> None: # connect to DDS
ChannelFactoryInitialize(0)
def disconnect(self):
pass
def get_observation(self) -> dict[str, Any]:
return self.lowstate_buffer.get_data()
@property
def is_calibrated(self) -> bool:
return True
@property
def is_connected(self) -> bool:
return self.lowstate_buffer.get_data() is not None
@property
def _motors_ft(self) -> dict[str, type]:
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
@property
def _cameras_ft(self) -> dict[str, tuple]:
return {
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
}
@cached_property
def observation_features(self) -> dict[str, type | tuple]:
return {**self._motors_ft, **self._cameras_ft}
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
self.msg.crc = self.crc.Crc(action)
self.lowcmd_publisher.Write(action)
return action
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
"""Get gravity orientation from quaternion."""
qw = quaternion[0]
qx = quaternion[1]
qy = quaternion[2]
qz = quaternion[3]
gravity_orientation = np.zeros(3)
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
return gravity_orientation

View File

@@ -1,168 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
from typing import Any
import zmq
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
_ctx: zmq.Context | None = None
_lowcmd_sock: zmq.Socket | None = None
_lowstate_sock: zmq.Socket | None = None
LOWCMD_PORT = 6000
LOWSTATE_PORT = 6001
# DDS topic names follow Unitree SDK naming conventions
# ruff: noqa: N816
kTopicLowCommand_Debug = "rt/lowcmd"
class LowStateMsg:
"""
Wrapper class that mimics the Unitree SDK LowState_ message structure.
Reconstructs the message from deserialized JSON data to maintain
compatibility with existing code that expects SDK message objects.
"""
class MotorState:
"""Motor state data for a single joint."""
def __init__(self, data: dict[str, Any]) -> None:
self.q: float = data.get("q", 0.0)
self.dq: float = data.get("dq", 0.0)
self.tau_est: float = data.get("tau_est", 0.0)
self.temperature: float = data.get("temperature", 0.0)
class IMUState:
"""IMU sensor data."""
def __init__(self, data: dict[str, Any]) -> None:
self.quaternion: list[float] = data.get("quaternion", [1.0, 0.0, 0.0, 0.0])
self.gyroscope: list[float] = data.get("gyroscope", [0.0, 0.0, 0.0])
self.accelerometer: list[float] = data.get("accelerometer", [0.0, 0.0, 0.0])
self.rpy: list[float] = data.get("rpy", [0.0, 0.0, 0.0])
self.temperature: float = data.get("temperature", 0.0)
def __init__(self, data: dict[str, Any]) -> None:
"""Initialize from deserialized JSON data."""
self.motor_state = [self.MotorState(m) for m in data.get("motor_state", [])]
self.imu_state = self.IMUState(data.get("imu_state", {}))
# Decode base64-encoded wireless_remote bytes
wireless_b64 = data.get("wireless_remote", "")
self.wireless_remote: bytes = base64.b64decode(wireless_b64) if wireless_b64 else b""
self.mode_machine: int = data.get("mode_machine", 0)
def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
"""Convert LowCmd message to a JSON-serializable dictionary."""
motor_cmds = []
# Iterate over all motor commands in the message
for i in range(len(msg.motor_cmd)):
motor_cmds.append(
{
"mode": int(msg.motor_cmd[i].mode),
"q": float(msg.motor_cmd[i].q),
"dq": float(msg.motor_cmd[i].dq),
"kp": float(msg.motor_cmd[i].kp),
"kd": float(msg.motor_cmd[i].kd),
"tau": float(msg.motor_cmd[i].tau),
}
)
return {
"topic": topic,
"data": {
"mode_pr": int(msg.mode_pr),
"mode_machine": int(msg.mode_machine),
"motor_cmd": motor_cmds,
},
}
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
"""
Initialize ZMQ sockets for robot communication.
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
ZMQ sockets to connect to the robot server bridge instead of DDS.
"""
global _ctx, _lowcmd_sock, _lowstate_sock
# read socket config
config = UnitreeG1Config()
robot_ip = config.robot_ip
ctx = zmq.Context.instance()
_ctx = ctx
# lowcmd: send robot commands
lowcmd_sock = ctx.socket(zmq.PUSH)
lowcmd_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}")
_lowcmd_sock = lowcmd_sock
# lowstate: receive robot observations
lowstate_sock = ctx.socket(zmq.SUB)
lowstate_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
lowstate_sock.connect(f"tcp://{robot_ip}:{LOWSTATE_PORT}")
lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "")
_lowstate_sock = lowstate_sock
class ChannelPublisher:
"""ZMQ-based publisher that sends commands to the robot server."""
def __init__(self, topic: str, msg_type: type) -> None:
self.topic = topic
self.msg_type = msg_type
def Init(self) -> None: # noqa: N802
"""Initialize the publisher (no-op for ZMQ)."""
pass
def Write(self, msg: Any) -> None: # noqa: N802
"""Serialize and send a command message to the robot."""
if _lowcmd_sock is None:
raise RuntimeError("ChannelFactoryInitialize must be called first")
payload = json.dumps(lowcmd_to_dict(self.topic, msg)).encode("utf-8")
_lowcmd_sock.send(payload)
class ChannelSubscriber:
"""ZMQ-based subscriber that receives state from the robot server."""
def __init__(self, topic: str, msg_type: type) -> None:
self.topic = topic
self.msg_type = msg_type
def Init(self) -> None: # noqa: N802
"""Initialize the subscriber (no-op for ZMQ)."""
pass
def Read(self) -> LowStateMsg: # noqa: N802
"""Receive and deserialize a state message from the robot."""
if _lowstate_sock is None:
raise RuntimeError("ChannelFactoryInitialize must be called first")
payload = _lowstate_sock.recv()
msg_dict = json.loads(payload.decode("utf-8"))
return LowStateMsg(msg_dict.get("data", {}))

View File

@@ -52,7 +52,7 @@ from lerobot.teleoperators import ( # noqa: F401
so100_leader,
so101_leader,
)
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.utils import init_logging
@@ -84,7 +84,7 @@ def calibrate(cfg: CalibrateConfig):
def main():
register_third_party_plugins()
register_third_party_devices()
calibrate()

View File

@@ -65,6 +65,7 @@ import argparse
import gc
import logging
import time
from collections.abc import Iterator
from pathlib import Path
import numpy as np
@@ -77,6 +78,19 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self) -> int:
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
@@ -105,10 +119,12 @@ def visualize_dataset(
repo_id = dataset.repo_id
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=episode_sampler,
)
logging.info("Starting Rerun")

View File

@@ -18,8 +18,7 @@
Edit LeRobot datasets using various transformation tools.
This script allows you to delete episodes, split datasets, merge datasets,
remove features, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset.
and remove features. When new_repo_id is specified, creates a new dataset.
Usage Examples:
@@ -66,25 +65,6 @@ Remove camera feature:
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
Convert image dataset to video format (saves locally):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--operation.type convert_to_video \
--operation.output_dir /path/to/output/pusht_video
Convert image dataset and save with new repo_id:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video
Convert and push to hub:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_image \
--new_repo_id lerobot/pusht_video \
--operation.type convert_to_video \
--push_to_hub true
Using JSON config file:
python -m lerobot.scripts.lerobot_edit_dataset \
--config_path path/to/edit_config.json
@@ -92,13 +72,9 @@ Using JSON config file:
import logging
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
from tqdm import tqdm
from lerobot.configs import parser
from lerobot.datasets.dataset_tools import (
delete_episodes,
@@ -106,10 +82,8 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import write_stats, write_tasks
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
@@ -137,23 +111,10 @@ class RemoveFeatureConfig:
feature_names: list[str] | None = None
@dataclass
class ConvertToVideoConfig:
type: str = "convert_to_video"
output_dir: str | None = None
vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p"
g: int = 2
crf: int = 30
fast_decode: int = 0
episode_indices: list[int] | None = None
num_workers: int = 4
@dataclass
class EditDatasetConfig:
repo_id: str
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
@@ -297,415 +258,6 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def save_episode_images_for_video(
dataset: LeRobotDataset,
imgs_dir: Path,
img_key: str,
episode_index: int,
num_workers: int = 4,
) -> None:
"""Save images from a specific episode and camera to disk for video encoding.
Args:
dataset: The LeRobot dataset to extract images from
imgs_dir: Directory to save images to
img_key: The image key (camera) to extract
episode_index: Index of the episode to save
num_workers: Number of threads for parallel image saving
"""
# Create directory
imgs_dir.mkdir(parents=True, exist_ok=True)
# Get dataset without torch format for PIL image access
hf_dataset = dataset.hf_dataset.with_format(None)
# Select only this camera's images
imgs_dataset = hf_dataset.select_columns(img_key)
# Get episode start and end indices
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
# Define function to save a single image
def save_single_image(i_item_tuple):
i, item = i_item_tuple
img = item[img_key]
# Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(save_single_image, item) for item in items]
for future in as_completed(futures):
future.result() # This will raise any exceptions that occurred
def encode_episode_videos(
dataset: LeRobotDataset,
new_meta: LeRobotDatasetMetadata,
episode_index: int,
vcodec: str,
pix_fmt: str,
g: int,
crf: int,
fast_decode: int,
temp_dir: Path,
num_image_workers: int = 4,
) -> dict[str, dict]:
"""Encode videos for a single episode and return video metadata.
Args:
dataset: Source dataset with images
new_meta: Metadata object for the new video dataset
episode_index: Episode index to process
vcodec: Video codec
pix_fmt: Pixel format
g: Group of pictures size
crf: Constant rate factor
fast_decode: Fast decode tuning
temp_dir: Temporary directory for images
num_image_workers: Number of workers for saving images
Returns:
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
"""
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
video_metadata = {}
fps = int(dataset.fps) # Convert to int for PyAV compatibility
episode_length = dataset.meta.episodes["length"][episode_index]
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
for img_key in img_keys:
# Save images temporarily
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
# Determine chunk and file indices
# For simplicity, we'll put each episode in its own file
chunk_idx = episode_index // new_meta.chunks_size
file_idx = episode_index % new_meta.chunks_size
# Create video path in the new dataset structure
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
)
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encode video
encode_video_frames(
imgs_dir=imgs_dir,
video_path=video_path,
fps=fps,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
overwrite=True,
)
# Clean up temporary images
shutil.rmtree(imgs_dir)
# Store video metadata
video_metadata[img_key] = {
f"videos/{img_key}/chunk_index": chunk_idx,
f"videos/{img_key}/file_index": file_idx,
f"videos/{img_key}/from_timestamp": 0.0,
f"videos/{img_key}/to_timestamp": episode_duration,
}
return video_metadata
def convert_dataset_to_videos(
dataset: LeRobotDataset,
output_dir: Path,
repo_id: str | None = None,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int = 2,
crf: int = 30,
fast_decode: int = 0,
episode_indices: list[int] | None = None,
num_workers: int = 4,
) -> LeRobotDataset:
"""Convert image-based dataset to video-based dataset.
Creates a new LeRobotDataset with videos instead of images, following the proper
LeRobot dataset structure with videos stored in chunked MP4 files.
Args:
dataset: The source LeRobot dataset with images
output_dir: Directory to save the new video dataset
repo_id: Repository ID for the new dataset (default: original_id + "_video")
vcodec: Video codec (default: libsvtav1)
pix_fmt: Pixel format (default: yuv420p)
g: Group of pictures size (default: 2)
crf: Constant rate factor (default: 30)
fast_decode: Fast decode tuning (default: 0)
episode_indices: List of episode indices to convert (None = all episodes)
num_workers: Number of threads for parallel processing (default: 4)
Returns:
New LeRobotDataset with videos
"""
# Check that it's an image dataset
if len(dataset.meta.video_keys) > 0:
raise ValueError(
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
)
# Get all image keys
hf_dataset = dataset.hf_dataset.with_format(None)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
if len(img_keys) == 0:
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
# Determine which episodes to process
if episode_indices is None:
episode_indices = list(range(dataset.meta.total_episodes))
if repo_id is None:
repo_id = f"{dataset.repo_id}_video"
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
# Create new features dict, converting image features to video features
new_features = {}
for key, value in dataset.meta.features.items():
if key not in img_keys:
new_features[key] = value
else:
# Convert image key to video format
new_features[key] = value.copy()
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
# Video info will be updated after episodes are encoded
# Create new metadata for video dataset
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=new_features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=True,
chunks_size=dataset.meta.chunks_size,
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
)
# Create temporary directory for image extraction
temp_dir = output_dir / "temp_images"
temp_dir.mkdir(parents=True, exist_ok=True)
# Process each episode
all_episode_metadata = []
try:
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
# Get episode metadata from source
src_episode = dataset.meta.episodes[ep_idx]
# Encode videos for this episode
video_metadata = encode_episode_videos(
dataset=dataset,
new_meta=new_meta,
episode_index=ep_idx,
vcodec=vcodec,
pix_fmt=pix_fmt,
g=g,
crf=crf,
fast_decode=fast_decode,
temp_dir=temp_dir,
num_image_workers=num_workers,
)
# Build episode metadata
episode_meta = {
"episode_index": ep_idx,
"length": src_episode["length"],
"dataset_from_index": ep_idx * src_episode["length"],
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
}
# Add video metadata
for img_key in img_keys:
episode_meta.update(video_metadata[img_key])
# Add data chunk/file info (using same structure as source)
if "data/chunk_index" in src_episode:
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
episode_meta["data/file_index"] = src_episode["data/file_index"]
all_episode_metadata.append(episode_meta)
# Copy and transform data files (removing image columns)
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
# Save episode metadata
episodes_df = pd.DataFrame(all_episode_metadata)
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
episodes_path.parent.mkdir(parents=True, exist_ok=True)
episodes_df.to_parquet(episodes_path, index=False)
# Update metadata info
new_meta.info["total_episodes"] = len(episode_indices)
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
new_meta.info["total_tasks"] = dataset.meta.total_tasks
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
# Update video info for all image keys (now videos)
# We need to manually set video info since update_video_info() checks video_keys first
for img_key in img_keys:
if not new_meta.features[img_key].get("info", None):
video_path = new_meta.root / new_meta.video_path.format(
video_key=img_key, chunk_index=0, file_index=0
)
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
from lerobot.datasets.utils import write_info
write_info(new_meta.info, new_meta.root)
# Copy stats and tasks
if dataset.meta.stats is not None:
# Remove image stats
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
write_stats(new_stats, new_meta.root)
if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root)
finally:
# Clean up temporary directory
if temp_dir.exists():
shutil.rmtree(temp_dir)
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
logging.info(f"New dataset saved to: {output_dir}")
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def _copy_data_without_images(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_indices: list[int],
img_keys: list[str],
) -> None:
"""Copy data files without image columns.
Args:
src_dataset: Source dataset
dst_meta: Destination metadata
episode_indices: Episodes to include
img_keys: Image keys to remove
"""
from lerobot.datasets.utils import DATA_DIR
data_dir = src_dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
episode_set = set(episode_indices)
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
# Filter to only include selected episodes
df = df[df["episode_index"].isin(episode_set)].copy()
if len(df) == 0:
continue
# Remove image columns
columns_to_drop = [col for col in img_keys if col in df.columns]
if columns_to_drop:
df = df.drop(columns=columns_to_drop)
# Get chunk and file indices from path
relative_path = src_path.relative_to(src_dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
# Write to destination without pandas index
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
dst_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_path, index=False)
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
# Note: Parser may create any config type with the right fields, so we access fields directly
# instead of checking isinstance()
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
# Determine output directory and repo_id
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
output_dir_config = getattr(cfg.operation, "output_dir", None)
if cfg.new_repo_id:
# Use new_repo_id for both local storage and hub push
output_repo_id = cfg.new_repo_id
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
elif output_dir_config:
# Use custom output directory for local-only storage
output_dir = Path(output_dir_config)
# Extract repo name from output_dir for the dataset
output_repo_id = output_dir.name
logging.info(f"Saving to local directory: {output_dir}")
else:
# Auto-generate name: append "_video" to original repo_id
output_repo_id = f"{cfg.repo_id}_video"
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
logging.info(f"Saving to auto-generated location: {output_dir}")
logging.info(f"Converting dataset {cfg.repo_id} to video format")
new_dataset = convert_dataset_to_videos(
dataset=dataset,
output_dir=output_dir,
repo_id=output_repo_id,
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
g=getattr(cfg.operation, "g", 2),
crf=getattr(cfg.operation, "crf", 30),
fast_decode=getattr(cfg.operation, "fast_decode", 0),
episode_indices=getattr(cfg.operation, "episode_indices", None),
num_workers=getattr(cfg.operation, "num_workers", 4),
)
logging.info("Video dataset created successfully!")
logging.info(f"Location: {output_dir}")
logging.info(f"Episodes: {new_dataset.meta.total_episodes}")
logging.info(f"Frames: {new_dataset.meta.total_frames}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}...")
new_dataset.push_to_hub()
logging.info("✓ Successfully pushed to hub!")
else:
logging.info("Dataset saved locally (not pushed to hub)")
@parser.wrap()
def edit_dataset(cfg: EditDatasetConfig) -> None:
operation_type = cfg.operation.type
@@ -718,12 +270,10 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_merge(cfg)
elif operation_type == "remove_feature":
handle_remove_feature(cfg)
elif operation_type == "convert_to_video":
handle_convert_to_video(cfg)
else:
raise ValueError(
f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
f"Available operations: delete_episodes, split, merge, remove_feature"
)

View File

@@ -71,7 +71,7 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.factory import make_env
from lerobot.envs.utils import (
add_envs_task,
check_env_attributes_and_types,
@@ -82,7 +82,6 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -95,8 +94,6 @@ from lerobot.utils.utils import (
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
seeds: list[int] | None = None,
@@ -168,19 +165,11 @@ def rollout(
# Infer "task" from attributes of environments.
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
observation = add_envs_task(env, observation)
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
@@ -250,8 +239,6 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -332,8 +319,6 @@ def eval_policy(
rollout_data = rollout(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
seeds=list(seeds) if seeds else None,
@@ -532,16 +517,10 @@ def eval_main(cfg: EvalPipelineConfig):
pretrained_path=cfg.policy.pretrained_path,
preprocessor_overrides=preprocessor_overrides,
)
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
@@ -582,8 +561,6 @@ def eval_one(
env: gym.vector.VectorEnv,
*,
policy: PreTrainedPolicy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -599,8 +576,6 @@ def eval_one(
task_result = eval_policy(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -625,8 +600,6 @@ def run_one(
env,
*,
policy,
env_preprocessor,
env_postprocessor,
preprocessor,
postprocessor,
n_episodes: int,
@@ -649,8 +622,6 @@ def run_one(
metrics = eval_one(
env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -668,8 +639,6 @@ def run_one(
def eval_policy_all(
envs: dict[str, dict[int, gym.vector.VectorEnv]],
policy,
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
n_episodes: int,
@@ -725,8 +694,6 @@ def eval_policy_all(
task_runner = partial(
run_one,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
@@ -793,7 +760,6 @@ def eval_policy_all(
def main():
init_logging()
register_third_party_plugins()
eval_main()

View File

@@ -15,23 +15,18 @@
# limitations under the License.
"""
Script to find joint limits and end-effector bounds via teleoperation.
Simple script to control a robot from teleoperation.
Example:
```shell
lerobot-find-joint-limits \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760432981 \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760434471 \
--teleop.id=blue \
--urdf_path=<user>/SO-ARM100-main/Simulation/SO101/so101_new_calib.urdf \
--target_frame_name=gripper \
--teleop_time_s=30 \
--warmup_time_s=5 \
--control_loop_fps=30
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue
```
"""
@@ -47,7 +42,6 @@ from lerobot.robots import ( # noqa: F401
koch_follower,
make_robot_from_config,
so100_follower,
so101_follower,
)
from lerobot.teleoperators import ( # noqa: F401
TeleoperatorConfig,
@@ -55,28 +49,18 @@ from lerobot.teleoperators import ( # noqa: F401
koch_leader,
make_teleoperator_from_config,
so100_leader,
so101_leader,
)
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.robot_utils import busy_wait
@dataclass
class FindJointLimitsConfig:
teleop: TeleoperatorConfig
robot: RobotConfig
# Path to URDF file for kinematics
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
urdf_path: str
target_frame_name: str = "gripper"
# Duration of the recording phase in seconds
# Limit the maximum frames per second. By default, no limit.
teleop_time_s: float = 30
# Duration of the warmup phase in seconds
warmup_time_s: float = 5
# Control loop frequency
control_loop_fps: int = 30
# Display all cameras on screen
display_data: bool = False
@draccus.wrap()
@@ -84,127 +68,53 @@ def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
teleop = make_teleoperator_from_config(cfg.teleop)
robot = make_robot_from_config(cfg.robot)
print(f"Connecting to robot: {cfg.robot.type}...")
teleop.connect()
robot.connect()
print("Devices connected.")
# Initialize Kinematics
try:
kinematics = RobotKinematics(cfg.urdf_path, cfg.target_frame_name)
except Exception as e:
print(f"Error initializing kinematics: {e}")
print("Ensure URDF path and target frame name are correct.")
robot.disconnect()
teleop.disconnect()
return
start_episode_t = time.perf_counter()
robot_type = getattr(robot.config, "robot_type", "so101")
if "so100" in robot_type or "so101" in robot_type:
# Note to be compatible with the rest of the codebase,
# we are using the new calibration method for so101 and so100
robot_type = "so_new_calibration"
kinematics = RobotKinematics(cfg.robot.urdf_path, cfg.robot.target_frame_name)
# Initialize variables
max_pos = None
min_pos = None
max_ee = None
min_ee = None
# Initialize min/max values
observation = robot.get_observation()
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
ee_pos = kinematics.forward_kinematics(joint_positions)[:3, 3]
start_t = time.perf_counter()
warmup_done = False
max_pos = joint_positions.copy()
min_pos = joint_positions.copy()
max_ee = ee_pos.copy()
min_ee = ee_pos.copy()
print("\n" + "=" * 40)
print(f" WARMUP PHASE ({cfg.warmup_time_s}s)")
print(" Move the robot freely to ensure control works.")
print(" Data is NOT being recorded yet.")
print("=" * 40 + "\n")
while True:
action = teleop.get_action()
robot.send_action(action)
try:
while True:
t0 = time.perf_counter()
observation = robot.get_observation()
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
ee_pos = kinematics.forward_kinematics(joint_positions)[:3, 3]
# 1. Teleoperation Control Loop
action = teleop.get_action()
robot.send_action(action)
# Skip initial warmup period
if (time.perf_counter() - start_episode_t) < 5:
continue
# 2. Read Observations
observation = robot.get_observation()
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
# Update min/max values
max_ee = np.maximum(max_ee, ee_pos)
min_ee = np.minimum(min_ee, ee_pos)
max_pos = np.maximum(max_pos, joint_positions)
min_pos = np.minimum(min_pos, joint_positions)
# 3. Calculate Kinematics
# Forward kinematics to get (x, y, z) translation
ee_pos = kinematics.forward_kinematics(joint_positions)[:3, 3]
if time.perf_counter() - start_episode_t > cfg.teleop_time_s:
print(f"Max ee position {np.round(max_ee, 4).tolist()}")
print(f"Min ee position {np.round(min_ee, 4).tolist()}")
print(f"Max joint pos position {np.round(max_pos, 4).tolist()}")
print(f"Min joint pos position {np.round(min_pos, 4).tolist()}")
break
current_time = time.perf_counter()
elapsed = current_time - start_t
# 4. Handle Phases
if elapsed < cfg.warmup_time_s:
# Still in warmup
pass
else:
# Phase Transition: Warmup -> Recording
if not warmup_done:
print("\n" + "=" * 40)
print(" RECORDING STARTED")
print(" Move robot to ALL joint limits.")
print(" Press Ctrl+C to stop early and save results.")
print("=" * 40 + "\n")
# Initialize limits with current position at start of recording
max_pos = joint_positions.copy()
min_pos = joint_positions.copy()
max_ee = ee_pos.copy()
min_ee = ee_pos.copy()
warmup_done = True
# Update Limits
max_ee = np.maximum(max_ee, ee_pos)
min_ee = np.minimum(min_ee, ee_pos)
max_pos = np.maximum(max_pos, joint_positions)
min_pos = np.minimum(min_pos, joint_positions)
# Time check
recording_time = elapsed - cfg.warmup_time_s
remaining = cfg.teleop_time_s - recording_time
# Simple throttle for print statements (every ~1 sec)
if int(recording_time * 100) % 100 == 0:
print(f"Time remaining: {remaining:.1f}s", end="\r")
if recording_time > cfg.teleop_time_s:
print("\nTime limit reached.")
break
precise_sleep(max(1.0 / cfg.control_loop_fps - (time.perf_counter() - t0), 0.0))
except KeyboardInterrupt:
print("\n\nInterrupted by user. Stopping safely...")
finally:
# Safety: Disconnect devices
print("\nDisconnecting devices...")
robot.disconnect()
teleop.disconnect()
# Results Output
if max_pos is not None:
print("\n" + "=" * 40)
print("FINAL RESULTS")
print("=" * 40)
# Rounding for readability
r_max_ee = np.round(max_ee, 4).tolist()
r_min_ee = np.round(min_ee, 4).tolist()
r_max_pos = np.round(max_pos, 4).tolist()
r_min_pos = np.round(min_pos, 4).tolist()
print("\n# End Effector Bounds (x, y, z):")
print(f"max_ee = {r_max_ee}")
print(f"min_ee = {r_min_ee}")
print("\n# Joint Position Limits (radians):")
print(f"max_pos = {r_max_pos}")
print(f"min_pos = {r_min_pos}")
else:
print("No data recorded (exited during warmup).")
busy_wait(0.01)
def main():

View File

@@ -93,7 +93,6 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_so100_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
make_robot_from_config,
@@ -119,8 +118,8 @@ from lerobot.utils.control_utils import (
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
get_safe_torch_device,
init_logging,
@@ -365,7 +364,7 @@ def record_loop(
log_rerun_data(observation=obs_processed, action=action_values)
dt_s = time.perf_counter() - start_loop_t
precise_sleep(1 / fps - dt_s)
busy_wait(1 / fps - dt_s)
timestamp = time.perf_counter() - start_episode_t
@@ -513,7 +512,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
def main():
register_third_party_plugins()
register_third_party_devices()
record()

View File

@@ -54,7 +54,6 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_so100_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
make_robot_from_config,
@@ -62,8 +61,8 @@ from lerobot.robots import ( # noqa: F401
so101_follower,
)
from lerobot.utils.constants import ACTION
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
log_say,
@@ -122,13 +121,13 @@ def replay(cfg: ReplayConfig):
_ = robot.send_action(processed_action)
dt_s = time.perf_counter() - start_episode_t
precise_sleep(1 / dataset.fps - dt_s)
busy_wait(1 / dataset.fps - dt_s)
robot.disconnect()
def main():
register_third_party_plugins()
register_third_party_devices()
replay()

View File

@@ -71,7 +71,6 @@ from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
bi_so100_follower,
earthrover_mini_plus,
hope_jr,
koch_follower,
make_robot_from_config,
@@ -84,14 +83,13 @@ from lerobot.teleoperators import ( # noqa: F401
bi_so100_leader,
gamepad,
homunculus,
keyboard,
koch_leader,
make_teleoperator_from_config,
so100_leader,
so101_leader,
)
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.import_utils import register_third_party_devices
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import init_logging, move_cursor_up
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
@@ -172,13 +170,12 @@ def teleop_loop(
# Display the final robot action that was sent
for motor, value in robot_action_to_send.items():
print(f"{motor:<{display_len}} | {value:>7.2f}")
move_cursor_up(len(robot_action_to_send) + 3)
move_cursor_up(len(robot_action_to_send) + 5)
dt_s = time.perf_counter() - loop_start
precise_sleep(1 / fps - dt_s)
busy_wait(1 / fps - dt_s)
loop_s = time.perf_counter() - loop_start
print(f"Teleop loop time: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
move_cursor_up(1)
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
if duration is not None and time.perf_counter() - start >= duration:
return
@@ -219,7 +216,7 @@ def teleoperate(cfg: TeleoperateConfig):
def main():
register_third_party_plugins()
register_third_party_devices()
teleoperate()

View File

@@ -29,14 +29,13 @@ from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import cycle
from lerobot.envs.factory import make_env, make_env_pre_post_processors
from lerobot.envs.factory import make_env
from lerobot.envs.utils import close_envs
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.scripts.lerobot_eval import eval_policy_all
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
@@ -260,10 +259,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info("Creating environment processors")
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
env_cfg=cfg.env, policy_cfg=cfg.policy
)
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
@@ -279,7 +274,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
@@ -390,8 +384,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=accelerator.unwrap_model(policy),
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
@@ -449,7 +441,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
def main():
register_third_party_plugins()
train()

Some files were not shown because too many files have changed in this diff Show More