Compare commits

..

42 Commits

Author SHA1 Message Date
Pepijn
8765b57c0a fix precommit and fix tests 2025-10-16 14:52:32 +02:00
Pepijn
ccdf06f0f1 Merge branch 'main' into feat/accelerate-melt-gpus 2025-10-16 14:46:52 +02:00
Pepijn
8a32764c38 use accelerate to determin logging 2025-10-15 13:04:34 +02:00
Pepijn
c775d8d07a add min memory to cpu tests 2025-10-15 13:00:30 +02:00
Pepijn
300d614ae5 encorperate feedback pr 2025-10-15 12:57:41 +02:00
Pepijn
1d86482101 Merge branch 'main' into feat/accelerate-melt-gpus 2025-10-14 20:48:46 +02:00
Pepijn
5730b0edfa Merge branch 'main' into feat/accelerate-melt-gpus 2025-10-14 17:43:16 +02:00
Pepijn
ebf64bd80e fix formatting 2025-10-14 17:37:47 +02:00
Pepijn
f8a185f753 cleanup logging 2025-10-14 17:05:47 +02:00
Pepijn
a66b50d372 scale lr decay if we reduce steps 2025-10-14 15:59:46 +02:00
Pepijn
9950bfd66f fix bug 2025-10-14 15:22:59 +02:00
Pepijn
4170d1b6f1 cleanup 2025-10-14 14:48:18 +02:00
Pepijn
d3f1ece680 cleanup update method 2025-10-14 14:33:58 +02:00
Pepijn
4061b3f5b3 always use accelerate 2025-10-14 14:24:55 +02:00
Pepijn
d2687e9486 add some debugging 2025-10-14 14:13:50 +02:00
Pepijn
bb824f2275 change accelerate detection 2025-10-14 14:08:20 +02:00
Pepijn
da78460b65 fix OOM bug 2025-10-14 14:01:51 +02:00
Pepijn
a0d0b00e04 small improvements in train 2025-10-14 13:53:38 +02:00
Pepijn
cabc47c5ad simplify accelerate main process detection 2025-10-14 13:38:36 +02:00
Pepijn
50ff388bf6 update docs, and small improvements in train 2025-10-14 13:31:52 +02:00
Pepijn
a86cea5708 fix path optimizer state 2025-10-14 11:26:57 +02:00
Pepijn
6486982ab4 small fixes 2025-10-14 10:46:19 +02:00
Pepijn
2bc154e706 Merge branch 'feat/accelerate-melt-gpus' of https://github.com/huggingface/lerobot into feat/accelerate-melt-gpus 2025-10-14 10:25:20 +02:00
Pepijn
0d79130729 pre download dataset in tests 2025-10-14 10:24:46 +02:00
Pepijn
ed267d4cf1 Merge branch 'main' into feat/accelerate-melt-gpus 2025-10-14 01:13:23 -07:00
Pepijn
252bca9354 dont push to hub in multi gpu tests 2025-10-14 10:06:32 +02:00
Pepijn
43bef1d91c fix test 2025-10-13 17:59:59 +02:00
Pepijn
4c40be57d8 change runner 2025-10-13 17:28:06 +02:00
Pepijn
c711a628b9 add tests 2025-10-13 16:25:46 +02:00
Pepijn
a74affad7c try with local rank 2025-10-10 15:52:49 +02:00
Pepijn
63fcebd5a7 main logging 2025-10-10 15:01:27 +02:00
Pepijn
8ebda30d1a Merge branch 'feat/accelerate-melt-gpus' of https://github.com/huggingface/lerobot into feat/accelerate-melt-gpus 2025-10-10 14:06:01 +02:00
Pepijn
b65172f819 only log in main process 2025-10-10 14:05:53 +02:00
Pepijn
deaeb4281c Merge branch 'main' into feat/accelerate-melt-gpus 2025-10-10 13:35:58 +02:00
Pepijn
771b03c30d fix pre commit 2025-10-10 13:35:26 +02:00
Pepijn
d709acfc55 Merge branch 'feat/accelerate-melt-gpus' of https://github.com/huggingface/lerobot into feat/accelerate-melt-gpus 2025-10-10 11:25:55 +02:00
Pepijn
95b6035baa Place logging under accelerate and update docs 2025-10-10 11:25:53 +02:00
Pepijn
629bbca96b Merge branch 'main' into feat/accelerate-melt-gpus 2025-10-10 10:09:53 +02:00
Pepijn
52751e8e6d Merge branch 'main' into feat/accelerate-melt-gpus 2025-10-09 15:19:48 +02:00
Pepijn
4b7cd7211a add docs and only push model once 2025-10-09 15:11:47 +02:00
AdilZouitine
dbce707db5 Initialize logging in training script for both main and non-main processes
- Added `init_logging` calls to ensure proper logging setup when using the accelerator and in standard training mode.
- This change enhances the clarity and consistency of logging during training sessions.
2025-10-03 16:43:05 +02:00
AdilZouitine
f30da2dec1 Enhance training and logging functionality with accelerator support
- Added support for multi-GPU training by introducing an `accelerator` parameter in training functions.
- Updated `update_policy` to handle gradient updates based on the presence of an accelerator.
- Modified logging to prevent duplicate messages in non-main processes.
- Enhanced `set_seed` and `get_safe_torch_device` functions to accommodate accelerator usage.
- Updated `MetricsTracker` to account for the number of processes when calculating metrics.
- Introduced a new feature in `pyproject.toml` for the `accelerate` library dependency.
2025-10-02 18:11:27 +02:00
165 changed files with 1909 additions and 21355 deletions

View File

@@ -78,7 +78,7 @@ jobs:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install lerobot with all extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
run: uv sync --all-extras
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10

View File

@@ -189,6 +189,5 @@ jobs:
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
timeout-minutes: 10

View File

@@ -82,14 +82,6 @@ jobs:
exit 1
fi
- name: Remove Tags with Git dependencies
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
run: |
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
echo "::info:: Git dependencies removed. Proceeding with build."
- name: Install build dependencies
run: python -m pip install build
@@ -111,7 +103,7 @@ jobs:
- name: Publish to TestPyPI for pre-releases
# True for tags like 'v0.2.0-rc1'
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with:
repository-url: https://test.pypi.org/legacy/
verbose: true
@@ -119,7 +111,7 @@ jobs:
- name: Publish to PyPI
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
uses: pypa/gh-action-pypi-publish@v1.13.0 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
with:
verbose: true
print-hash: true
@@ -146,7 +138,7 @@ jobs:
- name: Setup uv and Python
uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses]
with:
enable-cache: true # zizmor: ignore[cache-poisoning]
enable-cache: true
version: ${{ env.UV_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
- name: Create uv virtual environment

View File

@@ -27,17 +27,15 @@ env:
This issue was closed because it has been stalled for 14 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
CLOSE_PR_MESSAGE: >
This PR was closed because it has been stalled for 21 days with no activity.
This PR was closed because it has been stalled for 14 days with no activity.
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
WARN_ISSUE_MESSAGE: >
This issue has been automatically marked as stale because it has not had
recent activity (6 months). It will be closed if no further activity occurs.
Any change, comment or update to this issue will reset this count.
Thank you for your contributions.
WARN_PR_MESSAGE: >
This PR has been automatically marked as stale because it has not had
recent activity (1 year). It will be closed if no further activity occurs.
Any change, comment or update to this PR will reset this count.
recent activity (6 months). It will be closed if no further activity occurs.
Thank you for your contributions.
jobs:
@@ -58,10 +56,10 @@ jobs:
stale-pr-label: stale
exempt-issue-labels: never-stale
exempt-pr-labels: never-stale
days-before-issue-stale: 180
days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup
days-before-issue-close: 14
days-before-pr-stale: 365
days-before-pr-close: 21
days-before-pr-stale: 180
days-before-pr-close: 14
delete-branch: true
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}

View File

@@ -70,7 +70,7 @@ jobs:
echo "Dependencies unbound:" && cat pyproject.toml
- name: Install lerobot with all extras
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
run: uv sync --all-extras
- name: Run pytest (all extras)
run: uv run pytest tests -vv

View File

@@ -26,7 +26,7 @@ repos:
##### General Code Quality & Formatting #####
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
rev: v5.0.0
hooks:
- id: check-added-large-files
args: ['--maxkb=1024']
@@ -39,20 +39,20 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.1
rev: v0.12.4
hooks:
- id: ruff-format
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/adhtruong/mirrors-typos
rev: v1.38.1
rev: v1.34.0
hooks:
- id: typos
args: [--force-exclude]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.0
rev: v3.20.0
hooks:
- id: pyupgrade
args: [--py310-plus]
@@ -68,12 +68,12 @@ repos:
##### Security #####
- repo: https://github.com/gitleaks/gitleaks
rev: v8.28.0
rev: v8.27.2
hooks:
- id: gitleaks
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.15.2
rev: v1.11.0
hooks:
- id: zizmor
@@ -87,7 +87,7 @@ repos:
# TODO(Steven): Uncomment when ready to use
##### Static Analysis & Typing #####
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
rev: v1.16.0
hooks:
- id: mypy
args: [--config-file=pyproject.toml]

View File

@@ -137,7 +137,7 @@ Follow these steps to start contributing:
4. for development, we advise to use a tool like `poetry` or `uv` instead of just `pip` to easily track our dependencies.
Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already.
Set up a development environment with conda:
Set up a development environment with conda or miniconda:
```bash
conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev

View File

@@ -104,14 +104,14 @@ LeRobot works with Python 3.10+ and PyTorch 2.2+.
### Environment Setup
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniforge`](https://conda-forge.org/download/):
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
```
When using `conda`, install `ffmpeg` in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
@@ -185,11 +185,6 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
### Weights & Biases
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
@@ -212,13 +207,13 @@ lerobot-dataset-viz \
--episode-index 0
```
or from a dataset in a local folder with the `root` option and the `--mode local` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
--mode local \
--local-files-only 1 \
--episode-index 0
```
@@ -315,7 +310,7 @@ To upload these to the hub, run the following:
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
```
See [lerobot_eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_eval.py) for an example of how other people may use your policy.
See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
### Acknowledgment
@@ -342,3 +337,7 @@ If you want, you can cite this work with:
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline)
```
```

View File

@@ -15,6 +15,8 @@
title: Train a Robot with RL
- local: hilserl_sim
title: Train RL in Simulation
- local: async
title: Use Async Inference
- local: multi_gpu_training
title: Multi GPU training
title: "Tutorials"
@@ -35,18 +37,8 @@
title: π₀ (Pi0)
- local: pi05
title: π₀.₅ (Pi05)
- local: groot
title: NVIDIA GR00T N1.5
title: "Policies"
- sections:
- local: async
title: Use Async Inference
- local: rtc
title: Real-Time Chunking (RTC)
title: "Inference"
- sections:
- local: envhub
title: Environments from the Hub
- local: il_sim
title: Imitation Learning in Sim
- local: libero
@@ -63,8 +55,6 @@
title: Implement your own processor
- local: processors_robots_teleop
title: Processors for Robots and Teleoperators
- local: env_processor
title: Environment Processors
title: "Robot Processors"
- sections:
- local: so101

View File

@@ -1,418 +0,0 @@
# Environment Processors
Environment processors are a critical layer in LeRobot's data processing architecture that handle **environment-specific** transformations, separate from policy-specific processing. This separation of concerns enables cleaner code, better modularity, and easier experimentation with different environments and policies.
## Why Environment Processors?
When working with different robot environments (LIBERO, MetaWorld, Aloha, etc.), each environment often has unique data formats, coordinate systems, and conventions that need standardization **before** policy processing. Without environment processors, these transformations would be:
1. **Hardcoded in environment code** - Making it difficult to experiment with different state representations
2. **Duplicated across policies** - Each policy would need to handle environment-specific quirks
3. **Mixed with policy logic** - Violating separation of concerns and making debugging harder
Environment processors solve this by providing a **dedicated processing layer** between raw environment observations and policy inputs.
## The Processing Pipeline
Here's how data flows through the complete processing pipeline during evaluation:
```python
# In lerobot_eval.py rollout() function:
# 1. Raw environment observation (numpy arrays, various formats)
raw_observation = env.step(action)
# 2. Convert numpy to torch, normalize images [0,1]
observation = preprocess_observation(raw_observation)
# 3. Add task metadata (for multi-task environments)
observation = add_envs_task(env, observation)
# 4. ENVIRONMENT-SPECIFIC preprocessing (NEW!)
# - Flatten robot states
# - Rotate images to match dataset conventions
# - Handle environment-specific coordinate systems
observation = env_preprocessor(observation)
# 5. POLICY-SPECIFIC preprocessing
# - Normalize with dataset statistics
# - Add batch dimensions
# - Move to GPU
# - Tokenize language instructions
observation = preprocessor(observation)
# 6. Policy inference
action = policy.select_action(observation)
# 7. POLICY-SPECIFIC postprocessing
# - Unnormalize actions
# - Remove batch dimensions
action = postprocessor(action)
# 8. ENVIRONMENT-SPECIFIC postprocessing (NEW!)
# - Convert action formats if needed
# - Apply environment-specific constraints
action_transition = {"action": action}
action_transition = env_postprocessor(action_transition)
action = action_transition["action"]
# 9. Execute in environment
env.step(action)
```
## The Benefits
### 1. **Separation of Concerns**
Environment processors handle transformations specific to the **environment's data format**, while policy processors handle transformations specific to the **model's requirements**.
```python
# ❌ Before: Mixed concerns
class LiberoVLAPolicy:
def preprocess(self, obs):
# Environment-specific: Flatten robot state (shouldn't be in policy!)
state = self._flatten_robot_state(obs["robot_state"])
# Policy-specific: Normalize with dataset stats
state = self.normalizer(state)
return state
# ✅ After: Clear separation
# Environment processor: Handles LIBERO's nested robot state
env_preprocessor = LiberoProcessorStep() # Flattens robot_state
# Policy processor: Handles model requirements
policy_preprocessor = NormalizerProcessorStep(stats=dataset_stats)
```
### 2. **Flexibility and Reusability**
The same policy can work with different environment processors, and the same environment processor can work with different policies:
```python
# Use SmolVLA policy with LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
# Or use ACT policy with the same LIBERO environment
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
```
### 3. **Easier Experimentation**
Want to try different state representations for LIBERO? Just create a new processor:
```python
# Original: 8D state (pos + quat→axisangle + gripper)
@ProcessorStepRegistry.register("libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
gripper = robot_state["gripper"]["qpos"] # 2D
state = torch.cat([eef_pos, eef_axisangle, gripper], dim=-1) # 8D
return state
# Experiment: Add velocity for better control
@ProcessorStepRegistry.register("libero_velocity_processor")
class LiberoVelocityProcessorStep(ObservationProcessorStep):
def _process_observation(self, obs):
# Include velocities for 14D state
eef_pos = robot_state["eef"]["pos"] # 3D
eef_axisangle = quat2axisangle(quat) # 3D
eef_vel = robot_state["eef"]["vel"] # 3D (NEW)
gripper_pos = robot_state["gripper"]["qpos"] # 2D
gripper_vel = robot_state["gripper"]["qvel"] # 3D (NEW)
state = torch.cat([eef_pos, eef_axisangle, eef_vel,
gripper_pos, gripper_vel], dim=-1) # 14D
return state
```
### 4. **Cleaner Environment Code**
Environments expose **all available data** without needing to know what downstream models will use:
```python
# LIBERO environment exposes full robot state
observation = {
"pixels": {"image": img, "image2": img2},
"robot_state": {
"eef": {"pos": ..., "quat": ..., "vel": ..., "mat": ..., "axisangle": ...},
"gripper": {"qpos": ..., "qvel": ...},
"joints": {"pos": ..., "vel": ...}
}
}
# Environment processor decides what to use
# Policy processor handles model-specific transformations
```
## Using Environment Processors
### Factory Function
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
```python
from lerobot.envs.factory import make_env_pre_post_processors
from lerobot.envs.configs import LiberoEnv, PushtEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
# For other environments: Returns identity processors (no-op)
pusht_cfg = PushtEnv()
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
```
### Implementation in `envs/factory.py`
```python
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs
"""
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
else:
# For all other environments, return an identity preprocessor
preprocessor = PolicyProcessorPipeline(steps=[])
# Postprocessor is currently identity for all environments
# Future: Could add environment-specific action transformations
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### Integration in Evaluation
In `lerobot_eval.py`, the environment processors are created once and used throughout:
```python
def eval_main(cfg: EvalPipelineConfig):
# Create environment
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size)
# Create policy
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
# Create policy processors
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
)
# Create environment processors (NEW!)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
# Run evaluation with both processor types
eval_policy_all(
envs=envs,
policy=policy,
env_preprocessor=env_preprocessor, # Environment-specific
env_postprocessor=env_postprocessor, # Environment-specific
preprocessor=preprocessor, # Policy-specific
postprocessor=postprocessor, # Policy-specific
n_episodes=cfg.eval.n_episodes,
)
```
## Example: LIBERO Environment Processor
The `LiberoProcessorStep` demonstrates a real-world environment processor:
```python
from lerobot.processor.pipeline import ObservationProcessorStep
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
class LiberoProcessorStep(ObservationProcessorStep):
"""
Processes LIBERO observations into the LeRobot format.
**State Processing:**
- Extracts end-effector position (3D)
- Converts quaternion to axis-angle representation (3D)
- Extracts gripper joint positions (2D)
- Concatenates into 8D state vector
**Image Processing:**
- Rotates images 180° to match HuggingFaceVLA/libero convention
"""
def _process_observation(self, observation):
processed_obs = observation.copy()
# Process images: Flip 180° for camera convention
for key in list(processed_obs.keys()):
if key.startswith("observation.images."):
img = processed_obs[key]
img = torch.flip(img, dims=[2, 3]) # Flip H and W
processed_obs[key] = img
# Process robot_state: Flatten to 8D vector
if "observation.robot_state" in processed_obs:
robot_state = processed_obs.pop("observation.robot_state")
eef_pos = robot_state["eef"]["pos"] # (B, 3)
eef_quat = robot_state["eef"]["quat"] # (B, 4)
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2)
# Convert quaternion to axis-angle
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
# Concatenate into single state vector
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
state = state.float()
processed_obs["observation.state"] = state
return processed_obs
```
### Why These Transformations?
1. **Image Rotation**: The HuggingFaceVLA/libero dataset has images rotated 180° from the raw LIBERO simulator. The processor handles this convention mismatch so policies trained on the dataset work seamlessly.
2. **State Flattening**: The raw LIBERO environment exposes nested dictionaries with all available state information (position, quaternion, velocity, matrix representation, etc.). The processor:
- Selects the relevant components (pos, quat, gripper)
- Converts quaternion to axis-angle (more suitable for learning)
- Flattens to a single 8D vector that policies expect
3. **Flexibility**: The environment still exposes **all** raw data. If you want to try different state representations (e.g., including velocities, using matrix representation instead of axis-angle), you can create a new processor without modifying the environment code.
## Adding Environment Processors for New Environments
To add environment processors for a new environment:
### 1. Create the Processor Step
```python
# In src/lerobot/processor/env_processor.py
@dataclass
@ProcessorStepRegistry.register(name="myenv_processor")
class MyEnvProcessorStep(ObservationProcessorStep):
"""Process observations from MyEnv."""
def _process_observation(self, observation):
processed = observation.copy()
# Your environment-specific transformations
if "myenv.specific.state" in processed:
state = processed.pop("myenv.specific.state")
# Transform to standard format
processed["observation.state"] = self._transform_state(state)
return processed
```
### 2. Update the Factory
```python
# In src/lerobot/envs/factory.py
def make_env_pre_post_processors(env_cfg: EnvConfig):
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
else:
preprocessor = PolicyProcessorPipeline(steps=[])
postprocessor = PolicyProcessorPipeline(steps=[])
return preprocessor, postprocessor
```
### 3. Use in Evaluation
No changes needed! The evaluation script automatically uses the appropriate processor:
```bash
lerobot-eval \
--policy.path=lerobot/my_policy \
--env.type=myenv \ # Automatically uses MyEnvProcessorStep
--eval.n_episodes=10
```
## Future: Environment Postprocessors
Currently, postprocessors are identity (no-op) for all environments. Future use cases include:
### Action Space Transformations
```python
@dataclass
class MyEnvActionPostprocessor(ProcessorStep):
"""Convert policy actions to environment-specific format."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Convert from Cartesian to joint space
if self.action_space == "joint":
action = self.ik_solver(action)
# Example: Apply environment-specific safety limits
action = torch.clamp(action, self.min_action, self.max_action)
transition["action"] = action
return transition
```
### Coordinate System Conversions
```python
@dataclass
class CoordinateTransformPostprocessor(ProcessorStep):
"""Transform actions between coordinate systems."""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition["action"]
# Example: Policy outputs in world frame, env expects base frame
action = self.world_to_base_transform(action)
transition["action"] = action
return transition
```
## Best Practices
1. **Keep environment processors simple**: They should only handle environment-specific data format issues, not complex learning-related transformations.
2. **Use policy processors for model requirements**: Normalization, batching, device placement, and tokenization belong in policy processors.
3. **Expose all data from environments**: Let processors decide what to use rather than hardcoding choices in the environment.
4. **Document conventions**: Clearly document any coordinate system conventions, camera orientations, or data formats that your processor handles.
5. **Test independently**: Environment processors should be testable without loading full policies or environments.
## Summary
Environment processors provide a **clean separation** between environment-specific data transformations and policy-specific model requirements. This architecture:
- ✅ Enables easy experimentation with different state representations
- ✅ Allows policies to work seamlessly across different environments
- ✅ Keeps environment code focused on simulation/hardware interface
- ✅ Makes processor pipelines more maintainable and debuggable
- ✅ Follows the single responsibility principle
The key insight: **Environments define data formats, processors standardize them, policies consume standardized data.** Each layer has a clear, focused responsibility.

View File

@@ -1,424 +0,0 @@
# Loading Environments from the Hub
The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community.
## Overview
With EnvHub, you can:
- Load environments from the Hub instantly
- Share your custom simulation tasks with the community
- Version control your environments using Git
- Distribute complex physics simulations without packaging hassles
## Quick Start
Loading an environment from the Hub is as simple as:
```python
from lerobot.envs.factory import make_env
# Load a hub environment (requires explicit consent to run remote code)
env = make_env("lerobot/cartpole-env", trust_remote_code=True)
```
<Tip warning={true}>
**Security Notice**: Loading environments from the Hub executes Python code
from third-party repositories. Only use `trust_remote_code=True` with
repositories you trust. We strongly recommend pinning to a specific commit
hash for reproducibility and security.
</Tip>
## What is EnvHub?
EnvHub is a framework that allows researchers and developers to:
1. **Publish environments** to the Hugging Face Hub as Git repositories
2. **Load environments** dynamically without installing them as packages
3. **Version and track** environment changes using Git semantics
4. **Discover** new simulation tasks shared by the community
This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, without worrying about dependency conflicts or complex installation procedures.
## Repository Structure
To make your environment loadable from the Hub, your repository must contain at minimum:
### Required Files
**`env.py`** (or custom Python file)
- Must expose a `make_env(n_envs: int, use_async_envs: bool)` function
- This function should return one of:
- A `gym.vector.VectorEnv` (most common)
- A single `gym.Env` (will be automatically wrapped)
- A dict mapping `{suite_name: {task_id: VectorEnv}}` (for multi-task benchmarks)
### Optional Files
**`requirements.txt`**
- List any additional dependencies your environment needs
- Users will need to install these manually before loading your environment
**`README.md`**
- Document your environment: what task it implements, observation/action spaces, rewards, etc.
- Include usage examples and any special setup instructions
**`.gitignore`**
- Exclude unnecessary files from your repository
### Example Repository Structure
```
my-environment-repo/
├── env.py # Main environment definition (required)
├── requirements.txt # Dependencies (optional)
├── README.md # Documentation (recommended)
├── assets/ # Images, videos, etc. (optional)
│ └── demo.gif
└── configs/ # Config files if needed (optional)
└── task_config.yaml
```
## Creating Your Environment Repository
### Step 1: Define Your Environment
Create an `env.py` file with a `make_env` function:
```python
# env.py
import gymnasium as gym
def make_env(n_envs: int = 1, use_async_envs: bool = False):
"""
Create vectorized environments for your custom task.
Args:
n_envs: Number of parallel environments
use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv
Returns:
gym.vector.VectorEnv or dict mapping suite names to vectorized envs
"""
def _make_single_env():
# Create your custom environment
return gym.make("CartPole-v1")
# Choose vector environment type
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
# Create vectorized environment
vec_env = env_cls([_make_single_env for _ in range(n_envs)])
return vec_env
```
### Step 2: Test Locally
Before uploading, test your environment locally:
```python
from lerobot.envs.utils import _load_module_from_path, _call_make_env, _normalize_hub_result
# Load your module
module = _load_module_from_path("./env.py")
# Test the make_env function
result = _call_make_env(module, n_envs=2, use_async_envs=False)
normalized = _normalize_hub_result(result)
# Verify it works
suite_name = next(iter(normalized))
env = normalized[suite_name][0]
obs, info = env.reset()
print(f"Observation shape: {obs.shape if hasattr(obs, 'shape') else type(obs)}")
env.close()
```
### Step 3: Upload to the Hub
Upload your repository to Hugging Face:
```bash
# Install huggingface_hub if needed
pip install huggingface_hub
# Login to Hugging Face
huggingface-cli login
# Create a new repository
huggingface-cli repo create my-custom-env --type space --org my-org
# Initialize git and push
git init
git add .
git commit -m "Initial environment implementation"
git remote add origin https://huggingface.co/my-org/my-custom-env
git push -u origin main
```
Alternatively, use the `huggingface_hub` Python API:
```python
from huggingface_hub import HfApi
api = HfApi()
# Create repository
api.create_repo("my-custom-env", repo_type="space")
# Upload files
api.upload_folder(
folder_path="./my-env-folder",
repo_id="username/my-custom-env",
repo_type="space",
)
```
## Loading Environments from the Hub
### Basic Usage
```python
from lerobot.envs.factory import make_env
# Load from the hub
envs_dict = make_env(
"username/my-custom-env",
n_envs=4,
trust_remote_code=True
)
# Access the environment
suite_name = next(iter(envs_dict))
env = envs_dict[suite_name][0]
# Use it like any gym environment
obs, info = env.reset()
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
```
### Advanced: Pinning to Specific Versions
For reproducibility and security, pin to a specific Git revision:
```python
# Pin to a specific branch
env = make_env("username/my-env@main", trust_remote_code=True)
# Pin to a specific commit (recommended for papers/experiments)
env = make_env("username/my-env@abc123def456", trust_remote_code=True)
# Pin to a tag
env = make_env("username/my-env@v1.0.0", trust_remote_code=True)
```
### Custom File Paths
If your environment definition is not in `env.py`:
```python
# Load from a custom file
env = make_env("username/my-env:custom_env.py", trust_remote_code=True)
# Combine with version pinning
env = make_env("username/my-env@v1.0:envs/task_a.py", trust_remote_code=True)
```
### Async Environments
For better performance with multiple environments:
```python
envs_dict = make_env(
"username/my-env",
n_envs=8,
use_async_envs=True, # Use AsyncVectorEnv for parallel execution
trust_remote_code=True
)
```
## URL Format Reference
The hub URL format supports several patterns:
| Pattern | Description | Example |
| -------------------- | ------------------------------ | -------------------------------------- |
| `user/repo` | Load `env.py` from main branch | `make_env("lerobot/pusht-env")` |
| `user/repo@revision` | Load from specific revision | `make_env("lerobot/pusht-env@main")` |
| `user/repo:path` | Load custom file | `make_env("lerobot/envs:pusht.py")` |
| `user/repo@rev:path` | Revision + custom file | `make_env("lerobot/envs@v1:pusht.py")` |
## Multi-Task Environments
For benchmarks with multiple tasks (like LIBERO), return a nested dictionary:
```python
def make_env(n_envs: int = 1, use_async_envs: bool = False):
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
# Return dict: {suite_name: {task_id: VectorEnv}}
return {
"suite_1": {
0: env_cls([lambda: gym.make("Task1-v0") for _ in range(n_envs)]),
1: env_cls([lambda: gym.make("Task2-v0") for _ in range(n_envs)]),
},
"suite_2": {
0: env_cls([lambda: gym.make("Task3-v0") for _ in range(n_envs)]),
}
}
```
## Security Considerations
<Tip warning={true}>
**Important**: The `trust_remote_code=True` flag is required to execute
environment code from the Hub. This is by design for security.
</Tip>
When loading environments from the Hub:
1. **Review the code first**: Visit the repository and inspect `env.py` before loading
2. **Pin to commits**: Use specific commit hashes for reproducibility
3. **Check dependencies**: Review `requirements.txt` for suspicious packages
4. **Use trusted sources**: Prefer official organizations or well-known researchers
5. **Sandbox if needed**: Run untrusted code in isolated environments (containers, VMs)
Example of safe usage:
```python
# ❌ BAD: Loading without inspection
env = make_env("random-user/untrusted-env", trust_remote_code=True)
# ✅ GOOD: Review code, then pin to specific commit
# 1. Visit https://huggingface.co/trusted-org/verified-env
# 2. Review the env.py file
# 3. Copy the commit hash
env = make_env("trusted-org/verified-env@a1b2c3d4", trust_remote_code=True)
```
## Example: CartPole from the Hub
Here's a complete example using the reference CartPole environment:
```python
from lerobot.envs.factory import make_env
import numpy as np
# Load the environment
envs_dict = make_env("lerobot/cartpole-env", n_envs=4, trust_remote_code=True)
# Get the vectorized environment
suite_name = next(iter(envs_dict))
env = envs_dict[suite_name][0]
# Run a simple episode
obs, info = env.reset()
done = np.zeros(env.num_envs, dtype=bool)
total_reward = np.zeros(env.num_envs)
while not done.all():
# Random policy
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
total_reward += reward
done = terminated | truncated
print(f"Average reward: {total_reward.mean():.2f}")
env.close()
```
## Benefits of EnvHub
### For Environment Authors
- **Easy distribution**: No PyPI packaging required
- **Version control**: Use Git for environment versioning
- **Rapid iteration**: Push updates instantly
- **Documentation**: Hub README renders beautifully
- **Community**: Reach LeRobot users directly
### For Researchers
- **Quick experiments**: Load any environment in one line
- **Reproducibility**: Pin to specific commits
- **Discovery**: Browse environments on the Hub
- **No conflicts**: No need to install conflicting packages
### For the Community
- **Growing ecosystem**: More diverse simulation tasks
- **Standardization**: Common `make_env` API
- **Collaboration**: Fork and improve existing environments
- **Accessibility**: Lower barrier to sharing research
## Troubleshooting
### "Refusing to execute remote code"
You must explicitly pass `trust_remote_code=True`:
```python
env = make_env("user/repo", trust_remote_code=True)
```
### "Module X not found"
The hub environment has dependencies you need to install:
```bash
# Check the repo's requirements.txt and install dependencies
pip install gymnasium numpy
```
### "make_env not found in module"
Your `env.py` must expose a `make_env` function:
```python
def make_env(n_envs: int, use_async_envs: bool):
# Your implementation
pass
```
### Environment returns wrong type
The `make_env` function must return:
- A `gym.vector.VectorEnv`, or
- A single `gym.Env`, or
- A dict `{suite_name: {task_id: VectorEnv}}`
## Best Practices
1. **Document your environment**: Include observation/action space descriptions, reward structure, and termination conditions in your README
2. **Add requirements.txt**: List all dependencies with versions
3. **Test thoroughly**: Verify your environment works locally before pushing
4. **Use semantic versioning**: Tag releases with version numbers
5. **Add examples**: Include usage examples in your README
6. **Keep it simple**: Minimize dependencies when possible
7. **License your work**: Add a LICENSE file to clarify usage terms
## Future Directions
The EnvHub ecosystem enables exciting possibilities:
- **GPU-accelerated physics**: Share Isaac Gym or Brax environments
- **Photorealistic rendering**: Distribute environments with advanced graphics
- **Multi-agent scenarios**: Complex interaction tasks
- **Real-world simulators**: Digital twins of physical setups
- **Procedural generation**: Infinite task variations
- **Domain randomization**: Pre-configured DR pipelines
As more researchers and developers contribute, the diversity and quality of available environments will grow, benefiting the entire robotics learning community.
## See Also
- [Hugging Face Hub Documentation](https://huggingface.co/docs/hub/en/index)
- [Gymnasium Documentation](https://gymnasium.farama.org/index.html)
- [Example Hub Environment](https://huggingface.co/lerobot/cartpole-env)

View File

@@ -1,125 +0,0 @@
# GR00T N1.5 Policy
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
This document outlines the specifics of its integration and usage within the LeRobot framework.
## Model Overview
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes:
- Real captured data from robots.
- Synthetic data generated using NVIDIA Isaac GR00T Blueprint.
- Internet-scale video data.
This approach allows the model to be highly adaptable through post-training for specific embodiments, tasks, and environments.
## Installation Requirements
As of today, GR00T N1.5 requires flash attention for it's internal working.
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
```bash
# Check https://pytorch.org/get-started/locally/ for your system
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
```
3. Install LeRobot by running:
```bash
pip install lerobot[groot]
```
## Usage
To use GR00T in your LeRobot configuration, specify the policy type as:
```python
policy.type=groot
```
## Training
### Training Command Example
Here's a complete training command for finetuning the base GR00T model on your own dataset:
```bash
# Using a multi-GPU setup
accelerate launch \
--multi_gpu \
--num_processes=$NUM_GPUS \
$(which lerobot-train) \
--output_dir=$OUTPUT_DIR \
--save_checkpoint=true \
--batch_size=$BATCH_SIZE \
--steps=$NUM_STEPS \
--save_freq=$SAVE_FREQ \
--log_freq=$LOG_FREQ \
--policy.push_to_hub=true \
--policy.type=groot \
--policy.repo_id=$REPO_ID \
--policy.tune_diffusion_model=false \
--dataset.repo_id=$DATASET_ID \
--wandb.enable=true \
--wandb.disable_artifact=true \
--job_name=$JOB_NAME
```
## Performance Results
### Libero Benchmark Results
> [!NOTE]
> Follow our instructions for Libero usage: [Libero](./libero)
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
| Benchmark | LeRobot Implementation | GR00T Reference |
| ------------------ | ---------------------- | --------------- |
| **Libero Spatial** | 82.0% | 92.0% |
| **Libero Object** | 99.0% | 92.0% |
| **Libero Long** | 82.0% | 76.0% |
| **Average** | 87.0% | 87.0% |
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
### Evaluate in your hardware setup
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
```bash
lerobot-record \
--robot.type=bi_so100_follower \
--robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \
--robot.id=bimanual_follower \
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
}' \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm"
--policy.path=<user>/groot-bimanual # your trained model
--dataset.episode_time_s=30
--dataset.reset_time_s=10
```
## License
This model follows the **Apache 2.0 License**, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T).

View File

@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
Then store your Hugging Face repository name in a variable:
```bash
HF_USER=$(hf auth whoami | head -n 1)
HF_USER=$(huggingface-cli whoami | head -n 1)
echo $HF_USER
```

View File

@@ -1,15 +1,8 @@
# Installation
## Install [`miniforge`](https://conda-forge.org/download/)
```bash
wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
## Environment Setup
Create a virtual environment with Python 3.10, using conda:
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
```bash
conda create -y -n lerobot python=3.10
@@ -21,7 +14,7 @@ Then activate your conda environment, you have to do this each time you open a s
conda activate lerobot
```
When using `conda`, install `ffmpeg` in your environment:
When using `miniconda`, install `ffmpeg` in your environment:
```bash
conda install ffmpeg -c conda-forge
@@ -81,9 +74,6 @@ _Replace `[...]` with your desired features._
For a full list of optional dependencies, see:
https://pypi.org/project/lerobot/
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
### Troubleshooting
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.

View File

@@ -208,36 +208,34 @@ LeRobot supports saving and loading calibration data automatically. This is usef
<!-- prettier-ignore-start -->
```python
@property
def is_calibrated(self) -> bool:
return True
def calibrate(self) -> None:
pass
```
<!-- prettier-ignore-end -->
> @property
> def is_calibrated(self) -> bool:
> return True
>
> def calibrate(self) -> None:
> pass
> ```
### `is_calibrated`
This should reflect whether your robot has the required calibration loaded.
<!-- prettier-ignore-start -->
```python
```
<!-- prettier-ignore-end -->python
@property
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
```
<!-- prettier-ignore-end -->
### `calibrate()`
The goal of the calibration is twofold:
- Know the physical range of motion of each motors in order to only send commands within this range.
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
- Know the physical range of motion of each motors in order to only send commands within this range.
- Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere.
It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this.
<!-- prettier-ignore-start -->
```python
def calibrate(self) -> None:

View File

@@ -28,11 +28,6 @@ As described by Physical Intelligence, while AI has achieved remarkable success
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Training Data and Capabilities
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:

View File

@@ -36,11 +36,6 @@ This diverse training mixture creates a "curriculum" that enables generalization
pip install -e ".[pi]"
```
> [!NOTE]
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
>
> This will be solved in the next patch release
## Usage
To use π₀.₅ in your LeRobot configuration, specify the policy type as:

View File

@@ -1,27 +0,0 @@
## Research Paper
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
## Repository
Code: https://github.com/NVIDIA/Isaac-GR00T
## Citation
```bibtex
@inproceedings{gr00tn1_2025,
archivePrefix = {arxiv},
eprint = {2503.14734},
title = {{GR00T} {N1}: An Open Foundation Model for Generalist Humanoid Robots},
author = {NVIDIA and Johan Bjorck andFernando Castañeda, Nikita Cherniadev and Xingye Da and Runyu Ding and Linxi "Jim" Fan and Yu Fang and Dieter Fox and Fengyuan Hu and Spencer Huang and Joel Jang and Zhenyu Jiang and Jan Kautz and Kaushil Kundalia and Lawrence Lao and Zhiqi Li and Zongyu Lin and Kevin Lin and Guilin Liu and Edith Llontop and Loic Magne and Ajay Mandlekar and Avnish Narayan and Soroush Nasiriany and Scott Reed and You Liang Tan and Guanzhi Wang and Zu Wang and Jing Wang and Qi Wang and Jiannan Xiang and Yuqi Xie and Yinzhen Xu and Zhenjia Xu and Seonghyeon Ye and Zhiding Yu and Ao Zhang and Hao Zhang and Yizhou Zhao and Ruijie Zheng and Yuke Zhu},
month = {March},
year = {2025},
booktitle = {ArXiv Preprint},
}
```
## Additional Resources
Blog: https://developer.nvidia.com/isaac/gr00t
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B

View File

@@ -1,188 +0,0 @@
# Real-Time Chunking (RTC)
Real-Time Chunking (RTC) is an inference-time method that allows large, flow-matching based robotic policies, such as [Pi0](./pi0), [Pi0.5](./pi05), and [SmolVLA](./smolvla), to produce smooth, continuous, and reactive motion despite having high inference latency.
These policies generate chunks of future actions (e.g., 50 steps at a time) instead of single actions.
Because the models are large, producing each chunk takes longer than the time it takes the robot to execute it.
Naively executing chunks leads to problems such as pauses, jerky transitions, or sudden changes in strategy whenever the next chunk arrives late or disagrees with the previously executed actions.
RTC solves this by asynchronously generating the next chunk while the robot continues executing the current one, and by guiding the new chunk so it aligns smoothly with the portion of the previous chunk that has already been executed.
## How RTC Works (simplified)
RTC lets the robot think ahead while its still moving. When the robot is carrying out one chunk of actions, RTC starts creating the next chunk early.
But since the robot has already moved a bit by the time the new chunk is ready, RTC has to make sure the new chunk still lines up smoothly with what the robot is currently doing.
To do this, RTC treats the beginning of the new chunk like an inpainting or “fill-in-the-gaps” problem:
it gently adjusts the first part of the new chunk so it blends naturally with the robots ongoing motion. The result is no pauses, no sudden jumps.
In technical terms, RTC adds a guidance term to the flow-matching denoising process that forces the overlapping timesteps of the new chunk to stay close to the executed portion of the previous chunk, typically using a soft transition mask.
## Quick Start
### Installation
RTC is built into LeRobot. Just install the policy dependencies you need:
```bash
# For Pi0 or Pi0.5
pip install -e ".[pi]"
# For SmolVLA
pip install -e ".[smolvla]"
```
### Using RTC with Pi0
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
```python
from lerobot.policies.pi0 import PI0Policy, PI0Config
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.action_queue import ActionQueue
# Load Pi0 with RTC enabled
policy_cfg = PI0Config()
# Enable RTC
policy_cfg.rtc_config = RTCConfig(
enabled=True,
execution_horizon=10, # How many steps to blend with previous chunk
max_guidance_weight=10.0, # How strongly to enforce consistency
prefix_attention_schedule=RTCAttentionSchedule.EXP, # Exponential blend
)
# Load the policy
policy = PI0Policy.from_pretrained("lerobot/pi0_base", policy_cfg=policy_cfg, device="cuda")
# Now use predict_action_chunk with RTC parameters
inference_delay = 4 # How many steps of inference latency, this values should be calculated based on the inference latency of the policy
# Initialize the action queue
action_queue = ActionQueue(policy_cfg.rtc_config)
# Start in a separate thread with the following function
def get_actions():
while True:
if should_get_actions:
prev_actions = action_queue.get_left_over()
obs = get_robot_observations(robot)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
action_queue.merge(
actions, actions, inference_delay
)
for step in range(num_steps):
action = action_queue.get()
# Execute the first N actions
execute_actions(action)
```
## Key Parameters
`RTCConfig` has the following parameters to tune:
**`execution_horizon`**: How many timesteps from the previous chunk to maintain consistency with. Higher values mean smoother transitions but potentially less reactivity.
Typical values: 8-12 steps
```python
RTCConfig(execution_horizon=10)
```
**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. For 10 steps flow matching (SmolVLA, Pi0, Pi0.5), a value of 10.0 is a optimal value.
**`prefix_attention_schedule`**: How to weight consistency across the overlap region.
- `LINEAR`: Linear decay from inference_delay to execution_horizon
- `EXP`: Exponential decay (recommended for getting started)
- `ONES`: Full weight across entire execution_horizon
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
**`inference_delay`**: How many timesteps of inference latency your system has. This is passed to `predict_action_chunk()` rather than the config, since it may vary at runtime.
## Testing RTC Offline
Before running on a real robot, test RTC with dataset samples to visualize how it works:
```bash
python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--rtc.max_guidance_weight=10.0 \
--device=cuda
```
The script generates a visualization of the denoising process, comparing standard generation (left) with RTC (right). In the RTC plots, you can see how the first few steps (blue/purple lines) are guided to match the red ground truth trajectory (previous chunk's tail), ensuring a smooth transition between chunks.
<p align="center">
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/flow_matching.png"
alt="Denoising steps with and without RTC"
width="100%"
/>
</p>
## Testing RTC with a Real Robot
```bash
python examples/rtc/eval_with_real_robot.py \
--policy.path=${HF_USERNAME}/policy_repo_id \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120 \
--device=cuda
```
## How It Differs from the Async Inference in LeRobot
Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
| Aspect | Async Inference | RTC |
| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
**Use both together** for maximum smoothness and reactivity!
## Advanced: Debug Tracking
RTC includes built-in debug tracking to help you understand what's happening during inference:
```python
# Enable debug tracking
policy_cfg.rtc_config.debug = True
policy_cfg.rtc_config.debug_maxlen = 100
# After inference, access debug data
debug_data = policy.rtc_processor.get_debug_data()
# Visualize denoising steps, corrections, etc.
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
visualizer = RTCDebugVisualizer()
# ... create plots
```
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
## References
- [Smooth-As-Butter Robot Policies](https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html) - Excellent technical explanation with real robot results
- [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/research/real_time_chunking) - Original paper and research
- [Kinetix RTC Implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix) - Reference implementation from Physical Intelligence

View File

@@ -132,15 +132,17 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
if __name__ == "__main__":
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
print(f"{batch['action'].shape=}") # (32, 64, c)
break

File diff suppressed because it is too large Load Diff

View File

@@ -1,525 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Visualize SARM Subtask Annotations
This script creates visualizations of the subtask annotations generated by subtask_annotation.py.
For each episode, it shows:
- A timeline with dashed vertical lines at subtask boundaries
- Sample frames from the episode at key points (start, middle, end of each subtask)
- Color-coded subtask segments
Usage:
python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5
"""
import argparse
import random
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from rich.console import Console
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import load_episodes
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
def timestamp_to_seconds(timestamp: str) -> float:
"""Convert MM:SS or SS timestamp to seconds"""
parts = timestamp.split(":")
if len(parts) == 2:
return int(parts[0]) * 60 + int(parts[1])
else:
return int(parts[0])
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
"""
Load annotations from LeRobot dataset parquet files.
Reads subtask annotations from the episodes metadata parquet files.
"""
episodes_dataset = load_episodes(dataset_path)
if episodes_dataset is None or len(episodes_dataset) == 0:
return {}
# Check if subtask columns exist
if "subtask_names" not in episodes_dataset.column_names:
return {}
# Convert to pandas DataFrame for easier access
episodes_df = episodes_dataset.to_pandas()
annotations = {}
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
# Skip episodes without annotations
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
continue
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
# Reconstruct SubtaskAnnotation from stored data
subtasks = []
for i, name in enumerate(subtask_names):
# Convert seconds back to MM:SS format
start_sec = int(start_times[i])
end_sec = int(end_times[i])
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
subtasks.append(
Subtask(
name=name,
timestamps=Timestamp(start=start_str, end=end_str)
)
)
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
return annotations
# Color palette for subtasks (colorblind-friendly)
SUBTASK_COLORS = [
"#E69F00", # Orange
"#56B4E9", # Sky blue
"#009E73", # Bluish green
"#F0E442", # Yellow
"#0072B2", # Blue
"#D55E00", # Vermillion
"#CC79A7", # Reddish purple
"#999999", # Gray
]
def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None:
"""Extract a single frame from video at given timestamp."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
return None
# Set position to timestamp
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
ret, frame = cap.read()
cap.release()
if ret:
# Convert BGR to RGB
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return None
def visualize_episode(
episode_idx: int,
annotation,
video_path: Path,
video_start_timestamp: float,
video_end_timestamp: float,
fps: int,
output_path: Path,
video_key: str,
):
"""
Create visualization for a single episode.
Shows:
- Top row: Sample frames from the episode (one per subtask)
- Bottom: Timeline with subtask segments and boundary lines
"""
subtasks = annotation.subtasks
num_subtasks = len(subtasks)
if num_subtasks == 0:
print(f"No subtasks found for episode {episode_idx}")
return
# Calculate episode duration
episode_duration = video_end_timestamp - video_start_timestamp
# Extract sample frames - get frame from middle of each subtask
sample_frames = []
frame_timestamps = []
for subtask in subtasks:
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
mid_sec = (start_sec + end_sec) / 2
# Convert to video timestamp (add video_start_timestamp offset)
video_timestamp = video_start_timestamp + mid_sec
frame_timestamps.append(mid_sec)
frame = extract_frame_from_video(video_path, video_timestamp)
sample_frames.append(frame)
# Create figure
fig = plt.figure(figsize=(16, 10))
# Use a dark background for better contrast
fig.patch.set_facecolor('#1a1a2e')
# Calculate grid layout
# Top section: frames (variable number of columns based on subtasks)
# Bottom section: timeline
# Create gridspec
gs = fig.add_gridspec(
2, max(num_subtasks, 1),
height_ratios=[2, 1],
hspace=0.3,
wspace=0.1,
left=0.05, right=0.95,
top=0.88, bottom=0.1
)
# Add title
fig.suptitle(
f"Episode {episode_idx} - Subtask Annotations",
fontsize=18,
fontweight='bold',
color='white',
y=0.96
)
# Add subtitle with video info
fig.text(
0.5, 0.91,
f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks",
ha='center',
fontsize=11,
color='#888888'
)
# Plot sample frames
for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)):
ax = fig.add_subplot(gs[0, i])
ax.set_facecolor('#16213e')
if frame is not None:
ax.imshow(frame)
else:
ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center',
fontsize=12, color='white', transform=ax.transAxes)
ax.set_title(
f"{subtask.name}",
fontsize=10,
fontweight='bold',
color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)],
pad=8
)
ax.axis('off')
# Add frame timestamp below
ax.text(
0.5, -0.08,
f"t={frame_timestamps[i]:.1f}s",
ha='center',
fontsize=9,
color='#888888',
transform=ax.transAxes
)
# Create timeline subplot spanning all columns
ax_timeline = fig.add_subplot(gs[1, :])
ax_timeline.set_facecolor('#16213e')
# Get total duration from last subtask end time
total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
# Draw subtask segments as colored bars
bar_height = 0.6
bar_y = 0.5
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)]
# Draw segment bar
rect = mpatches.FancyBboxPatch(
(start_sec, bar_y - bar_height/2),
end_sec - start_sec,
bar_height,
boxstyle="round,pad=0.02,rounding_size=0.1",
facecolor=color,
edgecolor='white',
linewidth=1.5,
alpha=0.85
)
ax_timeline.add_patch(rect)
# Add subtask label inside bar
mid_x = (start_sec + end_sec) / 2
duration = end_sec - start_sec
# Only add text if segment is wide enough
if duration > total_duration * 0.08:
ax_timeline.text(
mid_x, bar_y,
subtask.name,
ha='center', va='center',
fontsize=9,
fontweight='bold',
color='black' if i in [3] else 'white', # Yellow needs dark text
rotation=0 if duration > total_duration * 0.15 else 45
)
# Draw boundary lines (dashed vertical lines between subtasks)
boundary_times = []
for i, subtask in enumerate(subtasks):
start_sec = timestamp_to_seconds(subtask.timestamps.start)
end_sec = timestamp_to_seconds(subtask.timestamps.end)
# Add start boundary (except for first subtask at t=0)
if i == 0 and start_sec > 0:
boundary_times.append(start_sec)
elif i > 0:
boundary_times.append(start_sec)
# Add end boundary for last subtask
if i == len(subtasks) - 1:
boundary_times.append(end_sec)
# Draw dashed lines at boundaries
for t in boundary_times:
ax_timeline.axvline(
x=t,
ymin=0.1, ymax=0.9,
color='white',
linestyle='--',
linewidth=2,
alpha=0.9
)
# Add time label below line
ax_timeline.text(
t, 0.0,
f"{int(t//60):02d}:{int(t%60):02d}",
ha='center', va='top',
fontsize=8,
color='#cccccc'
)
# Add start line at t=0
ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9)
ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold')
# Configure timeline axes
ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02)
ax_timeline.set_ylim(-0.3, 1.2)
ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10)
ax_timeline.set_ylabel("")
# Style the axes
ax_timeline.spines['top'].set_visible(False)
ax_timeline.spines['right'].set_visible(False)
ax_timeline.spines['left'].set_visible(False)
ax_timeline.spines['bottom'].set_color('#444444')
ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9)
ax_timeline.tick_params(axis='y', left=False, labelleft=False)
# Add x-axis ticks at regular intervals
tick_interval = max(1, int(total_duration / 10))
ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval))
# Add legend explaining line styles
legend_elements = [
Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'),
Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'),
]
ax_timeline.legend(
handles=legend_elements,
loc='upper right',
framealpha=0.3,
facecolor='#16213e',
edgecolor='#444444',
fontsize=9,
labelcolor='white'
)
# Save figure
plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')
plt.close()
return output_path
def main():
parser = argparse.ArgumentParser(
description="Visualize SARM subtask annotations",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="HuggingFace dataset repository ID",
)
parser.add_argument(
"--num-episodes",
type=int,
default=5,
help="Number of random episodes to visualize (default: 5)",
)
parser.add_argument(
"--episodes",
type=int,
nargs="+",
default=None,
help="Specific episode indices to visualize (overrides --num-episodes)",
)
parser.add_argument(
"--video-key",
type=str,
default=None,
help="Camera/video key to use. If not specified, uses first available.",
)
parser.add_argument(
"--output-dir",
type=str,
default="./subtask_viz",
help="Output directory for visualizations (default: ./subtask_viz)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility",
)
args = parser.parse_args()
console = Console()
# Set random seed if specified
if args.seed is not None:
random.seed(args.seed)
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
dataset = LeRobotDataset(args.repo_id, download_videos=True)
fps = dataset.fps
# Get video key
if args.video_key:
if args.video_key not in dataset.meta.video_keys:
console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]")
console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]")
return
video_key = args.video_key
else:
video_key = dataset.meta.video_keys[0]
console.print(f"[cyan]Using camera: {video_key}[/cyan]")
console.print(f"[cyan]FPS: {fps}[/cyan]")
# Load annotations
console.print(f"\n[cyan]Loading annotations...[/cyan]")
annotations = load_annotations_from_dataset(dataset.root)
if not annotations:
console.print("[red]Error: No annotations found in dataset[/red]")
console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]")
return
console.print(f"[green]Found {len(annotations)} annotated episodes[/green]")
# Determine which episodes to visualize
if args.episodes:
episode_indices = args.episodes
# Validate episodes exist
for ep in episode_indices:
if ep not in annotations:
console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]")
episode_indices = [ep for ep in episode_indices if ep in annotations]
else:
# Random selection
available_episodes = list(annotations.keys())
num_to_select = min(args.num_episodes, len(available_episodes))
episode_indices = random.sample(available_episodes, num_to_select)
episode_indices.sort()
if not episode_indices:
console.print("[red]Error: No valid episodes to visualize[/red]")
return
console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]")
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Generate visualizations
for ep_idx in episode_indices:
console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]")
annotation = annotations[ep_idx]
# Get video path and timestamps
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
if not video_path.exists():
console.print(f"[red]Video not found: {video_path}[/red]")
continue
# Get episode-specific timestamps within the video file
video_path_key = f"videos/{video_key}/from_timestamp"
video_path_key_to = f"videos/{video_key}/to_timestamp"
video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx])
video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx])
# Create visualization
output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png"
try:
visualize_episode(
episode_idx=ep_idx,
annotation=annotation,
video_path=video_path,
video_start_timestamp=video_start_timestamp,
video_end_timestamp=video_end_timestamp,
fps=fps,
output_path=output_path,
video_key=video_key,
)
console.print(f"[green]✓ Saved: {output_path}[/green]")
except Exception as e:
console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]")
# Print summary
console.print(f"\n[bold green]{'=' * 50}[/bold green]")
console.print(f"[bold green]Visualization Complete![/bold green]")
console.print(f"[bold green]{'=' * 50}[/bold green]")
console.print(f"Output directory: {output_dir.absolute()}")
console.print(f"Episodes visualized: {len(episode_indices)}")
if __name__ == "__main__":
main()

View File

@@ -15,12 +15,16 @@
# limitations under the License.
import argparse
import logging
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_droid import DROID_SHARDS
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
class AggregateDatasets(PipelineStep):
@@ -34,11 +38,6 @@ class AggregateDatasets(PipelineStep):
self.aggr_repo_id = aggregated_repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
init_logging()
# Since aggregate_datasets already handles parallel processing internally,

View File

@@ -20,7 +20,7 @@ from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_droid import DROID_SHARDS
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
class PortDroidShards(PipelineStep):
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from port_droid import port_droid, validate_dataset
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
from lerobot.utils.utils import init_logging

View File

@@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from port_droid import DROID_SHARDS
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
@@ -185,11 +185,11 @@ class UploadDataset(PipelineStep):
def make_upload_executor(
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
UploadDataset(repo_id, private=private),
UploadDataset(repo_id),
],
"logging_dir": str(logs_dir / job_name),
}
@@ -267,12 +267,6 @@ def main():
default="1950M",
help="Memory per cpu that each worker will use.",
)
parser.add_argument(
"--private",
action="store_true",
default=False,
help="Whether to create a private repository.",
)
init_logging()

View File

@@ -1,951 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluate Real-Time Chunking (RTC) performance on dataset samples.
This script takes two random samples from a dataset:
- Uses actions from the first sample as previous chunk
- Generates new actions for the second sample with and without RTC
It compares action predictions with and without RTC on dataset samples,
measuring consistency and ground truth alignment.
Usage:
# Basic usage with smolvla policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--rtc.max_guidance_weight=10.0 \
--rtc.prefix_attention_schedule=EXP \
--seed=10
# Basic usage with pi0.5 policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=10 \
--device=mps
--seed=10
# Basic usage with pi0.5 policy with cuda device
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
# Basic usage with pi0 policy with cuda device
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi0_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=cuda
uv run python examples/rtc/eval_dataset.py \
--policy.path=lipsop/reuben_pi0 \
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
--rtc.execution_horizon=8 \
--device=cuda
# With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=mps \
--use_torch_compile=true \
--torch_compile_mode=max-autotune
# With torch.compile on CUDA (CUDA graphs disabled by default)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--rtc.execution_horizon=8 \
--device=cuda \
--use_torch_compile=true \
--torch_compile_mode=reduce-overhead
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \
--use_torch_compile=true \
--torch_compile_backend=inductor \
--torch_compile_mode=max-autotune \
--torch_compile_disable_cudagraphs=false
"""
import gc
import logging
import os
import random
from dataclasses import dataclass, field
import numpy as np
import torch
try:
import matplotlib.pyplot as plt
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
plt = None
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.factory import resolve_delta_timestamps
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
def set_seed(seed: int):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def _check_matplotlib_available():
"""Check if matplotlib is available, raise helpful error if not."""
if not MATPLOTLIB_AVAILABLE:
raise ImportError(
"matplotlib is required for RTC debug visualizations. "
"Please install it by running:\n"
" uv pip install matplotlib"
)
@dataclass
class RTCEvalConfig(HubMixin):
"""Configuration for RTC evaluation."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Dataset configuration
dataset: DatasetConfig = field(default_factory=DatasetConfig)
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
enabled=True,
execution_horizon=20,
max_guidance_weight=10.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
debug=True,
debug_maxlen=1000,
)
)
# Device configuration
device: str | None = field(
default=None,
metadata={"help": "Device to run on (cuda, cpu, mps, auto)"},
)
# Output configuration
output_dir: str = field(
default="rtc_debug_output",
metadata={"help": "Directory to save debug visualizations"},
)
# Seed configuration
seed: int = field(
default=42,
metadata={"help": "Random seed for reproducibility"},
)
inference_delay: int = field(
default=4,
metadata={"help": "Inference delay for RTC"},
)
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# Parse policy path
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required (--policy.path)")
# Auto-detect device if not specified
if self.device is None or self.device == "auto":
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
logging.info(f"Auto-detected device: {self.device}")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
class RTCEvaluator:
"""Evaluator for RTC on dataset samples."""
def __init__(self, cfg: RTCEvalConfig):
self.cfg = cfg
self.device = cfg.device
# Load dataset with proper delta_timestamps based on policy configuration
# Calculate delta_timestamps using the same logic as make_dataset factory
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
# Get dataset metadata to extract FPS
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id)
# Calculate delta_timestamps from policy's delta_indices
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
# Create dataset with calculated delta_timestamps
self.dataset = LeRobotDataset(
cfg.dataset.repo_id,
delta_timestamps=delta_timestamps,
)
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
# Create preprocessor/postprocessor
self.preprocessor, self.postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
preprocessor_overrides={
"device_processor": {"device": self.device},
},
)
logging.info("=" * 80)
logging.info("Ready to run evaluation with sequential policy loading:")
logging.info(" 1. policy_prev_chunk - Generate reference chunk, then destroy")
logging.info(" 2. policy_no_rtc - Generate without RTC, then destroy")
logging.info(" 3. policy_rtc - Generate with RTC, then destroy")
logging.info(" Note: Only one policy in memory at a time for efficient memory usage")
logging.info("=" * 80)
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
"""Initialize a single policy instance with specified RTC configuration.
Args:
name: Name identifier for logging purposes
rtc_enabled: Whether to enable RTC for this policy
rtc_debug: Whether to enable debug tracking for this policy
Returns:
Configured policy instance with optional torch.compile applied
"""
logging.info(f"Initializing {name}...")
# Load policy from pretrained
policy_class = get_policy_class(self.cfg.policy.type)
config = PreTrainedConfig.from_pretrained(self.cfg.policy.pretrained_path)
if self.cfg.policy.type == "pi05" or self.cfg.policy.type == "pi0":
config.compile_model = self.cfg.use_torch_compile
policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=config)
policy = policy.to(self.device)
policy.eval()
# Configure RTC
rtc_config = RTCConfig(
enabled=rtc_enabled,
execution_horizon=self.cfg.rtc.execution_horizon,
max_guidance_weight=self.cfg.rtc.max_guidance_weight,
prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
debug=rtc_debug,
debug_maxlen=self.cfg.rtc.debug_maxlen,
)
policy.config.rtc_config = rtc_config
policy.init_rtc_processor()
logging.info(f" RTC enabled: {rtc_enabled}")
logging.info(f" RTC debug: {rtc_debug}")
logging.info(f" Policy config: {config}")
# Apply torch.compile to predict_action_chunk method if enabled
if self.cfg.use_torch_compile:
policy = self._apply_torch_compile(policy, name)
logging.info(f"{name} initialized successfully")
return policy
def _apply_torch_compile(self, policy, policy_name: str):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
policy_name: Name for logging purposes
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logging.warning(
f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
logging.info(f" Disable CUDA graphs: {self.cfg.torch_compile_disable_cudagraphs}")
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
# Compile the predict_action_chunk method
# - Debug tracker is excluded from compilation via @torch._dynamo.disable
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": self.cfg.torch_compile_backend,
"mode": self.cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if self.cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
except Exception as e:
logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}")
logging.warning(f" [{policy_name}] Continuing without torch.compile")
return policy
def _destroy_policy(self, policy, policy_name: str):
"""Explicitly destroy a policy and free all associated memory.
This method performs aggressive cleanup to ensure maximum memory is freed,
which is critical for large models (e.g., VLAs with billions of parameters).
Args:
policy: Policy instance to destroy
policy_name: Name for logging purposes
"""
logging.info(f" Destroying {policy_name} and freeing memory...")
try:
# Step 1: Move policy to CPU to free GPU/MPS memory
policy.cpu()
# Step 2: Delete the policy object
del policy
# Step 3: Force garbage collection to reclaim memory immediately
gc.collect()
# Step 4: Clear device-specific caches
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize() # Ensure all operations complete
if torch.backends.mps.is_available():
torch.mps.empty_cache()
logging.info(f"{policy_name} destroyed and memory freed")
except Exception as e:
logging.warning(f" Warning: Error during {policy_name} cleanup: {e}")
def run_evaluation(self):
"""Run evaluation on two random dataset samples using three separate policies.
Note: Policies are deinitalized after each step to free memory. Large models
(e.g., VLA models with billions of parameters) cannot fit three instances in
memory simultaneously. By deleting and garbage collecting after each step,
we ensure only one policy is loaded at a time.
"""
# Create output directory
os.makedirs(self.cfg.output_dir, exist_ok=True)
logging.info(f"Output directory: {self.cfg.output_dir}")
logging.info("=" * 80)
logging.info("Starting RTC evaluation")
logging.info(f"Inference delay: {self.cfg.inference_delay}")
logging.info("=" * 80)
# Load two random samples from dataset
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
loader_iter = iter(data_loader)
first_sample = next(loader_iter)
second_sample = next(loader_iter)
preprocessed_first_sample = self.preprocessor(first_sample)
preprocessed_second_sample = self.preprocessor(second_sample)
# ============================================================================
# Step 1: Generate previous chunk using policy_prev_chunk
# ============================================================================
# This policy is only used to generate the reference chunk and then freed
logging.info("=" * 80)
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
logging.info("=" * 80)
# Initialize policy 1
policy_prev_chunk_policy = self._init_policy(
name="policy_prev_chunk",
rtc_enabled=False,
rtc_debug=False,
)
with torch.no_grad():
prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
preprocessed_first_sample,
)[:, :25, :].squeeze(0)
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
# Destroy policy_prev_chunk to free memory for large models
self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
# ============================================================================
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
# ============================================================================
logging.info("=" * 80)
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
logging.info("=" * 80)
set_seed(self.cfg.seed)
# Initialize policy 2
policy_no_rtc_policy = self._init_policy(
name="policy_no_rtc",
rtc_enabled=False,
rtc_debug=True,
)
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
noise_clone = noise.clone()
policy_no_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad():
no_rtc_actions = policy_no_rtc_policy.predict_action_chunk(
preprocessed_second_sample,
noise=noise,
)
no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps()
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
logging.info(f" Generated no_rtc_actions shape: {no_rtc_actions.shape}")
# Destroy policy_no_rtc to free memory before loading policy_rtc
self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
# ============================================================================
# Step 3: Generate actions WITH RTC using policy_rtc
# ============================================================================
logging.info("=" * 80)
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
logging.info("=" * 80)
set_seed(self.cfg.seed)
# Initialize policy 3
policy_rtc_policy = self._init_policy(
name="policy_rtc",
rtc_enabled=True,
rtc_debug=True,
)
policy_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad():
rtc_actions = policy_rtc_policy.predict_action_chunk(
preprocessed_second_sample,
noise=noise_clone,
inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=self.cfg.rtc.execution_horizon,
)
rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps()
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
logging.info(f" Generated rtc_actions shape: {rtc_actions.shape}")
# Save num_steps before destroying policy (needed for plotting)
try:
num_steps = policy_rtc_policy.config.num_steps
except Exception as e:
logging.error(f" Error getting num_steps: {e}")
num_steps = policy_rtc_policy.config.num_inference_steps
logging.warning(f" Using num_inference_steps: {num_steps} instead of num_steps")
# Destroy policy_rtc after final use
self._destroy_policy(policy_rtc_policy, "policy_rtc")
# Plot and save results
logging.info("=" * 80)
logging.info("Plotting results...")
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
# Plot final actions comparison
logging.info("=" * 80)
logging.info("Plotting final actions comparison...")
self.plot_final_actions_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over)
logging.info("=" * 80)
logging.info("Evaluation completed successfully")
def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
"""Plot final action predictions comparison on a single chart.
Args:
rtc_actions: Final actions from RTC policy
no_rtc_actions: Final actions from non-RTC policy
prev_chunk_left_over: Previous chunk used as ground truth
"""
_check_matplotlib_available()
# Remove batch dimension if present
rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
no_rtc_actions_plot = (
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
)
prev_chunk_plot = prev_chunk_left_over.cpu()
# Create figure with 6 subplots (one per action dimension)
fig, axes = plt.subplots(6, 1, figsize=(16, 12))
fig.suptitle("Final Action Predictions Comparison (Raw)", fontsize=16)
# Plot each action dimension
for dim_idx, ax in enumerate(axes):
# Plot previous chunk (ground truth) in red
RTCDebugVisualizer.plot_waypoints(
[ax],
prev_chunk_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="red",
label="Previous Chunk (Ground Truth)",
linewidth=2.5,
alpha=0.8,
)
# Plot no-RTC actions in blue
RTCDebugVisualizer.plot_waypoints(
[ax],
no_rtc_actions_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="blue",
label="No RTC",
linewidth=2,
alpha=0.7,
)
# Plot RTC actions in green
RTCDebugVisualizer.plot_waypoints(
[ax],
rtc_actions_plot[:, dim_idx : dim_idx + 1],
start_from=0,
color="green",
label="RTC",
linewidth=2,
alpha=0.7,
)
# Add vertical lines for inference delay and execution horizon
inference_delay = self.cfg.inference_delay
execution_horizon = self.cfg.rtc.execution_horizon
if inference_delay > 0:
ax.axvline(
x=inference_delay - 1,
color="orange",
linestyle="--",
alpha=0.5,
label=f"Inference Delay ({inference_delay})",
)
if execution_horizon > 0:
ax.axvline(
x=execution_horizon,
color="purple",
linestyle="--",
alpha=0.5,
label=f"Execution Horizon ({execution_horizon})",
)
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
ax.grid(True, alpha=0.3)
# Set x-axis ticks to show all integer values
max_len = max(rtc_actions_plot.shape[0], no_rtc_actions_plot.shape[0], prev_chunk_plot.shape[0])
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5)
axes[-1].set_xlabel("Step", fontsize=10)
# Collect legend handles and labels from first subplot
handles, labels = axes[0].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right)
fig.legend(
unique_handles,
unique_labels,
loc="center right",
fontsize=9,
bbox_to_anchor=(1.0, 0.5),
framealpha=0.9,
)
# Save figure
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right
fig.savefig(output_path, dpi=150, bbox_inches="tight")
logging.info(f"Saved final actions comparison to {output_path}")
plt.close(fig)
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
_check_matplotlib_available()
# Create side-by-side figures for denoising visualization
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
fig_corr, axs_corr = self._create_figure("Correction: No RTC (left) vs RTC (right)")
fig_x1t, axs_x1t = self._create_figure(
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
)
self._plot_denoising_steps_from_tracker(
rtc_tracked_steps,
axs_xt[:, 1], # Right column for x_t
axs_vt[:, 1], # Right column for v_t
axs_corr[:, 1], # Right column for correction
axs_x1t[:, 1], # Right column for x1_t
num_steps,
add_labels=True, # Add labels for RTC (right column)
)
self._plot_denoising_steps_from_tracker(
no_rtc_tracked_steps,
axs_xt[:, 0], # Left column for x_t
axs_vt[:, 0], # Left column for v_t
axs_corr[:, 0], # Left column for correction
axs_x1t[:, 0], # Left column for x1_t
num_steps,
add_labels=False, # No labels for No RTC (left column)
)
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
self._plot_no_rtc_xt_reference(no_rtc_tracked_steps, axs_xt[:, 1], num_steps)
# Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x1_t axes
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x_t axes (no labels for left column)
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
)
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
)
# Add legends outside the plot area for each figure
self._add_figure_legend(fig_xt, axs_xt)
self._add_figure_legend(fig_vt, axs_vt)
self._add_figure_legend(fig_corr, axs_corr)
self._add_figure_legend(fig_x1t, axs_x1t)
# Save denoising plots
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
self._save_figure(fig_corr, os.path.join(self.cfg.output_dir, "denoising_correction_comparison.png"))
self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png"))
def _create_figure(self, title):
fig, axs = plt.subplots(6, 2, figsize=(24, 12))
fig.suptitle(title, fontsize=16)
for ax in axs[:, 0]:
ax.set_title("No RTC (N/A)" if ax == axs[0, 0] else "", fontsize=12)
for ax in axs[:, 1]:
ax.set_title("RTC" if ax == axs[0, 1] else "", fontsize=12)
return fig, axs
def _add_figure_legend(self, fig, axs):
"""Add a legend outside the plot area on the right side.
Args:
fig: Matplotlib figure to add legend to
axs: Array of axes to collect legend handles from
"""
# Collect all handles and labels from the first row of axes (right column)
handles, labels = axs[0, 1].get_legend_handles_labels()
# Remove duplicates while preserving order
seen = set()
unique_handles = []
unique_labels = []
for handle, label in zip(handles, labels, strict=True):
if label not in seen:
seen.add(label)
unique_handles.append(handle)
unique_labels.append(label)
# Add legend outside the plot area (to the right, close to charts)
if unique_handles:
fig.legend(
unique_handles,
unique_labels,
loc="center left",
fontsize=8,
bbox_to_anchor=(0.87, 0.5),
framealpha=0.9,
ncol=1,
)
def _save_figure(self, fig, path):
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend/colorbar on right
fig.savefig(path, dpi=150, bbox_inches="tight")
logging.info(f"Saved figure to {path}")
plt.close(fig)
def _plot_denoising_steps_from_tracker(
self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps, add_labels=True
):
"""Plot denoising steps from tracker data.
Args:
tracked_steps: List of DebugStep objects containing debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
corr_axs: Matplotlib axes for correction plots (array of 6 axes)
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
num_steps: Total number of denoising steps for colormap
add_labels: Whether to add legend labels for the plots
"""
logging.info("=" * 80)
logging.info(f"Plotting {len(tracked_steps)} steps")
debug_steps = tracked_steps
if not debug_steps:
return
# Define colors for different denoise steps (using a colormap)
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
for step_idx, debug_step in enumerate(debug_steps):
color = colors[step_idx % len(colors)]
label = f"Step {step_idx}" if add_labels else None
# Plot x_t
if debug_step.x_t is not None:
RTCDebugVisualizer.plot_waypoints(
xt_axs, debug_step.x_t, start_from=0, color=color, label=label
)
# Plot v_t
if debug_step.v_t is not None:
RTCDebugVisualizer.plot_waypoints(
vt_axs, debug_step.v_t, start_from=0, color=color, label=label
)
# Plot correction on separate axes
if debug_step.correction is not None:
RTCDebugVisualizer.plot_waypoints(
corr_axs,
debug_step.correction,
start_from=0,
color=color,
label=label,
)
# Plot x1_t (predicted state)
if x1t_axs is not None and debug_step.x1_t is not None:
x1t_label = f"x1_t Step {step_idx}" if add_labels else None
RTCDebugVisualizer.plot_waypoints(
x1t_axs,
debug_step.x1_t,
start_from=0,
color=color,
label=x1t_label,
)
# Plot error in orange dashed
if x1t_axs is not None and debug_step.err is not None:
error_chunk = (
debug_step.err[0].cpu().numpy()
if len(debug_step.err.shape) == 3
else debug_step.err.cpu().numpy()
)
num_dims = min(error_chunk.shape[-1], 6)
error_label = f"error Step {step_idx}" if add_labels else None
for j in range(num_dims):
x1t_axs[j].plot(
np.arange(0, error_chunk.shape[0]),
error_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
label=error_label,
)
# Recalculate axis limits after plotting to ensure proper scaling
self._rescale_axes(xt_axs)
self._rescale_axes(vt_axs)
self._rescale_axes(corr_axs)
self._rescale_axes(x1t_axs)
def _plot_no_rtc_xt_reference(self, no_rtc_tracked_steps, xt_axs, num_steps):
"""Plot final no-RTC x_t data as orange dashed line on the RTC chart for comparison.
Args:
no_rtc_tracked_steps: List of DebugStep objects containing no-RTC debug steps
xt_axs: Matplotlib axes for x_t plots (array of 6 axes, right column)
num_steps: Total number of denoising steps for colormap
"""
debug_steps = no_rtc_tracked_steps
if not debug_steps:
return
# Plot only the final x_t step as orange dashed line
final_step = debug_steps[-1]
logging.info("Plotting final no-RTC x_t step as orange dashed reference")
if final_step.x_t is not None:
x_t_chunk = (
final_step.x_t[0].cpu().numpy()
if len(final_step.x_t.shape) == 3
else final_step.x_t.cpu().numpy()
)
num_dims = min(x_t_chunk.shape[-1], 6)
for j in range(num_dims):
xt_axs[j].plot(
np.arange(0, x_t_chunk.shape[0]),
x_t_chunk[:, j],
color="orange",
linestyle="--",
alpha=0.7,
linewidth=2,
label="No RTC (final)" if j == 0 else "",
)
def _rescale_axes(self, axes):
"""Rescale axes to show all data with proper margins.
Args:
axes: Array of matplotlib axes to rescale
"""
for ax in axes:
ax.relim()
ax.autoscale_view()
# Add 10% margin to y-axis for better visualization
ylim = ax.get_ylim()
y_range = ylim[1] - ylim[0]
if y_range > 0: # Avoid division by zero
margin = y_range * 0.1
ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
# Set x-axis ticks to show all integer values
xlim = ax.get_xlim()
max_len = int(xlim[1]) + 1
if max_len > 0:
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
ax.set_xlim(-0.5, max_len - 0.5)
@parser.wrap()
def main(cfg: RTCEvalConfig):
"""Main entry point for RTC evaluation."""
# Set random seed for reproducibility
set_seed(cfg.seed)
init_logging()
logging.info("=" * 80)
logging.info("RTC Dataset Evaluation")
logging.info(f"Config: {cfg}")
logging.info("=" * 80)
evaluator = RTCEvaluator(cfg)
evaluator.run_evaluation()
if __name__ == "__main__":
main()

View File

@@ -1,549 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
This script demonstrates:
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
2. Consuming actions from the policy while the robot executes
3. Periodically requesting new action chunks in the background using threads
4. Managing action buffers and timing for real-time operation
For simulation environments, see eval_with_simulation.py
Usage:
# Run RTC with Real robot with RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot without RTC
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \
--policy.device=mps \
--rtc.enabled=false \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
# Run RTC with Real robot with pi0.5 policy
uv run examples/rtc/eval_with_real_robot.py \
--policy.path=helper2424/pi05_check_rtc \
--policy.device=mps \
--rtc.enabled=true \
--rtc.execution_horizon=20 \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58FA0834591 \
--robot.id=so100_follower \
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
--task="Move green small object into the purple platform" \
--duration=120
"""
import logging
import math
import sys
import time
import traceback
from dataclasses import dataclass, field
from threading import Event, Lock, Thread
import torch
from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor.factory import (
make_default_robot_action_processor,
make_default_robot_observation_processor,
)
from lerobot.rl.process import ProcessSignalHandler
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
so100_follower,
so101_follower,
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: Tensor):
with self.lock:
self.robot.send_action(action)
def observation_features(self) -> list[str]:
with self.lock:
return self.robot.observation_features
def action_features(self) -> list[str]:
with self.lock:
return self.robot.action_features
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies and real robots."""
# Policy configuration
policy: PreTrainedConfig | None = None
# Robot configuration
robot: RobotConfig | None = None
# RTC configuration
rtc: RTCConfig = field(
default_factory=lambda: RTCConfig(
execution_horizon=10,
max_guidance_weight=1.0,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
)
)
# Demo parameters
duration: float = 30.0 # Duration to run the demo (seconds)
fps: float = 10.0 # Action execution frequency (Hz)
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
raise ValueError("Policy path is required")
# Validate that robot configuration is provided
if self.robot is None:
raise ValueError("Robot configuration must be provided")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def is_image_key(k: str) -> bool:
return k.startswith(OBS_IMAGES)
def get_actions(
policy,
robot: RobotWrapper,
robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to request action chunks from the policy.
Args:
policy: The policy instance (SmolVLA, Pi0, etc.)
robot: The robot instance for getting observations
robot_observation_processor: Processor for raw robot observations
action_queue: Queue to put new action chunks
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[GET_ACTIONS] Starting get actions thread")
latency_tracker = LatencyTracker() # Track latency of action chunks
fps = cfg.fps
time_per_chunk = 1.0 / fps
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
policy_device = policy.config.device
# Load preprocessor and postprocessor from pretrained files
# The stats are embedded in the processor .safetensors files
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=None, # Will load from pretrained processor files
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
if not cfg.rtc.enabled:
get_actions_threshold = 0
while not shutdown_event.is_set():
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk)
obs = robot.get_observation()
# Apply robot observation processor
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
dataset_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
obs_with_policy_features["robot_type"] = (
robot.robot.name if hasattr(robot.robot, "name") else ""
)
preproceseded_obs = preprocessor(obs_with_policy_features)
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
# Store original actions (before postprocessing) for RTC
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
else:
# Small sleep to prevent busy waiting
time.sleep(0.1)
logger.info("[GET_ACTIONS] get actions thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def actor_control(
robot: RobotWrapper,
robot_action_processor,
action_queue: ActionQueue,
shutdown_event: Event,
cfg: RTCDemoConfig,
):
"""Thread function to execute actions on the robot.
Args:
robot: The robot instance
action_queue: Queue to get actions from
shutdown_event: Event to signal shutdown
cfg: Demo configuration
"""
try:
logger.info("[ACTOR] Starting actor thread")
action_count = 0
action_interval = 1.0 / cfg.fps
while not shutdown_event.is_set():
start_time = time.perf_counter()
# Try to get an action from the queue with timeout
action = action_queue.get()
if action is not None:
action = action.cpu()
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
dt_s = time.perf_counter() - start_time
time.sleep(max(0, (action_interval - dt_s) - 0.001))
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
logger.error(traceback.format_exc())
sys.exit(1)
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
def demo_cli(cfg: RTCDemoConfig):
"""Main entry point for RTC demo with draccus configuration."""
# Initialize logging
init_logging()
logger.info(f"Using device: {cfg.device}")
# Setup signal handler for graceful shutdown
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
shutdown_event = signal_handler.shutdown_event
policy = None
robot = None
get_actions_thread = None
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processort, as by default if RTC disabled in the config
# The processor won't be created
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot
logger.info(f"Initializing robot: {cfg.robot.type}")
robot = make_robot_from_config(cfg.robot)
robot.connect()
robot_wrapper = RobotWrapper(robot)
# Create robot observation processor
robot_observation_processor = make_default_robot_observation_processor()
robot_action_processor = make_default_robot_action_processor()
# Create action queue for communication between threads
action_queue = ActionQueue(cfg.rtc)
# Start chunk requester thread
get_actions_thread = Thread(
target=get_actions,
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="GetActions",
)
get_actions_thread.start()
logger.info("Started get actions thread")
# Start action executor thread
actor_thread = Thread(
target=actor_control,
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
daemon=True,
name="Actor",
)
actor_thread.start()
logger.info("Started actor thread")
logger.info("Started stop by duration thread")
# Main thread monitors for duration or shutdown
logger.info(f"Running demo for {cfg.duration} seconds...")
start_time = time.time()
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
time.sleep(10)
# Log queue status periodically
if int(time.time() - start_time) % 5 == 0:
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
if time.time() - start_time > cfg.duration:
break
logger.info("Demo duration reached or shutdown requested")
# Signal shutdown
shutdown_event.set()
# Wait for threads to finish
if get_actions_thread and get_actions_thread.is_alive():
logger.info("Waiting for chunk requester thread to finish...")
get_actions_thread.join()
if actor_thread and actor_thread.is_alive():
logger.info("Waiting for action executor thread to finish...")
actor_thread.join()
# Cleanup robot
if robot:
robot.disconnect()
logger.info("Robot disconnected")
logger.info("Cleanup completed")
if __name__ == "__main__":
demo_cli()
logging.info("RTC demo finished")

View File

@@ -1,98 +0,0 @@
"""This script demonstrates how to train ACT Policy on a real-world dataset."""
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
if delta_indices is None:
return [0]
return [i / fps for i in delta_indices]
output_directory = Path("outputs/robot_learning_tutorial/act")
output_directory.mkdir(parents=True, exist_ok=True)
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
dataset_id = "lerobot/svla_so101_pickplace"
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
cfg = ACTConfig(input_features=input_features, output_features=output_features)
policy = ACTPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")

View File

@@ -1,57 +0,0 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_act"
model = ACTPolicy.from_pretrained(model_id)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")

View File

@@ -1,11 +0,0 @@
from lerobot.async_inference.configs import PolicyServerConfig
from lerobot.async_inference.policy_server import serve
host = ... # something like "127.0.0.1" if you're exposing to localhost
port = ... # something like 8080
config = PolicyServerConfig(
host=host,
port=port,
)
serve(config)

View File

@@ -1,55 +0,0 @@
import threading
from lerobot.async_inference.configs import RobotClientConfig
from lerobot.async_inference.helpers import visualize_action_queue_size
from lerobot.async_inference.robot_client import RobotClient
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.robots.so100_follower import SO100FollowerConfig
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
# check the config.json on the Hub for the policy you are using to see the expected camera specs
camera_cfg = {
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
server_address = ... # something like "127.0.0.1:8080" if using localhost
# 3. Create client configuration
client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address=server_address,
policy_device="mps",
policy_type="act",
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g
actions_per_chunk=50, # make sure this is less than the max actions of the policy
)
# 4. Create and start client
client = RobotClient(client_cfg)
# 5. Provide a textual description of the task
task = ...
if client.start():
# Start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
action_receiver_thread.start()
try:
# Run the control loop
client.control_loop(task)
except KeyboardInterrupt:
client.stop()
action_receiver_thread.join()
# (Optionally) plot the action queue size
visualize_action_queue_size(client.action_queue_size)

View File

@@ -1,99 +0,0 @@
"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
if delta_indices is None:
return [0]
return [i / fps for i in delta_indices]
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
output_directory.mkdir(parents=True, exist_ok=True)
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
dataset_id = "lerobot/svla_so101_pickplace"
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
policy = DiffusionPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")

View File

@@ -1,60 +0,0 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_diffusion"
model = DiffusionPolicy.from_pretrained(model_id)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(
model.config, model_id, dataset_stats=dataset_metadata.stats
)
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")

View File

@@ -1,67 +0,0 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/pi0_base"
model = PI0Policy.from_pretrained(model_id)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
print("Episode finished! Starting new episode...")

View File

@@ -1,345 +0,0 @@
import multiprocessing as mp
import signal
from pathlib import Path
from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so100_follower import SO100FollowerConfig
from lerobot.teleoperators.so100_leader import SO100LeaderConfig
from lerobot.teleoperators.utils import TeleopEvents
LOG_EVERY = 10
SEND_EVERY = 10
def run_learner(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_learner: SACPolicy,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer,
lr: float = 3e-4,
batch_size: int = 32,
device: torch.device = "mps",
):
"""The learner process - trains SAC policy on transitions streamed from the actor, updating parameters
for the actor to adopt."""
policy_learner.train()
policy_learner.to(device)
# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
training_step = 0
while not shutdown_event.is_set():
# retrieve incoming transitions from the actor process
try:
transitions = transitions_queue.get(timeout=0.1)
for transition in transitions:
# HIL-SERL: Add ALL transitions to online buffer
online_buffer.add(**transition)
# HIL-SERL: Add ONLY human intervention transitions to offline buffer
is_intervention = transition.get("complementary_info", {}).get("is_intervention", False)
if is_intervention:
offline_buffer.add(**transition)
print(
f"[LEARNER] Human intervention detected! Added to offline buffer (now {len(offline_buffer)} transitions)"
)
except Empty:
pass # No transitions available, continue
# Train if we have enough data
if len(online_buffer) >= policy_learner.config.online_step_before_learning:
# Sample from online buffer (autonomous + human data)
online_batch = online_buffer.sample(batch_size // 2)
# Sample from offline buffer (human demonstrations only, either precollected or at runtime)
offline_batch = offline_buffer.sample(batch_size // 2)
# Combine batches - this is the key HIL-SERL mechanism!
batch = {}
for key in online_batch:
if key in offline_batch:
batch[key] = torch.cat([online_batch[key], offline_batch[key]], dim=0)
else:
batch[key] = online_batch[key]
loss, _ = policy_learner.forward(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_step += 1
if training_step % LOG_EVERY == 0:
print(
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)
# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
pass
print("[LEARNER] Learner process finished")
def run_actor(
transitions_queue: mp.Queue,
parameters_queue: mp.Queue,
shutdown_event: mp.Event,
policy_actor: SACPolicy,
reward_classifier: Classifier,
env_cfg: HILSerlRobotEnvConfig,
device: torch.device = "mps",
output_directory: Path | None = None,
):
"""The actor process - interacts with environment and collects data.
The policy is frozen and only the parameters are updated, popping the most recent ones from a queue."""
policy_actor.eval()
policy_actor.to(device)
reward_classifier.eval()
reward_classifier.to(device)
# Create robot environment inside the actor process
env, teleop_device = make_robot_env(env_cfg)
try:
for episode in range(MAX_EPISODES):
if shutdown_event.is_set():
break
obs, _info = env.reset()
episode_reward = 0.0
step = 0
episode_transitions = []
print(f"[ACTOR] Starting episode {episode + 1}")
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass
# Get action from policy
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action = action_tensor.squeeze(0).cpu().numpy()
# Step environment
next_obs, _env_reward, terminated, truncated, _info = env.step(action)
done = terminated or truncated
# Predict reward
policy_next_obs = make_policy_obs(next_obs, device=device)
reward = reward_classifier.predict_reward(policy_next_obs)
if reward >= 1.0 and not done: # success detected! halt episode
terminated = True
done = True
# In HIL-SERL, human interventions come from the teleop device
is_intervention = False
if hasattr(teleop_device, "get_teleop_events"):
# Real intervention detection from teleop device
teleop_events = teleop_device.get_teleop_events()
is_intervention = teleop_events.get(TeleopEvents.IS_INTERVENTION, False)
# Store transition with intervention metadata
transition = {
"state": policy_obs,
"action": action,
"reward": float(reward) if hasattr(reward, "item") else reward,
"next_state": policy_next_obs,
"done": done,
"truncated": truncated,
"complementary_info": {
"is_intervention": is_intervention,
},
}
episode_transitions.append(transition)
episode_reward += reward
step += 1
obs = next_obs
if done:
break
# Send episode transitions to learner
transitions_queue.put_nowait(episode_transitions)
except KeyboardInterrupt:
print("[ACTOR] Interrupted by user")
finally:
# Clean up
if hasattr(env, "robot") and env.robot.is_connected:
env.robot.disconnect()
if teleop_device and hasattr(teleop_device, "disconnect"):
teleop_device.disconnect()
if output_directory is not None:
policy_actor.save_pretrained(output_directory)
print(f"[ACTOR] Latest actor policy saved at: {output_directory}")
print("[ACTOR] Actor process finished")
def make_policy_obs(obs, device: torch.device = "cpu"):
return {
"observation.state": torch.from_numpy(obs["agent_pos"]).float().unsqueeze(0).to(device),
**{
f"observation.image.{k}": torch.from_numpy(obs["pixels"][k]).float().unsqueeze(0).to(device)
for k in obs["pixels"]
},
}
"""Main function - coordinates actor and learner processes."""
device = "mps" # or "cuda" or "cpu"
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
output_directory.mkdir(parents=True, exist_ok=True)
# find ports using lerobot-find-port
follower_port = ...
leader_port = ...
# the robot ids are used the load the right calibration files
follower_id = ...
leader_id = ...
# A pretrained model (to be used in-distribution!)
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
reward_classifier.to(device)
reward_classifier.eval()
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# Robot and environment configuration
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
# Create robot environment
env, teleop_device = make_robot_env(env_cfg)
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
action_features = hw_to_dataset_features(env.robot.action_features, "action")
# Create SAC policy for action selection
policy_cfg = SACConfig(
device=device,
input_features=obs_features,
output_features=action_features,
)
policy_actor = SACPolicy(policy_cfg)
policy_learner = SACPolicy(policy_cfg)
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
# Online buffer: initialized from scratch
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
)
# Create communication channels between learner and actor processes
transitions_queue = mp.Queue(maxsize=10)
parameters_queue = mp.Queue(maxsize=2)
shutdown_event = mp.Event()
# Signal handler for graceful shutdown
def signal_handler(sig):
print(f"\nSignal {sig} received, shutting down...")
shutdown_event.set()
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Create processes
learner_process = mp.Process(
target=run_learner,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_learner,
online_replay_buffer,
offline_replay_buffer,
),
kwargs={"device": device}, # can run on accelerated hardware for training
)
actor_process = mp.Process(
target=run_actor,
args=(
transitions_queue,
parameters_queue,
shutdown_event,
policy_actor,
reward_classifier,
env_cfg,
output_directory,
),
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
)
learner_process.start()
actor_process.start()
try:
# Wait for actor to finish (it controls the episode loop)
actor_process.join()
shutdown_event.set()
learner_process.join(timeout=10)
except KeyboardInterrupt:
print("Main process interrupted")
shutdown_event.set()
actor_process.join(timeout=5)
learner_process.join(timeout=10)
finally:
if learner_process.is_alive():
learner_process.terminate()
if actor_process.is_alive():
actor_process.terminate()

View File

@@ -1,62 +0,0 @@
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
# Device to use for training
device = "mps" # or "cuda", or "cpu"
# Load the dataset used for training
repo_id = "lerobot/example_hil_serl_dataset"
dataset = LeRobotDataset(repo_id)
# Configure the policy to extract features from the image frames
camera_keys = dataset.meta.camera_keys
config = RewardClassifierConfig(
num_cameras=len(camera_keys),
device=device,
# backbone model to extract features from the image frames
model_name="microsoft/resnet-18",
)
# Make policy, preprocessor, and optimizer
policy = make_policy(config, ds_meta=dataset.meta)
optimizer = config.get_optimizer_preset().build(policy.parameters())
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
# Instantiate a dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
total_loss = 0
total_accuracy = 0
for batch in dataloader:
# Preprocess the batch and move it to the correct device.
batch = preprocessor(batch)
# Forward pass
loss, output_dict = policy.forward(batch)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_accuracy += output_dict["accuracy"]
avg_loss = total_loss / len(dataloader)
avg_accuracy = total_accuracy / len(dataloader)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
print("Training finished!")
# You can now save the trained policy.
policy.push_to_hub(classifier_id)

View File

@@ -1,66 +0,0 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "lerobot/smolvla_base"
model = SmolVLAPolicy.from_pretrained(model_id)
preprocess, postprocess = make_pre_post_processors(
model.config,
model_id,
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
preprocessor_overrides={"device_processor": {"device": str(device)}},
)
# find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
task = "" # something like "pick the red block"
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
# This is used to match the raw observation keys to the keys expected by the policy
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_features)
robot.send_action(action)
print("Episode finished! Starting new episode...")

View File

@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
[project]
name = "lerobot"
version = "0.4.2"
version = "0.3.4"
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
readme = "README.md"
license = { text = "Apache-2.0" }
@@ -74,15 +74,15 @@ dependencies = [
"packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0",
"pyserial>=3.5,<4.0",
"wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf)
"wandb>=0.20.0,<0.23.0",
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
"draccus==0.10.0", # TODO: Remove ==
"gymnasium>=1.1.1,<2.0.0",
"rerun-sdk>=0.24.0,<0.27.0",
"gymnasium>=1.0.0",
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
# Support dependencies
"deepdiff>=7.0.1,<9.0.0",
@@ -97,7 +97,7 @@ dependencies = [
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -113,23 +113,17 @@ intelrealsense = [
"pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'",
"pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'",
]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"]
# stretch = [
# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'",
# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'",
# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"
# ] # TODO: Currently not supported
# Policies
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
groot = [
"lerobot[transformers-dep]",
"peft>=0.13.0,<1.0.0",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"safetensors>=0.4.3,<1.0.0",
"Pillow>=10.0.0,<13.0.0",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
@@ -142,8 +136,8 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0"]
metaworld = ["metaworld==3.0.0"]
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
metaworld = ["metaworld>=3.0.0"]
# All
all = [
@@ -156,7 +150,6 @@ all = [
"lerobot[intelrealsense]",
"lerobot[pi]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[hilserl]",
"lerobot[async]",
"lerobot[dev]",
@@ -241,6 +234,9 @@ exclude_dirs = [
"tests",
"benchmarks",
"src/lerobot/datasets/push_dataset_to_hub",
"src/lerobot/datasets/v2/convert_dataset_v1_to_v2",
"src/lerobot/policies/pi0/conversion_scripts",
"src/lerobot/scripts/push_dataset_to_hub.py",
]
skips = ["B101", "B311", "B404", "B603", "B615"]
@@ -255,8 +251,6 @@ default.extend-ignore-identifiers-re = [
"pn",
"ser",
"ein",
"thw",
"inpt",
]
# TODO: Uncomment when ready to use
@@ -295,6 +289,7 @@ ignore_errors = true
[[tool.mypy.overrides]]
module = "lerobot.envs.*"
# Enable type checking only for the envs module
ignore_errors = false
@@ -302,22 +297,17 @@ ignore_errors = false
# module = "lerobot.utils.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.configs.*"
ignore_errors = false
# extra strictness for configs
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
# [[tool.mypy.overrides]]
# module = "lerobot.configs.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.optim.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.model.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.model.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.processor.*"
@@ -327,9 +317,9 @@ ignore_errors = false
# module = "lerobot.datasets.*"
# ignore_errors = false
[[tool.mypy.overrides]]
module = "lerobot.cameras.*"
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.cameras.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"

View File

@@ -1,4 +1,3 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
@@ -13,62 +12,47 @@ absl-py==2.3.1
# dm-tree
# labmaze
# mujoco
# tensorboard
accelerate==1.11.0
# via
# lerobot
# peft
accelerate==1.9.0
# via lerobot
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.1
aiohttp==3.12.15
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
# via
# starlette
# watchfiles
asttokens==3.0.0
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.4.0
attrs==25.3.0
# via
# aiohttp
# dm-tree
# jsonlines
# jsonschema
# referencing
# rerun-sdk
av==15.1.0
av==15.0.0
# via lerobot
bddl==1.0.1
# via libero
certifi==2025.10.5
blinker==1.9.0
# via flask
certifi==2025.7.14
# via
# requests
# sentry-sdk
cffi==2.0.0
cffi==1.17.1
# via pymunk
cfgv==3.4.0
# via pre-commit
charset-normalizer==3.4.4
charset-normalizer==3.4.2
# via requests
click==8.3.0
click==8.2.1
# via
# uvicorn
# flask
# wandb
cloudpickle==3.1.1
# via
# gymnasium
# libero
cmake==4.1.0
# via gymnasium
cmake==4.0.3
# via lerobot
cmeel==0.57.3
# via
@@ -110,27 +94,27 @@ coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.11.0
coverage[toml]==7.10.1
# via pytest-cov
cycler==0.12.1
# via matplotlib
datasets==4.1.1
datasets==3.6.0
# via lerobot
debugpy==1.8.17
debugpy==1.8.15
# via lerobot
decorator==5.2.1
# via ipython
deepdiff==8.6.1
deepdiff==8.5.0
# via lerobot
diffusers==0.35.2
diffusers==0.34.0
# via lerobot
dill==0.4.0
dill==0.3.8
# via
# datasets
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.34
dm-control==1.0.14
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -138,45 +122,29 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
dynamixel-sdk==3.8.4
dynamixel-sdk==3.7.31
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
# via
# lerobot
# libero
# via lerobot
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
# via mujoco
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.1
executing==2.2.0
# via stack-data
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastjsonschema==2.21.2
# via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.20.0
filelock==3.18.0
# via
# datasets
# diffusers
@@ -184,25 +152,24 @@ filelock==3.20.0
# torch
# transformers
# virtualenv
fonttools==4.60.1
flask==3.1.1
# via lerobot
fonttools==4.59.0
# via matplotlib
frozenlist==1.8.0
frozenlist==1.7.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.9.0
fsspec[http]==2025.3.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
# via wandb
glfw==2.10.0
glfw==2.9.0
# via
# dm-control
# mujoco
@@ -210,79 +177,61 @@ grpcio==1.73.1
# via
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
grpcio-tools==1.73.1
# via
# lerobot
# reachy2-sdk-api
gym-aloha==0.1.3
# via lerobot
gym-hil==0.1.13
gym-aloha==0.1.1
# via lerobot
gym-pusht==0.1.6
gym-hil==0.1.10
# via lerobot
gymnasium==1.2.1
gym-pusht==0.1.5
# via lerobot
gym-xarm==0.1.1
# via lerobot
gymnasium==0.29.1
# via
# gym-aloha
# gym-hil
# gym-pusht
# gym-xarm
# gymnasium-robotics
# lerobot
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via robomimic
hebi-py==2.11.0
# via lerobot
# pettingzoo
gymnasium-robotics==1.2.4
# via gym-xarm
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.10
hf-xet==1.1.5
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
huggingface-hub[cli,hf-transfer]==0.34.3
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
hydra-core==1.3.2
# via libero
identify==2.6.15
identify==2.6.12
# via pre-commit
idna==3.11
idna==3.10
# via
# anyio
# requests
# yarl
imageio[ffmpeg]==2.37.0
# via
# gym-aloha
# gym-hil
# gymnasium-robotics
# lerobot
# metaworld
# robomimic
# scikit-image
imageio-ffmpeg==0.6.0
# via
# imageio
# robomimic
# via imageio
importlib-metadata==8.7.0
# via diffusers
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
iniconfig==2.1.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
@@ -290,71 +239,50 @@ ipython==8.37.0
# via meshcat
ischedule==1.2.7
# via placo
itsdangerous==2.2.0
# via flask
jedi==0.19.2
# via ipython
jinja2==3.1.6
# via torch
# via
# flask
# gymnasium-robotics
# torch
jsonlines==4.0.0
# via lerobot
jsonschema==4.25.1
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
# via bddl
kiwisolver==1.4.9
kiwisolver==1.4.8
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
# via scikit-image
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
# via numba
lxml==6.0.2
lxml==6.0.0
# via dm-control
markdown==3.9
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
markupsafe==3.0.3
markupsafe==3.0.2
# via
# flask
# jinja2
# werkzeug
matplotlib==3.10.7
# via
# lerobot
# libero
matplotlib-inline==0.2.1
matplotlib==3.10.5
# via lerobot
matplotlib-inline==0.1.7
# via ipython
mdit-py-plugins==0.5.0
# via jupytext
mdurl==0.1.2
# via markdown-it-py
mergedeep==1.3.4
# via draccus
meshcat==0.3.2
# via placo
metaworld==3.0.0
# via lerobot
mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==3.3.7
mujoco==2.3.7
# via
# dm-control
# gym-aloha
# gym-hil
# libero
# metaworld
# robosuite
multidict==6.7.0
# gym-xarm
# gymnasium-robotics
multidict==6.6.3
# via
# aiohttp
# yarl
@@ -362,25 +290,17 @@ multiprocess==0.70.16
# via datasets
mypy-extensions==1.1.0
# via typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
# via
# bddl
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
# via robosuite
numpy==2.2.6
# via
# accelerate
# bddl
# cmeel-boost
# contourpy
# datasets
@@ -389,43 +309,25 @@ numpy==2.2.6
# dm-env
# dm-tree
# gymnasium
# h5py
# hebi-py
# gymnasium-robotics
# imageio
# labmaze
# libero
# matplotlib
# meshcat
# metaworld
# mujoco
# numba
# opencv-python
# opencv-python-headless
# pandas
# peft
# pyquaternion
# reachy2-sdk
# pettingzoo
# rerun-sdk
# robomimic
# robosuite
# scikit-image
# scipy
# shapely
# teleop
# tensorboard
# tensorboardx
# tifffile
# torchvision
# transformers
# transforms3d
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
# via
# gym-pusht
# libero
# reachy2-sdk
# robosuite
# via gym-pusht
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -435,63 +337,53 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
# hydra-core
# jupytext
# lazy-loader
# lerobot
# matplotlib
# peft
# pytest
# reachy2-sdk
# scikit-image
# tensorboard
# tensorboardx
# transformers
# wandb
pandas==2.3.3
pandas==2.3.1
# via
# datasets
# lerobot
parso==0.8.5
parso==0.8.4
# via jedi
peft==0.17.1
# via lerobot
pettingzoo==1.24.3
# via gymnasium-robotics
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==12.0.0
pillow==11.3.0
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# rerun-sdk
# robosuite
# scikit-image
# tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
# via lerobot
platformdirs==4.5.0
platformdirs==4.3.8
# via
# jupyter-core
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.3.0
pre-commit==4.2.0
# via lerobot
prompt-toolkit==3.0.52
prompt-toolkit==3.0.51
# via
# inquirerpy
# ipython
propcache==0.4.1
propcache==0.3.2
# via
# aiohttp
# yarl
@@ -500,17 +392,11 @@ protobuf==6.31.0
# dm-control
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
# tensorboardx
# wandb
psutil==7.1.1
psutil==7.0.0
# via
# accelerate
# imageio
# peft
# robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
@@ -519,13 +405,11 @@ pyarrow==21.0.0
# via
# datasets
# rerun-sdk
pycparser==2.23
pycparser==2.22
# via cffi
pydantic==2.12.3
# via
# fastapi
# wandb
pydantic-core==2.41.4
pydantic==2.11.7
# via wandb
pydantic-core==2.33.2
# via pydantic
pygame==2.6.1
# via
@@ -540,42 +424,40 @@ pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.4.1
pyngrok==7.2.12
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
pyobjc-core==12.0
pyobjc-core==11.1
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-cocoa
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-applicationservices==12.0
pyobjc-framework-applicationservices==11.1
# via pynput
pyobjc-framework-cocoa==12.0
pyobjc-framework-cocoa==11.1
# via
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
# pyobjc-framework-quartz
pyobjc-framework-coretext==12.0
pyobjc-framework-coretext==11.1
# via pyobjc-framework-applicationservices
pyobjc-framework-quartz==12.0
pyobjc-framework-quartz==11.1
# via
# pynput
# pyobjc-framework-applicationservices
# pyobjc-framework-coretext
pyopengl==3.1.10
pyopengl==3.1.9
# via
# dm-control
# mujoco
pyparsing==3.2.5
pyparsing==3.2.3
# via
# dm-control
# matplotlib
pyquaternion==0.9.9
# via reachy2-sdk
pyrealsense2-macosx==2.54.2
# via lerobot
pyserial==3.5
@@ -583,14 +465,12 @@ pyserial==3.5
# dynamixel-sdk
# feetech-servo-sdk
# lerobot
pytest==8.4.2
pytest==8.4.1
# via
# bddl
# lerobot
# pytest-cov
# pytest-timeout
# teleop
pytest-cov==7.0.0
pytest-cov==6.2.1
# via lerobot
pytest-timeout==2.4.0
# via lerobot
@@ -598,73 +478,46 @@ python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
python-dotenv==1.1.1
# via uvicorn
pytz==2025.2
# via pandas
pyyaml==6.0.3
pyyaml==6.0.2
# via
# accelerate
# datasets
# draccus
# hebi-py
# huggingface-hub
# jupytext
# omegaconf
# peft
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
pyyaml-include==1.4.1
# via draccus
pyzmq==27.1.0
pyzmq==27.0.0
# via
# lerobot
# meshcat
reachy2-sdk==1.0.14
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
regex==2025.7.34
# via
# diffusers
# transformers
requests==2.32.5
requests==2.32.4
# via
# datasets
# diffusers
# dm-control
# huggingface-hub
# teleop
# transformers
# wandb
rerun-sdk==0.26.1
rerun-sdk==0.22.1
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
robomimic==0.2.0
# via libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via
# jsonschema
# referencing
safetensors==0.6.2
safetensors==0.5.3
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
@@ -673,12 +526,10 @@ scikit-image==0.25.2
scipy==1.15.3
# via
# dm-control
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.42.1
sentry-sdk==2.34.1
# via wandb
shapely==2.1.2
shapely==2.1.1
# via gym-pusht
six==1.17.0
# via
@@ -686,106 +537,64 @@ six==1.17.0
# python-dateutil
smmap==5.0.2
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
# via fastapi
sympy==1.14.0
# via torch
teleop==0.1.2
# via lerobot
tensorboard==2.20.0
# via robomimic
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
# via lerobot
tifffile==2025.5.10
# via scikit-image
timm==1.0.20
# via lerobot
tokenizers==0.22.1
tokenizers==0.21.4
# via transformers
toml==0.10.2
# via draccus
tomli==2.3.0
tomli==2.2.1
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
# via
# accelerate
# lerobot
# peft
# robomimic
# thop
# timm
# torchvision
torchcodec==0.5
# via lerobot
torchvision==0.22.1
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
# via lerobot
tornado==6.5.1
# via meshcat
tqdm==4.67.1
# via
# datasets
# dm-control
# huggingface-hub
# peft
# robomimic
# transformers
traitlets==5.14.3
# via
# ipython
# jupyter-core
# matplotlib-inline
# nbformat
transformers==4.57.1
# via
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
typing-extensions==4.15.0
transformers==4.51.3
# via lerobot
typing-extensions==4.14.1
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# pydantic
# pydantic-core
# referencing
# rerun-sdk
# starlette
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.2
typing-inspection==0.4.1
# via pydantic
tzdata==2025.2
# via pandas
@@ -795,36 +604,22 @@ urllib3==2.5.0
# via
# requests
# sentry-sdk
uvicorn[standard]==0.38.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
virtualenv==20.32.0
# via pre-commit
wandb==0.21.4
# via
# lerobot
# libero
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
wandb==0.21.0
# via lerobot
wcwidth==0.2.13
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
# via uvicorn
werkzeug==3.1.3
# via tensorboard
wrapt==2.0.0
# via flask
wrapt==1.17.2
# via dm-tree
xxhash==3.6.0
xxhash==3.5.0
# via datasets
yarl==1.22.0
yarl==1.20.1
# via aiohttp
zipp==3.23.0
# via
# etils
# importlib-metadata
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@@ -13,62 +13,47 @@ absl-py==2.3.1
# dm-tree
# labmaze
# mujoco
# tensorboard
accelerate==1.11.0
# via
# lerobot
# peft
accelerate==1.9.0
# via lerobot
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.1
aiohttp==3.12.15
# via fsspec
aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via
# hydra-core
# omegaconf
anyio==4.11.0
# via
# starlette
# watchfiles
asttokens==3.0.0
# via stack-data
async-timeout==5.0.1
# via aiohttp
attrs==25.4.0
attrs==25.3.0
# via
# aiohttp
# dm-tree
# jsonlines
# jsonschema
# referencing
# rerun-sdk
av==15.1.0
av==15.0.0
# via lerobot
bddl==1.0.1
# via libero
certifi==2025.10.5
blinker==1.9.0
# via flask
certifi==2025.7.14
# via
# requests
# sentry-sdk
cffi==2.0.0
cffi==1.17.1
# via pymunk
cfgv==3.4.0
# via pre-commit
charset-normalizer==3.4.4
charset-normalizer==3.4.2
# via requests
click==8.3.0
click==8.2.1
# via
# uvicorn
# flask
# wandb
cloudpickle==3.1.1
# via
# gymnasium
# libero
cmake==4.1.0
# via gymnasium
cmake==4.0.3
# via lerobot
cmeel==0.57.3
# via
@@ -110,29 +95,27 @@ coal-library==3.0.1
# via pin
contourpy==1.3.2
# via matplotlib
coverage[toml]==7.11.0
coverage[toml]==7.10.1
# via pytest-cov
cycler==0.12.1
# via matplotlib
datasets==4.1.1
datasets==3.6.0
# via lerobot
debugpy==1.8.17
debugpy==1.8.15
# via lerobot
decorator==5.2.1
# via ipython
decord==0.6.0
deepdiff==8.5.0
# via lerobot
deepdiff==8.6.1
diffusers==0.34.0
# via lerobot
diffusers==0.35.2
# via lerobot
dill==0.4.0
dill==0.3.8
# via
# datasets
# multiprocess
distlib==0.4.0
# via virtualenv
dm-control==1.0.34
dm-control==1.0.14
# via gym-aloha
dm-env==1.6
# via dm-control
@@ -140,48 +123,31 @@ dm-tree==0.1.9
# via
# dm-control
# dm-env
# lerobot
docopt==0.6.2
# via num2words
draccus==0.10.0
# via lerobot
dynamixel-sdk==3.8.4
dynamixel-sdk==3.7.31
# via lerobot
easydict==1.13
# via libero
egl-probe @ git+https://github.com/huggingface/egl_probe.git
# via
# libero
# robomimic
eigenpy==3.10.3
# via coal-library
einops==0.8.1
# via
# flash-attn
# lerobot
# libero
# via lerobot
eiquadprog==1.2.9
# via placo
etils[epath,epy]==1.13.0
# via mujoco
evdev==1.9.2
# via pynput
exceptiongroup==1.3.0
# via
# anyio
# ipython
# pytest
executing==2.2.1
executing==2.2.0
# via stack-data
farama-notifications==0.0.4
# via gymnasium
fastapi==0.119.1
# via teleop
fastjsonschema==2.21.2
# via nbformat
feetech-servo-sdk==1.0.0
# via lerobot
filelock==3.20.0
filelock==3.18.0
# via
# datasets
# diffusers
@@ -189,27 +155,24 @@ filelock==3.20.0
# torch
# transformers
# virtualenv
flash-attn==2.8.3
flask==3.1.1
# via lerobot
fonttools==4.60.1
fonttools==4.59.0
# via matplotlib
frozenlist==1.8.0
frozenlist==1.7.0
# via
# aiohttp
# aiosignal
fsspec[http]==2025.9.0
fsspec[http]==2025.3.0
# via
# datasets
# etils
# huggingface-hub
# torch
future==1.0.0
# via libero
gitdb==4.0.12
# via gitpython
gitpython==3.1.45
# via wandb
glfw==2.10.0
glfw==2.9.0
# via
# dm-control
# mujoco
@@ -217,79 +180,61 @@ grpcio==1.73.1
# via
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
grpcio-tools==1.73.1
# via
# lerobot
# reachy2-sdk-api
gym-aloha==0.1.3
# via lerobot
gym-hil==0.1.13
gym-aloha==0.1.1
# via lerobot
gym-pusht==0.1.6
gym-hil==0.1.10
# via lerobot
gymnasium==1.2.1
gym-pusht==0.1.5
# via lerobot
gym-xarm==0.1.1
# via lerobot
gymnasium==0.29.1
# via
# gym-aloha
# gym-hil
# gym-pusht
# gym-xarm
# gymnasium-robotics
# lerobot
# libero
# metaworld
h11==0.16.0
# via uvicorn
h5py==3.15.1
# via robomimic
hebi-py==2.11.0
# via lerobot
# pettingzoo
gymnasium-robotics==1.2.4
# via gym-xarm
hf-transfer==0.1.9
# via huggingface-hub
hf-xet==1.1.10
hf-xet==1.1.5
# via huggingface-hub
hidapi==0.14.0.post4
# via
# gym-hil
# lerobot
httptools==0.7.1
# via uvicorn
huggingface-hub[cli,hf-transfer]==0.35.3
huggingface-hub[cli,hf-transfer]==0.34.3
# via
# accelerate
# datasets
# diffusers
# lerobot
# peft
# timm
# tokenizers
# transformers
hydra-core==1.3.2
# via libero
identify==2.6.15
identify==2.6.12
# via pre-commit
idna==3.11
idna==3.10
# via
# anyio
# requests
# yarl
imageio[ffmpeg]==2.37.0
# via
# gym-aloha
# gym-hil
# gymnasium-robotics
# lerobot
# metaworld
# robomimic
# scikit-image
imageio-ffmpeg==0.6.0
# via
# imageio
# robomimic
# via imageio
importlib-metadata==8.7.0
# via diffusers
importlib-resources==6.5.2
# via etils
iniconfig==2.3.0
iniconfig==2.1.0
# via pytest
inquirerpy==0.3.4
# via huggingface-hub
@@ -297,71 +242,50 @@ ipython==8.37.0
# via meshcat
ischedule==1.2.7
# via placo
itsdangerous==2.2.0
# via flask
jedi==0.19.2
# via ipython
jinja2==3.1.6
# via torch
# via
# flask
# gymnasium-robotics
# torch
jsonlines==4.0.0
# via lerobot
jsonschema==4.25.1
# via nbformat
jsonschema-specifications==2025.9.1
# via jsonschema
jupyter-core==5.9.1
# via nbformat
jupytext==1.18.1
# via bddl
kiwisolver==1.4.9
kiwisolver==1.4.8
# via matplotlib
labmaze==1.0.6
# via dm-control
lazy-loader==0.4
# via scikit-image
libero @ git+https://github.com/huggingface/lerobot-libero.git@main
# via lerobot
llvmlite==0.45.1
# via numba
lxml==6.0.2
lxml==6.0.0
# via dm-control
markdown==3.9
# via tensorboard
markdown-it-py==4.0.0
# via
# jupytext
# mdit-py-plugins
markupsafe==3.0.3
markupsafe==3.0.2
# via
# flask
# jinja2
# werkzeug
matplotlib==3.10.7
# via
# lerobot
# libero
matplotlib-inline==0.2.1
matplotlib==3.10.5
# via lerobot
matplotlib-inline==0.1.7
# via ipython
mdit-py-plugins==0.5.0
# via jupytext
mdurl==0.1.2
# via markdown-it-py
mergedeep==1.3.4
# via draccus
meshcat==0.3.2
# via placo
metaworld==3.0.0
# via lerobot
mock-serial==0.0.1
# via lerobot
mpmath==1.3.0
# via sympy
mujoco==3.3.7
mujoco==2.3.7
# via
# dm-control
# gym-aloha
# gym-hil
# libero
# metaworld
# robosuite
multidict==6.7.0
# gym-xarm
# gymnasium-robotics
multidict==6.6.3
# via
# aiohttp
# yarl
@@ -369,63 +293,42 @@ multiprocess==0.70.16
# via datasets
mypy-extensions==1.1.0
# via typing-inspect
nbformat==5.10.4
# via jupytext
networkx==3.4.2
# via
# bddl
# scikit-image
# torch
ninja==1.13.0
# via lerobot
nodeenv==1.9.1
# via pre-commit
num2words==0.5.14
# via lerobot
numba==0.62.1
# via robosuite
numpy==2.2.6
# via
# accelerate
# bddl
# cmeel-boost
# contourpy
# datasets
# decord
# diffusers
# dm-control
# dm-env
# dm-tree
# gymnasium
# h5py
# hebi-py
# gymnasium-robotics
# imageio
# labmaze
# libero
# matplotlib
# meshcat
# metaworld
# mujoco
# numba
# opencv-python
# opencv-python-headless
# pandas
# peft
# pyquaternion
# reachy2-sdk
# pettingzoo
# rerun-sdk
# robomimic
# robosuite
# scikit-image
# scipy
# shapely
# teleop
# tensorboard
# tensorboardx
# tifffile
# torchvision
# transformers
# transforms3d
nvidia-cublas-cu12==12.6.4.1
# via
# nvidia-cudnn-cu12
@@ -463,14 +366,8 @@ nvidia-nvjitlink-cu12==12.6.85
# torch
nvidia-nvtx-cu12==12.6.77
# via torch
omegaconf==2.3.0
# via hydra-core
opencv-python==4.12.0.88
# via
# gym-pusht
# libero
# reachy2-sdk
# robosuite
# via gym-pusht
opencv-python-headless==4.12.0.88
# via lerobot
orderly-set==5.5.0
@@ -480,63 +377,53 @@ packaging==25.0
# accelerate
# datasets
# huggingface-hub
# hydra-core
# jupytext
# lazy-loader
# lerobot
# matplotlib
# peft
# pytest
# reachy2-sdk
# scikit-image
# tensorboard
# tensorboardx
# transformers
# wandb
pandas==2.3.3
pandas==2.3.1
# via
# datasets
# lerobot
parso==0.8.5
parso==0.8.4
# via jedi
peft==0.17.1
# via lerobot
pettingzoo==1.24.3
# via gymnasium-robotics
pexpect==4.9.0
# via ipython
pfzy==0.3.4
# via inquirerpy
pillow==12.0.0
pillow==11.3.0
# via
# diffusers
# imageio
# lerobot
# matplotlib
# meshcat
# rerun-sdk
# robosuite
# scikit-image
# tensorboard
# torchvision
pin==3.4.0
# via placo
placo==0.9.14
# via lerobot
platformdirs==4.5.0
platformdirs==4.3.8
# via
# jupyter-core
# virtualenv
# wandb
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.3.0
pre-commit==4.2.0
# via lerobot
prompt-toolkit==3.0.52
prompt-toolkit==3.0.51
# via
# inquirerpy
# ipython
propcache==0.4.1
propcache==0.3.2
# via
# aiohttp
# yarl
@@ -545,17 +432,11 @@ protobuf==6.31.0
# dm-control
# grpcio-tools
# lerobot
# reachy2-sdk
# reachy2-sdk-api
# tensorboard
# tensorboardx
# wandb
psutil==7.1.1
psutil==7.0.0
# via
# accelerate
# imageio
# peft
# robomimic
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
@@ -564,13 +445,11 @@ pyarrow==21.0.0
# via
# datasets
# rerun-sdk
pycparser==2.23
pycparser==2.22
# via cffi
pydantic==2.12.3
# via
# fastapi
# wandb
pydantic-core==2.41.4
pydantic==2.11.7
# via wandb
pydantic-core==2.33.2
# via pydantic
pygame==2.6.1
# via
@@ -585,22 +464,20 @@ pymunk==6.11.1
# via
# gym-pusht
# lerobot
pyngrok==7.4.1
pyngrok==7.2.12
# via meshcat
pynput==1.8.1
# via
# gym-hil
# lerobot
pyopengl==3.1.10
pyopengl==3.1.9
# via
# dm-control
# mujoco
pyparsing==3.2.5
pyparsing==3.2.3
# via
# dm-control
# matplotlib
pyquaternion==0.9.9
# via reachy2-sdk
pyrealsense2==2.56.5.9235
# via lerobot
pyserial==3.5
@@ -608,14 +485,12 @@ pyserial==3.5
# dynamixel-sdk
# feetech-servo-sdk
# lerobot
pytest==8.4.2
pytest==8.4.1
# via
# bddl
# lerobot
# pytest-cov
# pytest-timeout
# teleop
pytest-cov==7.0.0
pytest-cov==6.2.1
# via lerobot
pytest-timeout==2.4.0
# via lerobot
@@ -623,75 +498,48 @@ python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
python-dotenv==1.1.1
# via uvicorn
python-xlib==0.33
# via pynput
pytz==2025.2
# via pandas
pyyaml==6.0.3
pyyaml==6.0.2
# via
# accelerate
# datasets
# draccus
# hebi-py
# huggingface-hub
# jupytext
# omegaconf
# peft
# pre-commit
# pyngrok
# pyyaml-include
# timm
# transformers
# uvicorn
# wandb
pyyaml-include==1.4.1
# via draccus
pyzmq==27.1.0
pyzmq==27.0.0
# via
# lerobot
# meshcat
reachy2-sdk==1.0.14
# via lerobot
reachy2-sdk-api==1.0.21
# via reachy2-sdk
referencing==0.37.0
# via
# jsonschema
# jsonschema-specifications
regex==2025.10.23
regex==2025.7.34
# via
# diffusers
# transformers
requests==2.32.5
requests==2.32.4
# via
# datasets
# diffusers
# dm-control
# huggingface-hub
# teleop
# transformers
# wandb
rerun-sdk==0.26.1
rerun-sdk==0.22.1
# via lerobot
rhoban-cmeel-jsoncpp==1.9.4.9
# via placo
robomimic==0.2.0
# via libero
robosuite==1.4.0
# via libero
rpds-py==0.28.0
# via
# jsonschema
# referencing
safetensors==0.6.2
safetensors==0.5.3
# via
# accelerate
# diffusers
# lerobot
# peft
# timm
# transformers
scikit-image==0.25.2
# via
@@ -700,12 +548,10 @@ scikit-image==0.25.2
scipy==1.15.3
# via
# dm-control
# metaworld
# robosuite
# scikit-image
sentry-sdk==2.42.1
sentry-sdk==2.34.1
# via wandb
shapely==2.1.2
shapely==2.1.1
# via gym-pusht
six==1.17.0
# via
@@ -714,109 +560,66 @@ six==1.17.0
# python-xlib
smmap==5.0.2
# via gitdb
sniffio==1.3.1
# via anyio
stack-data==0.6.3
# via ipython
starlette==0.48.0
# via fastapi
sympy==1.14.0
# via torch
teleop==0.1.2
# via lerobot
tensorboard==2.20.0
# via robomimic
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.4
# via robomimic
termcolor==3.1.0
# via
# lerobot
# robomimic
thop==0.1.1.post2209072238
# via libero
# via lerobot
tifffile==2025.5.10
# via scikit-image
timm==1.0.20
# via lerobot
tokenizers==0.22.1
tokenizers==0.21.4
# via transformers
toml==0.10.2
# via draccus
tomli==2.3.0
tomli==2.2.1
# via
# cmeel
# coverage
# jupytext
# pytest
torch==2.7.1
# via
# accelerate
# flash-attn
# lerobot
# peft
# robomimic
# thop
# timm
# torchvision
torchcodec==0.5
# via lerobot
torchvision==0.22.1
# via
# lerobot
# robomimic
# timm
tornado==6.5.2
# via lerobot
tornado==6.5.1
# via meshcat
tqdm==4.67.1
# via
# datasets
# dm-control
# huggingface-hub
# peft
# robomimic
# transformers
traitlets==5.14.3
# via
# ipython
# jupyter-core
# matplotlib-inline
# nbformat
transformers==4.57.1
# via
# lerobot
# libero
# peft
transforms3d==0.4.2
# via teleop
transformers==4.51.3
# via lerobot
triton==3.3.1
# via torch
typing-extensions==4.15.0
typing-extensions==4.14.1
# via
# aiosignal
# anyio
# etils
# exceptiongroup
# fastapi
# gymnasium
# huggingface-hub
# ipython
# multidict
# pydantic
# pydantic-core
# referencing
# rerun-sdk
# starlette
# torch
# typing-inspect
# typing-inspection
# uvicorn
# virtualenv
# wandb
typing-inspect==0.9.0
# via draccus
typing-inspection==0.4.2
typing-inspection==0.4.1
# via pydantic
tzdata==2025.2
# via pandas
@@ -826,36 +629,22 @@ urllib3==2.5.0
# via
# requests
# sentry-sdk
uvicorn[standard]==0.38.0
# via teleop
uvloop==0.22.1
# via uvicorn
virtualenv==20.35.3
virtualenv==20.32.0
# via pre-commit
wandb==0.21.4
# via
# lerobot
# libero
watchfiles==1.1.1
# via uvicorn
wcwidth==0.2.14
wandb==0.21.0
# via lerobot
wcwidth==0.2.13
# via prompt-toolkit
websocket-client==1.9.0
# via teleop
websockets==15.0.1
# via uvicorn
werkzeug==3.1.3
# via tensorboard
wrapt==2.0.0
# via flask
wrapt==1.17.2
# via dm-tree
xxhash==3.6.0
xxhash==3.5.0
# via datasets
yarl==1.22.0
yarl==1.20.1
# via aiohttp
zipp==3.23.0
# via
# etils
# importlib-metadata
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@@ -1,9 +1,9 @@
# requirements.in
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 26.0.1 25A362 arm64).
# Darwin MacBook-Pro.local 25.0.0 Darwin Kernel Version 25.0.0: Wed Sep 17 21:42:08 PDT 2025; root:xnu-12377.1.9~141/RELEASE_ARM64_T8132 arm64
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.3 LTS x86_64).
# Linux mlerobot-linux 6.14.0-33-generic #33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
-e .[all]

View File

@@ -1,761 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inference script for SARM (Stage-Aware Reward Model).
This script loads a trained SARM model and runs inference on a dataset episode,
generating visualizations of the predicted task stages and progress over time.
Example usage:
python scripts/visualize_sarm_predictions.py \
--model-id username/sarm-model \
--dataset-repo lerobot/aloha_sim_insertion_human \
--episode-index 0 \
--output-dir outputs/sarm_viz \
--task-description "insert the peg into the socket"
"""
import argparse
import json
import logging
from pathlib import Path
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
from lerobot.policies.sarm.sarm_utils import (
pad_state_to_max_dim,
compute_tau,
compute_cumulative_progress_batch,
)
from lerobot.datasets.utils import load_stats
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Run SARM inference and visualize predictions")
# Model arguments
parser.add_argument(
"--model-id",
type=str,
required=True,
help="HuggingFace model ID or local path to trained SARM model"
)
# Dataset arguments
parser.add_argument(
"--dataset-repo",
type=str,
required=True,
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
)
parser.add_argument(
"--episode-index",
type=int,
default=0,
help="Index of the episode to visualize (default: 0)"
)
parser.add_argument(
"--task-description",
type=str,
default="perform the task",
help="Task description for the reward model (default: 'perform the task')"
)
# Output arguments
parser.add_argument(
"--output-dir",
type=str,
default="outputs/sarm_inference",
help="Directory to save visualization outputs (default: outputs/sarm_inference)"
)
parser.add_argument(
"--image-key",
type=str,
default=None,
help="Key for images in dataset (e.g., observation.images.image). If not specified, uses model config's image_key"
)
parser.add_argument(
"--state-key",
type=str,
default=None,
help="Key for joint states in dataset. If None, auto-detects from dataset"
)
# Visualization options
parser.add_argument(
"--show-frames",
action="store_true",
help="Include sample frames in the visualization"
)
parser.add_argument(
"--num-sample-frames",
type=int,
default=8,
help="Number of sample frames to show (default: 8)"
)
parser.add_argument(
"--figsize",
type=int,
nargs=2,
default=[14, 8],
help="Figure size as width height (default: 14 8)"
)
# Device
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to run inference on (cuda/cpu, default: auto-detect)"
)
return parser.parse_args()
def load_episode_data(
dataset: LeRobotDataset,
episode_index: int,
image_key: str,
state_key: str | None = None
) -> tuple[np.ndarray, np.ndarray, int, int, str]:
"""
Load all frames and states from a specific episode.
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode to load
image_key: Key for accessing images in the dataset
state_key: Key for accessing joint states (auto-detected if None)
Returns:
Tuple of (frames, states, start_index, end_index, task_description)
"""
# Get episode boundaries
episode_data = dataset.meta.episodes
start_idx = episode_data["dataset_from_index"][episode_index]
end_idx = episode_data["dataset_to_index"][episode_index]
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
# Auto-detect state key if not provided
if state_key is None:
first_item = dataset[start_idx]
state_keys = [k for k in first_item.keys() if 'state' in k.lower() or 'qpos' in k.lower()]
if state_keys:
state_key = state_keys[0]
logger.info(f"Auto-detected state key: {state_key}")
# Get task description from the dataset if available
task_description = None
first_item = dataset[start_idx]
if "task" in first_item:
task_description = first_item["task"]
logger.info(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
# Load all frames and states from the episode
frames = []
states = []
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
item = dataset[idx]
# Get image
img = item[image_key]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
img = img.cpu().numpy()
# Handle different image formats (C, H, W) or (H, W, C)
if img.shape[0] in [1, 3]: # Channel first
img = np.transpose(img, (1, 2, 0))
# Convert to uint8 if needed
if img.dtype != np.uint8:
if img.max() <= 1.0:
img = (img * 255).astype(np.uint8)
else:
img = img.astype(np.uint8)
frames.append(img)
# Get state if available
if state_key and state_key in item:
state = item[state_key]
if isinstance(state, torch.Tensor):
state = state.cpu().numpy()
states.append(state)
frames = np.array(frames)
states = np.array(states) if states else None
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
if states is not None:
logger.info(f"Loaded states with shape {states.shape}")
return frames, states, start_idx, end_idx, task_description
@torch.no_grad()
def run_inference(
model: SARMRewardModel,
frames: np.ndarray,
states: Optional[np.ndarray],
task_description: str,
dataset_stats: dict | None = None,
state_key: str = "observation.state",
batch_size: int = 32
) -> tuple[np.ndarray, np.ndarray]:
"""
Run SARM inference on video frames and joint states.
(per SARM paper Section A.4):
- Frame 0: Initial frame of the episode (frame 0)
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
Args:
model: SARM model
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
states: Joint states (num_frames, state_dim)
task_description: Task description text
dataset_stats: Dataset statistics for state normalization (same as training)
state_key: Key for state in dataset_stats
batch_size: Batch size for processing slices
Returns:
Tuple of (progress_predictions, stage_predictions)
- progress_predictions: (num_frames,)
- stage_predictions: (num_frames, num_stages)
"""
logger.info("Encoding video frames with CLIP...")
video_embeddings = model.encode_images(frames)
logger.info("Encoding task description with CLIP...")
text_embedding = model.encode_text(task_description)
# Get config values
num_frames_model = model.config.num_frames # 9
frame_gap = model.config.frame_gap # 30
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
# Convert to tensors
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
if states is not None:
state_embeddings = torch.tensor(states, dtype=torch.float32)
# Normalize states using dataset stats (same as training processor)
if dataset_stats is not None and state_key in dataset_stats:
mean = torch.tensor(dataset_stats[state_key]["mean"], dtype=torch.float32)
std = torch.tensor(dataset_stats[state_key]["std"], dtype=torch.float32)
state_embeddings = (state_embeddings - mean) / (std + 1e-8)
logger.info(f"✓ Applied MEAN_STD normalization to states using {state_key}")
else:
logger.warning("⚠ No dataset_stats provided - states not normalized (may differ from training)")
else:
state_embeddings = None
video_slices = []
state_slices = []
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices using symmetric bidirectional pattern:
# [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
# Boundary handling: clamp to [0, last_valid]
deltas = model.config.observation_delta_indices
last_valid = len(video_embeddings) - 1
frame_indices = []
for delta in deltas:
idx = current_frame + delta
idx = max(0, min(idx, last_valid)) # Clamp to valid range
frame_indices.append(idx)
video_slice = video_embeddings[frame_indices]
video_slices.append(video_slice)
if state_embeddings is not None:
state_slice = state_embeddings[frame_indices]
state_slices.append(state_slice)
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
if state_embeddings is not None:
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
# Pad states to max_state_dim (same as training processor)
state_slices = pad_state_to_max_dim(state_slices, model.config.max_state_dim)
else:
state_slices = None
logger.info("Running SARM inference on all slices...")
# Process in batches
all_progress = []
all_stages = []
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
batch_video = video_slices[i:i + batch_size].to(model.device)
batch_states = state_slices[i:i + batch_size].to(model.device) if state_slices is not None else None
batch_size_actual = batch_video.shape[0]
# Replicate text embedding for batch
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
# Get predictions
stage_logits, stage_probs, progress_preds = model.sarm_transformer(
batch_video, batch_text, batch_states
)
# Extract predictions at the "current frame" position
# With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g],
# the current frame is at position 5 (0-indexed)
current_frame_idx = 5
batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy()
batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy()
all_progress.extend(batch_progress)
all_stages.extend(batch_stages)
return np.array(all_progress), np.array(all_stages)
def compute_ground_truth_progress(
dataset: LeRobotDataset,
episode_index: int,
temporal_proportions: dict[str, float],
subtask_names_ordered: list[str],
) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
"""
Compute ground truth progress and stage labels for an episode using annotations.
Uses SARM Paper Formula (2):
y_t = P_{k-1} + ᾱ_k × τ_t
where:
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
- ᾱ_k is the temporal proportion for subtask k
Args:
dataset: LeRobotDataset instance
episode_index: Index of the episode
temporal_proportions: Dict mapping subtask name to proportion
subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing)
Returns:
Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations
"""
# Load episode metadata
episodes_df = dataset.meta.episodes.to_pandas()
# Check if annotations exist
if "subtask_names" not in episodes_df.columns:
logger.warning("No subtask_names column found in episodes metadata")
return None, None
ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"]
if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)):
logger.warning(f"No annotations found for episode {episode_index}")
return None, None
subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"]
subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"]
# Get episode boundaries
ep_start = dataset.meta.episodes["dataset_from_index"][episode_index]
ep_end = dataset.meta.episodes["dataset_to_index"][episode_index]
num_frames = ep_end - ep_start
# Get temporal proportions as ordered list
temporal_proportions_list = [
temporal_proportions.get(name, 0.0) for name in subtask_names_ordered
]
logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks")
logger.info(f"Subtask names in episode: {ep_subtask_names}")
logger.info(f"Subtask start frames: {subtask_start_frames}")
logger.info(f"Subtask end frames: {subtask_end_frames}")
logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}")
# Compute ground truth for each frame
gt_progress = np.zeros(num_frames)
gt_stages = np.zeros(num_frames, dtype=np.int32)
for frame_rel in range(num_frames):
# Find which subtask this frame belongs to
found = False
for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)):
if frame_rel >= start_frame and frame_rel <= end_frame:
# Found the subtask - get its global index
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0
# Compute τ_t using utility function
tau = compute_tau(frame_rel, start_frame, end_frame)
# Compute cumulative progress using utility function
progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
found = True
break
if not found:
# Handle frames outside annotated subtasks
if frame_rel < subtask_start_frames[0]:
gt_progress[frame_rel] = 0.0
gt_stages[frame_rel] = 0
elif frame_rel > subtask_end_frames[-1]:
gt_progress[frame_rel] = 1.0
gt_stages[frame_rel] = len(subtask_names_ordered) - 1
else:
# Between subtasks - find previous subtask
for j in range(len(ep_subtask_names) - 1):
if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]:
name = ep_subtask_names[j]
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j
progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list)
gt_progress[frame_rel] = progress
gt_stages[frame_rel] = stage_idx
break
logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}")
return gt_progress, gt_stages
def visualize_predictions(
frames: np.ndarray,
progress_predictions: np.ndarray,
stage_predictions: np.ndarray,
task_description: str,
output_path: Path,
num_sample_frames: int = 8,
figsize: tuple = (14, 8),
subtask_names: list[str] | None = None,
temporal_proportions: dict[str, float] | None = None,
ground_truth_progress: np.ndarray | None = None,
ground_truth_stages: np.ndarray | None = None,
):
"""
Create visualization of SARM predictions with optional ground truth comparison.
Args:
frames: Video frames (num_frames, H, W, C)
progress_predictions: Progress predictions (num_frames,)
stage_predictions: Stage probabilities (num_frames, num_stages)
task_description: Task description
output_path: Path to save the figure
num_sample_frames: Number of frames to show
figsize: Figure size (width, height)
subtask_names: Optional list of subtask names for labeling
temporal_proportions: Optional dict of temporal proportions for each subtask
ground_truth_progress: Optional ground truth progress array (num_frames,)
ground_truth_stages: Optional ground truth stage indices array (num_frames,)
"""
num_stages = stage_predictions.shape[1]
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
# Use subtask names if available, otherwise use generic labels
if subtask_names is not None and len(subtask_names) == num_stages:
stage_labels = subtask_names
else:
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
# Create figure with progress plot, stage plot, and sample frames
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
ax_progress = fig.add_subplot(gs[0])
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
ax_frames = fig.add_subplot(gs[2])
frame_indices = np.arange(len(progress_predictions))
# Plot 1: Progress over time
ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress')
ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB')
# Plot ground truth if available
if ground_truth_progress is not None:
ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745',
linestyle='--', label='Ground Truth Progress')
ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745')
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax_progress.set_ylabel('Task Progress', fontsize=12)
ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold')
ax_progress.grid(True, alpha=0.3)
ax_progress.set_ylim(-0.05, 1.1)
ax_progress.legend(loc='upper left')
# Add statistics box
stats_text = (
f'Frames: {len(progress_predictions)}\n'
f'Final Progress: {progress_predictions[-1]:.3f}\n'
f'Max Progress: {progress_predictions.max():.3f}\n'
f'Mean Progress: {progress_predictions.mean():.3f}'
)
if ground_truth_progress is not None:
mse = np.mean((progress_predictions - ground_truth_progress) ** 2)
stats_text += f'\nMSE vs GT: {mse:.4f}'
stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}'
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
# Plot 2: Stage predictions (stacked area plot)
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
colors=stage_colors, alpha=0.8, labels=stage_labels)
# Plot ground truth stage as vertical bands or markers
if ground_truth_stages is not None:
# Find stage transition points in ground truth
stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1
for change_idx in stage_changes:
ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5)
ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1)
# Add small markers at bottom showing GT stage
gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1)
ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02,
c=[stage_colors[s] for s in ground_truth_stages[::30]],
s=20, marker='|', alpha=0.8, label='GT Stage Markers')
ax_stages.set_xlabel('Frame Index', fontsize=12)
ax_stages.set_ylabel('Stage Probability', fontsize=12)
ax_stages.set_ylim(0, 1)
ax_stages.grid(True, alpha=0.3)
# Adjust legend based on number of stages and label lengths
if num_stages <= 5:
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
else:
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
if temporal_proportions is not None and subtask_names is not None:
cumulative_progress = 0.0
for i, name in enumerate(stage_labels):
if name in temporal_proportions:
# Find approximate frame where this stage should end
stage_end_progress = cumulative_progress + temporal_proportions[name]
# Find frame index closest to this progress
progress_diffs = np.abs(progress_predictions - stage_end_progress)
stage_end_frame = np.argmin(progress_diffs)
# Draw vertical line
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
cumulative_progress = stage_end_progress
# Plot 3: Sample frames (if requested)
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
ax_frames.axis('off')
# Create grid for frames
frame_height = frames[0].shape[0]
frame_width = frames[0].shape[1]
combined_width = frame_width * num_sample_frames
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
for i, frame_idx in enumerate(frame_indices_to_show):
frame = frames[frame_idx]
if frame.shape[-1] == 1:
frame = np.repeat(frame, 3, axis=-1)
# Add frame to combined image
x_start = i * frame_width
x_end = (i + 1) * frame_width
combined_image[:, x_start:x_end] = frame
# Add frame number, progress, and stage
progress_val = progress_predictions[frame_idx]
stage_idx = np.argmax(stage_predictions[frame_idx])
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
# Truncate long stage names for display
if len(stage_name) > 15:
stage_name = stage_name[:12] + '...'
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
# Draw label on image
ax_frames.text(x_start + frame_width / 2, -10, label,
ha='center', va='top', fontsize=7,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
ax_frames.imshow(combined_image)
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
plt.tight_layout()
output_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
logger.info(f"Saved visualization to {output_path}")
plt.close()
def main():
args = parse_args()
# Setup device
if args.device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
logger.info(f"Using device: {device}")
# Load model
logger.info(f"Loading SARM model from {args.model_id}...")
model = SARMRewardModel.from_pretrained(args.model_id)
model.to(device)
model.eval()
logger.info("Model loaded successfully")
# Load dataset
logger.info(f"Loading dataset {args.dataset_repo}...")
dataset = LeRobotDataset(args.dataset_repo)
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
# Validate episode index
if args.episode_index >= len(dataset.meta.episodes):
raise ValueError(
f"Episode index {args.episode_index} out of range. "
f"Dataset has {len(dataset.meta.episodes)} episodes."
)
image_key = args.image_key if args.image_key is not None else model.config.image_key
state_key = args.state_key if args.state_key is not None else model.config.state_key
logger.info(f"Using image key: {image_key}")
logger.info(f"Using state key: {state_key}")
# Load dataset stats for state normalization (same as training)
dataset_stats = load_stats(dataset.root)
if dataset_stats:
logger.info(f"✓ Loaded dataset stats from {dataset.root}")
else:
logger.warning("⚠ Could not load dataset stats - states will not be normalized")
# Load episode data
frames, states, start_idx, end_idx, dataset_task = load_episode_data(
dataset, args.episode_index, image_key, state_key
)
# Use task description from dataset if available, otherwise use command-line argument
task_description = dataset_task if dataset_task is not None else args.task_description
logger.info(f"Using task description: '{task_description}'")
# Run inference
progress_predictions, stage_predictions = run_inference(
model, frames, states, task_description,
dataset_stats=dataset_stats, state_key=state_key
)
# Extract subtask names and temporal proportions from model config if available
subtask_names = None
temporal_proportions = None
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
subtask_names = model.config.subtask_names
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
# Try to load temporal proportions from model config
if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None:
temporal_proportions = {
name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions)
}
logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}")
# Fallback: try to load from dataset meta
if temporal_proportions is None:
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
if proportions_path.exists():
with open(proportions_path, 'r') as f:
temporal_proportions = json.load(f)
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
# Also extract subtask names from proportions if not already set
if subtask_names is None:
subtask_names = sorted(temporal_proportions.keys())
logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}")
# Compute ground truth progress if annotations are available
ground_truth_progress = None
ground_truth_stages = None
if temporal_proportions is not None and subtask_names is not None:
logger.info("Attempting to compute ground truth progress from annotations...")
ground_truth_progress, ground_truth_stages = compute_ground_truth_progress(
dataset,
args.episode_index,
temporal_proportions,
subtask_names
)
if ground_truth_progress is None:
logger.warning("⚠ Ground truth not available - annotations may be missing for this episode")
else:
logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available")
output_dir = Path(args.output_dir)
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
visualize_predictions(
frames,
progress_predictions,
stage_predictions,
task_description,
output_path,
num_sample_frames=args.num_sample_frames,
figsize=tuple(args.figsize),
subtask_names=subtask_names,
temporal_proportions=temporal_proportions,
ground_truth_progress=ground_truth_progress,
ground_truth_stages=ground_truth_stages,
)
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
save_dict = {
'progress': progress_predictions,
'stages': stage_predictions
}
if ground_truth_progress is not None:
save_dict['gt_progress'] = ground_truth_progress
save_dict['gt_stages'] = ground_truth_stages
np.savez(predictions_path, **save_dict)
logger.info(f"Saved predictions to {predictions_path}")
logger.info(f"\nVisualization: {output_path}")
if __name__ == "__main__":
main()

View File

@@ -16,7 +16,7 @@ import logging
import logging.handlers
import os
import time
from dataclasses import dataclass, field
from dataclasses import dataclass
from pathlib import Path
import torch
@@ -268,7 +268,6 @@ class RemotePolicyConfig:
lerobot_features: dict[str, PolicyFeature]
actions_per_chunk: int
device: str = "cpu"
rename_map: dict[str, str] = field(default_factory=dict)
def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool:

View File

@@ -159,10 +159,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
self.preprocessor, self.postprocessor = make_pre_post_processors(
self.policy.config,
pretrained_path=policy_specs.pretrained_name_or_path,
preprocessor_overrides={
"device_processor": device_override,
"rename_observations_processor": {"rename_map": policy_specs.rename_map},
},
preprocessor_overrides={"device_processor": device_override},
postprocessor_overrides={"device_processor": device_override},
)

View File

@@ -17,7 +17,7 @@
import abc
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
import numpy as np
from .configs import CameraConfig, ColorMode
@@ -89,7 +89,7 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""Capture and return a single frame from the camera.
Args:
@@ -102,7 +102,7 @@ class Camera(abc.ABC):
pass
@abc.abstractmethod
def async_read(self, timeout_ms: float = ...) -> NDArray[Any]:
def async_read(self, timeout_ms: float = ...) -> np.ndarray:
"""Asynchronously capture and return a single frame from the camera.
Args:

View File

@@ -18,7 +18,7 @@ import abc
from dataclasses import dataclass
from enum import Enum
import draccus # type: ignore # TODO: add type stubs for draccus
import draccus
class ColorMode(str, Enum):
@@ -34,11 +34,11 @@ class Cv2Rotation(int, Enum):
@dataclass(kw_only=True)
class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
fps: int | None = None
width: int | None = None
height: int | None = None
@property
def type(self) -> str:
return str(self.get_choice_name(self.__class__))
return self.get_choice_name(self.__class__)

View File

@@ -14,5 +14,3 @@
from .camera_opencv import OpenCVCamera
from .configuration_opencv import OpenCVCameraConfig
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]

View File

@@ -25,12 +25,11 @@ from pathlib import Path
from threading import Event, Lock, Thread
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import cv2
import numpy as np
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
@@ -122,7 +121,7 @@ class OpenCVCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -141,7 +140,7 @@ class OpenCVCamera(Camera):
"""Checks if the camera is currently connected and opened."""
return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened()
def connect(self, warmup: bool = True) -> None:
def connect(self, warmup: bool = True):
"""
Connects to the OpenCV camera specified in the configuration.
@@ -181,14 +180,12 @@ class OpenCVCamera(Camera):
def _configure_capture_settings(self) -> None:
"""
Applies the specified FOURCC, FPS, width, and height settings to the connected camera.
Applies the specified FPS, width, and height settings to the connected camera.
This method attempts to set the camera properties via OpenCV. It checks if
the camera successfully applied the settings and raises an error if not.
FOURCC is set first (if specified) as it can affect the available FPS and resolution options.
Args:
fourcc: The desired FOURCC code (e.g., "MJPG", "YUYV"). If None, auto-detect.
fps: The desired frames per second. If None, the setting is skipped.
width: The desired capture width. If None, the setting is skipped.
height: The desired capture height. If None, the setting is skipped.
@@ -202,11 +199,10 @@ class OpenCVCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.")
# Set FOURCC first (if specified) as it can affect available FPS/resolution options
if self.config.fourcc is not None:
self._validate_fourcc()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if self.fps is None:
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
else:
self._validate_fps()
default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH)))
default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
@@ -220,56 +216,18 @@ class OpenCVCamera(Camera):
else:
self._validate_width_and_height()
if self.fps is None:
self.fps = self.videocapture.get(cv2.CAP_PROP_FPS)
else:
self._validate_fps()
def _validate_fps(self) -> None:
"""Validates and sets the camera's frames per second (FPS)."""
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if self.fps is None:
raise ValueError(f"{self} FPS is not set")
success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps))
actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS)
# Use math.isclose for robust float comparison
if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).")
def _validate_fourcc(self) -> None:
"""Validates and sets the camera's FOURCC code."""
fourcc_code = cv2.VideoWriter_fourcc(*self.config.fourcc)
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
success = self.videocapture.set(cv2.CAP_PROP_FOURCC, fourcc_code)
actual_fourcc_code = self.videocapture.get(cv2.CAP_PROP_FOURCC)
# Convert actual FOURCC code back to string for comparison
actual_fourcc_code_int = int(actual_fourcc_code)
actual_fourcc = "".join([chr((actual_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
if not success or actual_fourcc != self.config.fourcc:
logger.warning(
f"{self} failed to set fourcc={self.config.fourcc} (actual={actual_fourcc}, success={success}). "
f"Continuing with default format."
)
def _validate_width_and_height(self) -> None:
"""Validates and sets the camera's frame capture width and height."""
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
if self.capture_width is None or self.capture_height is None:
raise ValueError(f"{self} capture_width or capture_height is not set")
width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width))
height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height))
@@ -300,12 +258,11 @@ class OpenCVCamera(Camera):
"""
found_cameras_info = []
targets_to_scan: list[str | int]
if platform.system() == "Linux":
possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name)
targets_to_scan = [str(p) for p in possible_paths]
else:
targets_to_scan = [int(i) for i in range(MAX_OPENCV_INDEX)]
targets_to_scan = list(range(MAX_OPENCV_INDEX))
for target in targets_to_scan:
camera = cv2.VideoCapture(target)
@@ -314,12 +271,6 @@ class OpenCVCamera(Camera):
default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
default_fps = camera.get(cv2.CAP_PROP_FPS)
default_format = camera.get(cv2.CAP_PROP_FORMAT)
# Get FOURCC code and convert to string
default_fourcc_code = camera.get(cv2.CAP_PROP_FOURCC)
default_fourcc_code_int = int(default_fourcc_code)
default_fourcc = "".join([chr((default_fourcc_code_int >> 8 * i) & 0xFF) for i in range(4)])
camera_info = {
"name": f"OpenCV Camera @ {target}",
"type": "OpenCV",
@@ -327,7 +278,6 @@ class OpenCVCamera(Camera):
"backend_api": camera.getBackendName(),
"default_stream_profile": {
"format": default_format,
"fourcc": default_fourcc,
"width": default_width,
"height": default_height,
"fps": default_fps,
@@ -339,7 +289,7 @@ class OpenCVCamera(Camera):
return found_cameras_info
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Reads a single frame synchronously from the camera.
@@ -367,9 +317,6 @@ class OpenCVCamera(Camera):
start_time = time.perf_counter()
if self.videocapture is None:
raise DeviceNotConnectedError(f"{self} videocapture is not initialized")
ret, frame = self.videocapture.read()
if not ret or frame is None:
@@ -382,7 +329,7 @@ class OpenCVCamera(Camera):
return processed_frame
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Applies color conversion, dimension validation, and rotation to a raw frame.
@@ -425,7 +372,7 @@ class OpenCVCamera(Camera):
return processed_image
def _read_loop(self) -> None:
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
@@ -436,9 +383,6 @@ class OpenCVCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read()
@@ -475,7 +419,7 @@ class OpenCVCamera(Camera):
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame asynchronously.
@@ -518,7 +462,7 @@ class OpenCVCamera(Camera):
return frame
def disconnect(self) -> None:
def disconnect(self):
"""
Disconnects from the camera and cleans up resources.

View File

@@ -17,8 +17,6 @@ from pathlib import Path
from ..configs import CameraConfig, ColorMode, Cv2Rotation
__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"]
@CameraConfig.register_subclass("opencv")
@dataclass
@@ -35,9 +33,8 @@ class OpenCVCameraConfig(CameraConfig):
OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS
OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS
# Advanced configurations with FOURCC format
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90, fourcc="MJPG") # With 90° rotation and MJPG format
OpenCVCameraConfig(0, 30, 1280, 720, fourcc="YUYV") # With YUYV format
# Advanced configurations
OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation
```
Attributes:
@@ -49,21 +46,17 @@ class OpenCVCameraConfig(CameraConfig):
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation.
warmup_s: Time reading frames before returning from connect (in seconds)
fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect).
Note:
- Only 3-channel color output (RGB/BGR) is currently supported.
- FOURCC codes must be 4-character strings (e.g., "MJPG", "YUYV"). Some common FOUCC codes: https://learn.microsoft.com/en-us/windows/win32/medfound/video-fourccs#fourcc-constants
- Setting FOURCC can help achieve higher frame rates on some cameras.
"""
index_or_path: int | Path
color_mode: ColorMode = ColorMode.RGB
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
fourcc: str | None = None
def __post_init__(self) -> None:
def __post_init__(self):
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
@@ -78,8 +71,3 @@ class OpenCVCameraConfig(CameraConfig):
raise ValueError(
f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided."
)
if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4):
raise ValueError(
f"`fourcc` must be a 4-character string (e.g., 'MJPG', 'YUYV'), but '{self.fourcc}' is provided."
)

View File

@@ -16,8 +16,6 @@ from dataclasses import dataclass
from ..configs import CameraConfig, ColorMode
__all__ = ["CameraConfig", "ColorMode", "Reachy2CameraConfig"]
@CameraConfig.register_subclass("reachy2_camera")
@dataclass
@@ -64,7 +62,7 @@ class Reachy2CameraConfig(CameraConfig):
port: int = 50065
# use_depth: bool = False
def __post_init__(self) -> None:
def __post_init__(self):
if self.name not in ["teleop", "depth"]:
raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.")
if (self.name == "teleop" and self.image_type not in ["left", "right"]) or (

View File

@@ -23,17 +23,13 @@ import time
from threading import Event, Lock, Thread
from typing import Any
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
# Fix MSMF hardware transform compatibility for Windows before importing cv2
if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ:
os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0"
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk
from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk
CameraManager,
)
import cv2
import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
from lerobot.utils.errors import DeviceNotConnectedError
@@ -77,7 +73,7 @@ class Reachy2Camera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
def __str__(self) -> str:
@@ -87,17 +83,13 @@ class Reachy2Camera(Camera):
def is_connected(self) -> bool:
"""Checks if the camera is currently connected and opened."""
if self.config.name == "teleop":
return bool(
self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
)
return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False
elif self.config.name == "depth":
return bool(
self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
)
return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False
else:
raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.")
def connect(self, warmup: bool = True) -> None:
def connect(self, warmup: bool = True):
"""
Connects to the Reachy2 CameraManager as specified in the configuration.
"""
@@ -139,7 +131,7 @@ class Reachy2Camera(Camera):
camera_manager.disconnect()
return initialized_cameras
def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None) -> np.ndarray:
"""
Reads a single frame synchronously from the camera.
@@ -160,7 +152,7 @@ class Reachy2Camera(Camera):
start_time = time.perf_counter()
frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8)
frame = None
if self.cam_manager is None:
raise DeviceNotConnectedError(f"{self} is not connected.")
@@ -187,7 +179,7 @@ class Reachy2Camera(Camera):
return frame
def _read_loop(self) -> None:
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
@@ -198,9 +190,6 @@ class Reachy2Camera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read()
@@ -237,7 +226,7 @@ class Reachy2Camera(Camera):
self.thread = None
self.stop_event = None
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame asynchronously.
@@ -280,7 +269,7 @@ class Reachy2Camera(Camera):
return frame
def disconnect(self) -> None:
def disconnect(self):
"""
Stops the background read thread (if running).

View File

@@ -21,12 +21,11 @@ import time
from threading import Event, Lock, Thread
from typing import Any
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import numpy as np # type: ignore # TODO: add type stubs for numpy
from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing
import cv2
import numpy as np
try:
import pyrealsense2 as rs # type: ignore # TODO: add type stubs for pyrealsense2
import pyrealsense2 as rs
except Exception as e:
logging.info(f"Could not import realsense: {e}")
@@ -133,7 +132,7 @@ class RealSenseCamera(Camera):
self.thread: Thread | None = None
self.stop_event: Event | None = None
self.frame_lock: Lock = Lock()
self.latest_frame: NDArray[Any] | None = None
self.latest_frame: np.ndarray | None = None
self.new_frame_event: Event = Event()
self.rotation: int | None = get_cv2_rotation(config.rotation)
@@ -151,7 +150,7 @@ class RealSenseCamera(Camera):
"""Checks if the camera pipeline is started and streams are active."""
return self.rs_pipeline is not None and self.rs_profile is not None
def connect(self, warmup: bool = True) -> None:
def connect(self, warmup: bool = True):
"""
Connects to the RealSense camera specified in the configuration.
@@ -265,7 +264,7 @@ class RealSenseCamera(Camera):
serial_number = str(found_devices[0]["serial_number"])
return serial_number
def _configure_rs_pipeline_config(self, rs_config: Any) -> None:
def _configure_rs_pipeline_config(self, rs_config):
"""Creates and configures the RealSense pipeline configuration object."""
rs.config.enable_device(rs_config, self.serial_number)
@@ -294,9 +293,6 @@ class RealSenseCamera(Camera):
if not self.is_connected:
raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.")
if self.rs_profile is None:
raise RuntimeError(f"{self}: rs_profile must be initialized before use.")
stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile()
if self.fps is None:
@@ -312,7 +308,7 @@ class RealSenseCamera(Camera):
self.width, self.height = actual_width, actual_height
self.capture_width, self.capture_height = actual_width, actual_height
def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]:
def read_depth(self, timeout_ms: int = 200) -> np.ndarray:
"""
Reads a single frame (depth) synchronously from the camera.
@@ -340,9 +336,6 @@ class RealSenseCamera(Camera):
start_time = time.perf_counter()
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
@@ -358,7 +351,7 @@ class RealSenseCamera(Camera):
return depth_map_processed
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]:
def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray:
"""
Reads a single frame (color) synchronously from the camera.
@@ -383,9 +376,6 @@ class RealSenseCamera(Camera):
start_time = time.perf_counter()
if self.rs_pipeline is None:
raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.")
ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms)
if not ret or frame is None:
@@ -402,8 +392,8 @@ class RealSenseCamera(Camera):
return color_image_processed
def _postprocess_image(
self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False
) -> NDArray[Any]:
self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False
) -> np.ndarray:
"""
Applies color conversion, dimension validation, and rotation to a raw color frame.
@@ -448,7 +438,7 @@ class RealSenseCamera(Camera):
return processed_image
def _read_loop(self) -> None:
def _read_loop(self):
"""
Internal loop run by the background thread for asynchronous reading.
@@ -459,9 +449,6 @@ class RealSenseCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues.
"""
if self.stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
while not self.stop_event.is_set():
try:
color_image = self.read(timeout_ms=500)
@@ -487,7 +474,7 @@ class RealSenseCamera(Camera):
self.thread.daemon = True
self.thread.start()
def _stop_read_thread(self) -> None:
def _stop_read_thread(self):
"""Signals the background read thread to stop and waits for it to join."""
if self.stop_event is not None:
self.stop_event.set()
@@ -499,7 +486,7 @@ class RealSenseCamera(Camera):
self.stop_event = None
# NOTE(Steven): Missing implementation for depth for now
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
def async_read(self, timeout_ms: float = 200) -> np.ndarray:
"""
Reads the latest available frame data (color) asynchronously.
@@ -542,7 +529,7 @@ class RealSenseCamera(Camera):
return frame
def disconnect(self) -> None:
def disconnect(self):
"""
Disconnects from the camera, stops the pipeline, and cleans up resources.

View File

@@ -59,7 +59,7 @@ class RealSenseCameraConfig(CameraConfig):
rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION
warmup_s: int = 1
def __post_init__(self) -> None:
def __post_init__(self):
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
raise ValueError(
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."

View File

@@ -53,14 +53,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
def get_cv2_rotation(rotation: Cv2Rotation) -> int | None:
import cv2 # type: ignore # TODO: add type stubs for OpenCV
import cv2
if rotation == Cv2Rotation.ROTATE_90:
return int(cv2.ROTATE_90_CLOCKWISE)
return cv2.ROTATE_90_CLOCKWISE
elif rotation == Cv2Rotation.ROTATE_180:
return int(cv2.ROTATE_180)
return cv2.ROTATE_180
elif rotation == Cv2Rotation.ROTATE_270:
return int(cv2.ROTATE_90_COUNTERCLOCKWISE)
return cv2.ROTATE_90_COUNTERCLOCKWISE
else:
return None
@@ -69,8 +69,8 @@ def get_cv2_backend() -> int:
import cv2
if platform.system() == "Windows":
return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION
return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION
# elif platform.system() == "Darwin": # macOS
# return cv2.CAP_AVFOUNDATION
else: # Linux and others
return int(cv2.CAP_ANY)
return cv2.CAP_ANY

View File

@@ -57,7 +57,7 @@ class EvalConfig:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: bool = False
def __post_init__(self) -> None:
def __post_init__(self):
if self.batch_size > self.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "

View File

@@ -13,8 +13,8 @@
# limitations under the License.
import datetime as dt
import logging
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
from lerobot import envs, policies # noqa: F401
@@ -22,8 +22,6 @@ from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
logger = getLogger(__name__)
@dataclass
class EvalPipelineConfig:
@@ -36,31 +34,25 @@ class EvalPipelineConfig:
output_dir: Path | None = None
job_name: str | None = None
seed: int | None = 1000
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path)
self.policy.pretrained_path = policy_path
else:
logger.warning(
logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type if self.policy is not None else 'scratch'}"
self.job_name = f"{self.policy.type}"
else:
self.job_name = (
f"{self.env.type}_{self.policy.type if self.policy is not None else 'scratch'}"
)
logger.warning(f"No job name provided, using '{self.job_name}' as job name.")
self.job_name = f"{self.env.type}_{self.policy.type}"
if not self.output_dir:
now = dt.datetime.now()

View File

@@ -16,19 +16,14 @@ import inspect
import pkgutil
import sys
from argparse import ArgumentError
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Sequence
from functools import wraps
from pathlib import Path
from pkgutil import ModuleInfo
from types import ModuleType
from typing import Any, TypeVar, cast
import draccus
from lerobot.utils.utils import has_method
F = TypeVar("F", bound=Callable[..., object])
PATH_KEY = "path"
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
@@ -65,7 +60,7 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
return None
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict[str, str]:
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
"""Parse plugin-related arguments from command-line arguments.
This function extracts arguments from command-line arguments that match a specified suffix pattern.
@@ -132,7 +127,7 @@ def load_plugin(plugin_path: str) -> None:
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
) from e
def iter_namespace(ns_pkg: ModuleType) -> Iterable[ModuleInfo]:
def iter_namespace(ns_pkg):
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
try:
@@ -153,8 +148,6 @@ def get_type_arg(field_name: str, args: Sequence[str] | None = None) -> str | No
def filter_arg(field_to_filter: str, args: Sequence[str] | None = None) -> list[str]:
if args is None:
return []
return [arg for arg in args if not arg.startswith(f"--{field_to_filter}=")]
@@ -178,8 +171,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
if isinstance(fields_to_filter, str):
fields_to_filter = [fields_to_filter]
filtered_args = [] if args is None else list(args)
filtered_args = args
for field in fields_to_filter:
if get_path_arg(field, args):
if get_type_arg(field, args):
@@ -192,7 +184,7 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
return filtered_args
def wrap(config_path: Path | None = None) -> Callable[[F], F]:
def wrap(config_path: Path | None = None):
"""
HACK: Similar to draccus.wrap but does three additional things:
- Will remove '.path' arguments from CLI in order to process them later on.
@@ -203,9 +195,9 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
from the CLI '.type' arguments
"""
def wrapper_outer(fn: F) -> F:
def wrapper_outer(fn):
@wraps(fn)
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
def wrapper_inner(*args, **kwargs):
argspec = inspect.getfullargspec(fn)
argtype = argspec.annotations[argspec.args[0]]
if len(args) > 0 and type(args[0]) is argtype:
@@ -233,6 +225,6 @@ def wrap(config_path: Path | None = None) -> Callable[[F], F]:
response = fn(cfg, *args, **kwargs)
return response
return cast(F, wrapper_inner)
return wrapper_inner
return cast(Callable[[F], F], wrapper_outer)
return wrapper_outer

View File

@@ -14,12 +14,12 @@
import abc
import builtins
import json
import logging
import os
import tempfile
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
from typing import Any, TypeVar
from typing import TypeVar
import draccus
from huggingface_hub import hf_hub_download
@@ -34,11 +34,10 @@ from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
T = TypeVar("T", bound="PreTrainedConfig")
logger = getLogger(__name__)
@dataclass
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: ignore[misc,name-defined] #TODO: draccus issue
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
Base configuration class for policy models.
@@ -58,12 +57,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override
push_to_hub: bool = True
repo_id: str | None = None
# Upload on private repository on the Hugging Face hub.
@@ -74,41 +73,38 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
license: str | None = None
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: Path | None = None
pretrained_path: str | None = None
def __post_init__(self) -> None:
def __post_init__(self):
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logger.warning(
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@property
def type(self) -> str:
choice_name = self.get_choice_name(self.__class__)
if not isinstance(choice_name, str):
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
return choice_name
return self.get_choice_name(self.__class__)
@property
@abc.abstractmethod
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
def observation_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
def action_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
def reward_delta_indices(self) -> list | None:
raise NotImplementedError
@abc.abstractmethod
@@ -158,13 +154,13 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**policy_kwargs: Any,
**policy_kwargs,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
@@ -172,7 +168,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
@@ -198,9 +194,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
with draccus.config_type("json"):
orig_config = draccus.parse(cls, config_file, args=[])
if config_file is None:
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
with open(config_file) as f:
config = json.load(f)

View File

@@ -16,7 +16,6 @@ import datetime as dt
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import draccus
from huggingface_hub import hf_hub_download
@@ -64,35 +63,18 @@ class TrainPipelineConfig(HubMixin):
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
# RA-BC (Reward-Aligned Behavior Cloning) parameters
use_rabc: bool = False # Enable reward-weighted training
reward_model_path: str | None = None # Path to pre-trained reward model (e.g., SARM)
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
rabc_update_freq: int = 1 # Compute rewards every N batches (1 = every batch)
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
checkpoint_path: Path | None = field(init=False, default=None)
def __post_init__(self):
self.checkpoint_path = None
def validate(self):
# Validate RA-BC configuration
if self.use_rabc and not self.reward_model_path:
raise ValueError(
"RA-BC is enabled (use_rabc=True) but no reward_model_path provided. "
"Please specify a pre-trained reward model (e.g., SARM) path."
)
def validate(self) -> None:
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = Path(policy_path)
self.policy.pretrained_path = policy_path
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
@@ -100,22 +82,14 @@ class TrainPipelineConfig(HubMixin):
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_dir = Path(config_path).parent
if self.policy is not None:
self.policy.pretrained_path = policy_dir
self.checkpoint_path = policy_dir.parent
if self.policy is None:
raise ValueError(
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
)
policy_path = Path(config_path).parent
self.policy.pretrained_path = policy_path
self.checkpoint_path = policy_path.parent
if not self.job_name:
if self.env is None:
@@ -152,8 +126,8 @@ class TrainPipelineConfig(HubMixin):
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def to_dict(self) -> dict[str, Any]:
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
def to_dict(self) -> dict:
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
@@ -165,13 +139,13 @@ class TrainPipelineConfig(HubMixin):
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict[Any, Any] | None = None,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**kwargs: Any,
**kwargs,
) -> "TrainPipelineConfig":
model_id = str(pretrained_name_or_path)
config_file: str | None = None
@@ -207,6 +181,4 @@ class TrainPipelineConfig(HubMixin):
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset

View File

@@ -42,11 +42,4 @@ class NormalizationMode(str, Enum):
@dataclass
class PolicyFeature:
type: FeatureType
shape: tuple[int, ...]
class RTCAttentionSchedule(str, Enum):
ZEROS = "ZEROS"
ONES = "ONES"
LINEAR = "LINEAR"
EXP = "EXP"
shape: tuple

View File

@@ -39,7 +39,6 @@ from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -963,23 +962,28 @@ def _copy_data_with_feature_changes(
remove_features: list[str] | None = None,
) -> None:
"""Copy data while adding or removing features."""
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if dataset.meta.episodes is None:
dataset.meta.episodes = load_episodes(dataset.meta.root)
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
# Map file paths to episode indices to extract chunk/file indices
file_to_episodes: dict[Path, set[int]] = {}
for ep_idx in range(dataset.meta.total_episodes):
file_path = dataset.meta.get_data_file_path(ep_idx)
if file_path not in file_to_episodes:
file_to_episodes[file_path] = set()
file_to_episodes[file_path].add(ep_idx)
frame_idx = 0
for src_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
relative_path = src_path.relative_to(dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
# Get chunk_idx and file_idx from the source file's first episode
episodes_in_file = file_to_episodes[src_path]
first_ep_idx = min(episodes_in_file)
src_ep = dataset.meta.episodes[first_ep_idx]
chunk_idx = src_ep["data/chunk_index"]
file_idx = src_ep["data/file_index"]
if remove_features:
df = df.drop(columns=remove_features, errors="ignore")
@@ -999,21 +1003,13 @@ def _copy_data_with_feature_changes(
df[feature_name] = feature_values
else:
feature_slice = values[frame_idx:end_idx]
if len(feature_slice.shape) == 1:
# 1D array - can assign directly
df[feature_name] = feature_slice
elif len(feature_slice.shape) == 2 and feature_slice.shape[1] == 1:
# 2D array with single column - flatten it
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
df[feature_name] = feature_slice.flatten()
elif len(feature_slice.shape) == 2:
# 2D array with multiple columns (e.g., embeddings) - convert to list of lists
df[feature_name] = feature_slice.tolist()
else:
# Higher dimensional - convert to list
df[feature_name] = [row.tolist() for row in feature_slice]
df[feature_name] = feature_slice
frame_idx = end_idx
# Write using the same chunk/file structure as source
# Write using the preserved chunk_idx and file_idx from source
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)

View File

@@ -1,146 +0,0 @@
# LeRobot Embedding Generation Script
Generate embeddings for LeRobot datasets to make them more lightweight and efficient for training.
## Overview
This script processes v3.0 LeRobot datasets and adds pre-computed embeddings for:
- **Task embeddings**: Language command embeddings using MiniLM
- **Image embeddings**: Frame embeddings using DinoV2
The resulting dataset can be used more efficiently during training by loading pre-computed embeddings instead of running encoders on-the-fly.
## Supported Encoders
### Image Encoders (DinoV2)
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings:
- **`dinov2_vits14`**: ViT-S/14 (384-dim) - Fastest, smaller model
- **`dinov2_vitb14`**: ViT-B/14 (768-dim) - **Recommended** - Good balance
- **`dinov2_vitl14`**: ViT-L/14 (1024-dim) - Best quality, slower
### Language Encoders (MiniLM)
MiniLM is a lightweight sentence transformer model:
- **`minilm-l6`**: MiniLM-L6-v2 (384-dim) - Faster
- **`minilm-l12`**: MiniLM-L12-v2 (384-dim) - **Recommended** - Better quality
## Usage
### Basic Command
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--push-to-hub
```
### Lightweight Version (No Videos)
Removes video files to significantly reduce storage:
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
```
## Output
The script adds new features to your dataset:
### New Features
1. **`task_embedding`**: Language embedding for each frame
- Shape: `[384]` (MiniLM)
- One embedding per frame based on its task
2. **`{camera_key}_embedding`**: Image embedding for each camera view
- Shape: `[384]`, `[768]`, or `[1024]` depending on DinoV2 model
- Examples: `observation.images.top_embedding`, `observation.images.wrist_embedding`
### Using Embeddings in Training
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load dataset with embeddings
dataset = LeRobotDataset("your-username/utokyo_xarm_bimanual_embeddings")
# Access embeddings
item = dataset[0]
task_emb = item["task_embedding"] # Shape: [384]
img_emb = item["observation.images.top_embedding"] # Shape: [768]
# Use in your policy
# Instead of running encoders during training, use pre-computed embeddings
```
## Extending with New Encoders
The script is designed to be easily extensible. To add a new encoder:
### 1. Create Encoder Class
```python
class MyCustomImageEncoder(ImageEncoder):
"""Your custom image encoder."""
def __init__(self, device: str = "cuda"):
super().__init__(device)
# Load your model
self.model = load_my_model()
self.model = self.model.to(self.device)
self.model.eval()
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
# Your encoding logic here
embeddings = []
for img in images:
emb = self.model(img)
embeddings.append(emb)
return np.array(embeddings)
@property
def embedding_dim(self) -> int:
"""Return embedding dimension."""
return 512 # Your embedding dimension
```
### 2. Add to Factory Function
```python
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
# Add your encoder
"my_custom": lambda: MyCustomImageEncoder(device=device),
}
# ... rest of function
```
## Validating Embeddings
After generating embeddings, you can validate them using `validate_embeddings.py`:
```bash
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id pepijn223/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 20
```

View File

@@ -1,147 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
from PIL import Image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageEncoder:
"""Base class for image encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
raise NotImplementedError
class DinoV2Encoder(ImageEncoder):
"""DinoV2 image encoder.
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings.
Supports multiple model sizes (ViT-S/14, ViT-B/14, ViT-L/14).
"""
def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cuda", batch_size: int = 32):
super().__init__(device)
self.batch_size = batch_size
self.model_name = model_name
logger.info(f"Loading DinoV2 model: {model_name}")
self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614
self.model = self.model.to(self.device)
self.model.eval()
# DinoV2 preprocessing
from torchvision import transforms
self.transform = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
embeddings = []
with torch.inference_mode():
for i in range(0, len(images), self.batch_size):
batch_images = images[i : i + self.batch_size]
# Convert numpy arrays to PIL Images and apply transforms
pil_images = [Image.fromarray(img.astype(np.uint8)) for img in batch_images]
tensors = torch.stack([self.transform(img) for img in pil_images]).to(self.device)
# Get embeddings
batch_embeddings = self.model(tensors).cpu().numpy()
embeddings.append(batch_embeddings)
return np.concatenate(embeddings, axis=0)
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension based on model size."""
if "vits14" in self.model_name:
return 384 # DinoV2 ViT-S/14
elif "vitb14" in self.model_name:
return 768 # DinoV2 ViT-B/14
elif "vitl14" in self.model_name:
return 1024 # DinoV2 ViT-L/14
else:
return 768 # Default to ViT-B/14
class LanguageEncoder:
"""Base class for language encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
raise NotImplementedError
class MiniLMEncoder(LanguageEncoder):
"""MiniLM language encoder.
MiniLM is a lightweight sentence transformer model that produces high-quality text embeddings.
Supports L6 and L12 model sizes.
"""
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device: str = "cuda"):
super().__init__(device)
self.model_name = model_name
logger.info(f"Loading MiniLM model: {model_name}")
from transformers import AutoModel, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
def _mean_pooling(self, model_output, attention_mask):
"""Mean pooling to get sentence embeddings."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
with torch.inference_mode():
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
model_output = self.model(**encoded_input)
embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"])
return embeddings.cpu().numpy()
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension."""
return 384 # Both MiniLM-L6 and L12 output 384-dim embeddings

View File

@@ -1,329 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate embeddings for LeRobot datasets to make them more lightweight and efficient.
This script:
1. Loads a v3.0 LeRobot dataset from the hub
2. Computes embeddings for tasks (language commands) and frames (images)
3. Stores embeddings as new features in the dataset
4. Optionally removes video files to reduce size
5. Pushes the converted dataset to the hub
Current supported encoders:
- Image: DinoV2 (dinov2_vits14, dinov2_vitb14, dinov2_vitl14)
- Language: MiniLM (minilm-l6, minilm-l12)
The architecture is extensible - you can add more encoders by:
1. Creating a new encoder class inheriting from ImageEncoder or LanguageEncoder
2. Implementing the encode() method and embedding_dim property
3. Adding it to the get_image_encoder() or get_language_encoder() factory function
Usage example:
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id lerobot/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
"""
import argparse
import shutil
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import (
DinoV2Encoder,
ImageEncoder,
LanguageEncoder,
MiniLMEncoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
"""Factory function to get image encoder.
To add a new encoder:
1. Create a new class inheriting from ImageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
}
if encoder_name not in encoders:
raise ValueError(f"Unknown image encoder: {encoder_name}. Available options: {list(encoders.keys())}")
return encoders[encoder_name]()
def get_language_encoder(encoder_name: str, device: str = "cuda") -> LanguageEncoder:
"""Factory function to get language encoder.
To add a new encoder:
1. Create a new class inheriting from LanguageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"minilm-l6": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L6-v2", device=device
),
"minilm-l12": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L12-v2", device=device
),
}
if encoder_name not in encoders:
raise ValueError(
f"Unknown language encoder: {encoder_name}. Available options: {list(encoders.keys())}"
)
return encoders[encoder_name]()
def generate_embeddings_for_dataset(
repo_id: str,
output_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
remove_videos: bool = False,
local_dir: Path | None = None,
output_local_dir: Path | None = None,
push_to_hub: bool = False,
):
"""Generate embeddings for a LeRobot dataset.
Args:
repo_id: Source dataset repository ID
output_repo_id: Output dataset repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
remove_videos: Whether to remove video files
local_dir: Local directory for source dataset
output_local_dir: Local directory for output dataset
push_to_hub: Whether to push to hub after conversion
"""
from lerobot.datasets.dataset_tools import modify_features
print(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=True)
print(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
print("Computing task embeddings...")
unique_tasks = dataset.meta.tasks.index.tolist()
task_embeddings = {}
for task in tqdm(unique_tasks, desc="Encoding tasks"):
# Clean up task text
task_clean = task.strip().capitalize().strip(" .,!?-_")
embedding = language_encoder.encode([task_clean])[0]
task_embeddings[task] = embedding
print(f"Computed {len(task_embeddings)} task embeddings")
print("Processing frames and computing embeddings...")
all_task_embeddings = []
all_image_embeddings_dict = {cam_key: [] for cam_key in dataset.meta.camera_keys}
for frame_idx in tqdm(range(dataset.num_frames), desc="Processing frames"):
item = dataset.hf_dataset[frame_idx]
ep_idx = item["episode_index"].item()
task = dataset.meta.tasks.iloc[item["task_index"].item()].name
task_emb = task_embeddings[task]
all_task_embeddings.append(task_emb)
for cam_key in dataset.meta.camera_keys:
if cam_key in dataset.meta.video_keys:
current_ts = item["timestamp"].item()
video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx)
img = video_frames[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 4:
img = img[0] # (T, C, H, W) -> (C, H, W)
elif img.ndim != 3:
raise ValueError(f"Unexpected video frame shape {img.shape} for camera {cam_key}")
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
img_np = np.array(img)
else:
img = item[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 3:
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
else:
img_np = np.array(img)
all_image_embeddings_dict[cam_key].append(img_np)
print("Computing image embeddings...")
image_embeddings_dict = {}
for cam_key, images in all_image_embeddings_dict.items():
print(f" {cam_key}: {len(images)} images")
embeddings = image_encoder.encode(images)
image_embeddings_dict[cam_key] = embeddings
all_task_embeddings = np.array(all_task_embeddings)
for cam_key in dataset.meta.camera_keys:
image_embeddings_dict[cam_key] = np.array(image_embeddings_dict[cam_key])
img_emb_dim = image_encoder.embedding_dim
lang_emb_dim = language_encoder.embedding_dim
add_features_dict = {
"task_embedding": (
all_task_embeddings,
{"dtype": "float32", "shape": [lang_emb_dim], "names": None},
),
}
for cam_key in dataset.meta.camera_keys:
add_features_dict[f"{cam_key}_embedding"] = (
image_embeddings_dict[cam_key],
{"dtype": "float32", "shape": [img_emb_dim], "names": None},
)
print("Adding embeddings to dataset...")
remove_features_list = None
if remove_videos:
remove_features_list = dataset.meta.video_keys
output_dataset = modify_features(
dataset=dataset,
add_features=add_features_dict,
remove_features=remove_features_list,
output_dir=output_local_dir,
repo_id=output_repo_id,
)
if remove_videos:
print("Removing video files...")
videos_dir = output_dataset.root / "videos"
if videos_dir.exists():
shutil.rmtree(videos_dir)
print(f"Saved to: {output_dataset.root}")
if push_to_hub:
print(f"Pushing to hub: {output_repo_id}")
output_dataset.push_to_hub(push_videos=not remove_videos)
print("Done!")
def main():
parser = argparse.ArgumentParser(
description="Generate embeddings for LeRobot datasets",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic usage with default encoders (DinoV2 ViT-B/14 + MiniLM-L12)
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--push-to-hub
# Generate embeddings and remove videos
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--remove-videos \\
--push-to-hub
Available image encoders:
- dinov2_vits14: DinoV2 ViT-S/14 (384-dim, faster)
- dinov2_vitb14: DinoV2 ViT-B/14 (768-dim, recommended)
- dinov2_vitl14: DinoV2 ViT-L/14 (1024-dim, best quality)
Available language encoders:
- minilm-l6: MiniLM-L6-v2 (384-dim, faster)
- minilm-l12: MiniLM-L12-v2 (384-dim, recommended)
""",
)
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repository ID")
parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repository ID")
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--remove-videos",
action="store_true",
help="Remove video files after generating embeddings",
)
parser.add_argument("--local-dir", type=str, default=None, help="Local directory for source dataset")
parser.add_argument(
"--output-local-dir", type=str, default=None, help="Local directory for output dataset"
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the converted dataset to the hub",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Generate embeddings
generate_embeddings_for_dataset(
repo_id=args.repo_id,
output_repo_id=args.output_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
remove_videos=args.remove_videos,
local_dir=Path(args.local_dir) if args.local_dir else None,
output_local_dir=Path(args.output_local_dir) if args.output_local_dir else None,
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main()

View File

@@ -1,222 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Validate pre-computed embeddings against on-the-fly computed embeddings.
Usage:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id <your_username>/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 10
"""
import argparse
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import ImageEncoder, LanguageEncoder
from lerobot.datasets.generating_embeddings.generate_embeddings import (
get_image_encoder,
get_language_encoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def validate_embeddings(
original_repo_id: str,
embeddings_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
num_samples: int = 10,
device: str = "cuda",
):
"""Validate pre-computed embeddings against on-the-fly embeddings.
Args:
original_repo_id: Original dataset repository ID
embeddings_repo_id: Dataset with pre-computed embeddings repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
num_samples: Number of samples to validate
device: Device to use for encoding
"""
# Load both datasets
print("Loading datasets...")
original_dataset = LeRobotDataset(original_repo_id, download_videos=True)
embeddings_dataset = LeRobotDataset(embeddings_repo_id, download_videos=False)
# Verify both datasets have the same number of frames
assert original_dataset.num_frames == embeddings_dataset.num_frames, (
f"Frame count mismatch: original={original_dataset.num_frames}, "
f"embeddings={embeddings_dataset.num_frames}"
)
camera_keys = original_dataset.meta.camera_keys
# Check embedding features exist
expected_features = ["task_embedding"] + [f"{cam}_embedding" for cam in camera_keys]
for feat in expected_features:
if feat not in embeddings_dataset.features:
raise ValueError(f"Embedding feature not found: {feat}")
# Select random sample indices
sample_indices = np.random.choice(
original_dataset.num_frames, size=min(num_samples, original_dataset.num_frames), replace=False
)
print(f"Validating {len(sample_indices)} samples...")
# Track statistics
task_similarities = []
image_similarities = {cam: [] for cam in camera_keys}
for idx in tqdm(sample_indices, desc="Validating"):
idx = int(idx)
embeddings_item = embeddings_dataset[idx]
precomputed_task_emb = embeddings_item["task_embedding"].numpy()
precomputed_image_embs = {cam: embeddings_item[f"{cam}_embedding"].numpy() for cam in camera_keys}
original_item = original_dataset[idx]
# Get task and compute embedding
task = original_item["task"]
# Clean up task text (same as in generate_embeddings.py)
task_clean = task.strip().capitalize().strip(" .,!?-_")
onthefly_task_emb = language_encoder.encode([task_clean])[0]
# Get images and compute embeddings
onthefly_image_embs = {}
for cam in camera_keys:
img = original_item[cam]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
if img.ndim == 3: # (C, H, W)
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape: {img.shape}")
else:
img_np = np.array(img)
onthefly_image_embs[cam] = image_encoder.encode([img_np])[0]
# Task embedding comparison
task_sim = cosine_similarity(precomputed_task_emb, onthefly_task_emb)
task_similarities.append(task_sim)
# Image embedding comparison
for cam in camera_keys:
img_sim = cosine_similarity(precomputed_image_embs[cam], onthefly_image_embs[cam])
image_similarities[cam].append(img_sim)
# Results
print("\nResults:")
task_sim_threshold = 0.99
img_sim_threshold = 0.99
task_mean_sim = np.mean(task_similarities)
task_pass = task_mean_sim >= task_sim_threshold
print(f" Task: {task_mean_sim:.4f} {'' if task_pass else ''}")
for cam in camera_keys:
cam_mean_sim = np.mean(image_similarities[cam])
cam_pass = cam_mean_sim >= img_sim_threshold
print(f" {cam}: {cam_mean_sim:.4f} {'' if cam_pass else ''}")
image_pass = all(np.mean(image_similarities[cam]) >= img_sim_threshold for cam in camera_keys)
print()
if task_pass and image_pass:
print("✓ PASSED")
else:
print("✗ FAILED")
def main():
parser = argparse.ArgumentParser(
description="Validate and compare pre-computed embeddings with on-the-fly embeddings",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \\
--original-repo-id lerobot/utokyo_xarm_bimanual \\
--embeddings-repo-id lerobot/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--num-samples 20
""",
)
parser.add_argument("--original-repo-id", type=str, required=True, help="Original dataset repository ID")
parser.add_argument(
"--embeddings-repo-id",
type=str,
required=True,
help="Dataset with pre-computed embeddings repository ID",
)
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--num-samples",
type=int,
default=10,
help="Number of samples to validate (default: 10)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Validate embeddings
validate_embeddings(
original_repo_id=args.original_repo_id,
embeddings_repo_id=args.embeddings_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
num_samples=args.num_samples,
device=args.device,
)
if __name__ == "__main__":
main()

View File

@@ -430,7 +430,9 @@ class LeRobotDatasetMetadata:
video_keys = [video_key] if video_key is not None else self.video_keys
for key in video_keys:
if not self.features[key].get("info", None):
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
video_path = self.root / self.video_path.format(
video_key=video_key, chunk_index=0, file_index=0
)
self.info["features"][key]["info"] = get_video_info(video_path)
def update_chunk_settings(
@@ -684,7 +686,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer = None
self.writer = None
self.latest_episode = None
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
self.root.mkdir(exist_ok=True, parents=True)
@@ -707,20 +708,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (AssertionError, FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
self.hf_dataset = self.load_hf_dataset()
# Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded
# Build a mapping: absolute_index -> relative_index_in_filtered_dataset
self._absolute_to_relative_idx = None
if self.episodes is not None:
self._absolute_to_relative_idx = {
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
}
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
@@ -839,40 +830,31 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
hf_dataset = load_nested_dataset(self.root / "data", features=features)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes and their video files."""
"""Check if the cached dataset contains all requested episodes."""
if self.hf_dataset is None or len(self.hf_dataset) == 0:
return False
# Get available episode indices from cached dataset
available_episodes = {
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
for ep_idx in self.hf_dataset.unique("episode_index")
for ep_idx in self.hf_dataset["episode_index"]
}
# Determine requested episodes
if self.episodes is None:
# Requesting all episodes - check if we have all episodes from metadata
requested_episodes = set(range(self.meta.total_episodes))
else:
# Requesting specific episodes
requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data
if not requested_episodes.issubset(available_episodes):
return False
# Check if all required video files exist
if len(self.meta.video_keys) > 0:
for ep_idx in requested_episodes:
for vid_key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
return False
return True
return requested_episodes.issubset(available_episodes)
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
@@ -939,11 +921,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_timestamps = {}
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
if self._absolute_to_relative_idx is not None:
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
timestamps = self.hf_dataset[relative_indices]["timestamp"]
else:
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
@@ -951,32 +929,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
"""
Query dataset for indices across keys, skipping video keys.
Tries column-first [key][indices] for speed, falls back to row-first.
Args:
query_indices: Dict mapping keys to index lists to retrieve
Returns:
Dict with stacked tensors of queried data (video keys excluded)
"""
result: dict = {}
for key, q_idx in query_indices.items():
if key in self.meta.video_keys:
continue
# Map absolute indices to relative indices if needed
relative_indices = (
q_idx
if self._absolute_to_relative_idx is None
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
)
try:
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
except (KeyError, TypeError, IndexError):
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
return result
return {
key: torch.stack(self.hf_dataset[q_idx][key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
@@ -1274,7 +1231,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Initialize indices and frame count for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
global_frame_index = 0
self._current_file_start_frame = 0
# However, if the episodes already exists
# It means we are resuming recording, so we need to load the latest episode
# Update the indices to avoid overwriting the latest episode
@@ -1286,7 +1242,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# When resuming, move to the next file
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
self._current_file_start_frame = global_frame_index
else:
# Retrieve information from the latest parquet file
latest_ep = self.latest_episode
@@ -1297,7 +1252,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_file_size_in_mb(latest_path)
frames_in_current_file = global_frame_index - self._current_file_start_frame
frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
av_size_per_frame = (
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
)
@@ -1311,7 +1266,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
self._close_writer()
self._writer_closed_for_reading = False
self._current_file_start_frame = global_frame_index
ep_dict["data/chunk_index"] = chunk_idx
ep_dict["data/file_index"] = file_idx
@@ -1515,11 +1469,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj._absolute_to_relative_idx = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.writer = None
obj.latest_episode = None
obj._current_file_start_frame = None
# Initialize tracking for incremental recording
obj._lazy_loading = False
obj._recorded_frames = 0

View File

@@ -1,151 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SARM Temporal Sampler for reward model training.
Samples frames uniformly from episodes for SARM's 9-frame symmetric pattern:
- 1 initial frame + 4 frames before + current + 3 frames after
Boundary handling: clamp to first/last frame when indices go out of bounds.
This enables truly uniform sampling across entire episodes.
"""
import logging
from typing import Iterator, Optional
import numpy as np
import torch
from torch.utils.data import Sampler
import random
class SARMTemporalSampler(Sampler):
"""
Temporal sampler for SARM reward model training with symmetric/bidirectional sampling.
SARM uses 9 frames per sample:
- Frame 0: Initial frame of the episode (always frame 0)
- Frames 1-8: Symmetric context around current frame
Pattern: [t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
Boundary handling:
- Early frames: backward indices clamp to 0 (e.g., [0,0,0,5,35,65,95,125])
- Late frames: forward indices clamp to last frame (e.g., [850,880,910,940,970,1000,1000,1000])
This enables truly uniform sampling across entire episodes.
Args:
dataset_from_index: Start indices of episodes (global dataset indices)
dataset_to_index: End indices of episodes (global dataset indices)
frame_gap: Gap between consecutive frames (default: 30 = 1 second at 30fps)
shuffle: Whether to shuffle sampling order
seed: Random seed for reproducibility
samples_per_epoch: Number of samples per epoch (default: 6400)
min_episode_length: Minimum episode length to include (default: 1)
"""
def __init__(
self,
dataset_from_index: np.ndarray,
dataset_to_index: np.ndarray,
frame_gap: int = 30,
shuffle: bool = True,
seed: Optional[int] = None,
samples_per_epoch: int = 6400,
min_episode_length: int = 1,
):
self.dataset_from_index = np.array(dataset_from_index)
self.dataset_to_index = np.array(dataset_to_index)
self.frame_gap = frame_gap
self.shuffle = shuffle
self.samples_per_epoch = samples_per_epoch
self.min_episode_length = min_episode_length
if seed is not None:
self.seed = seed
random.seed(seed)
np.random.seed(seed)
self.generator = torch.Generator().manual_seed(seed)
else:
self.generator = torch.Generator()
# Compute valid episodes and sampling positions (ALL frames for uniform sampling)
self._compute_valid_positions()
logging.info(
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
f"{len(self.all_valid_positions)} positions (uniform sampling), "
f"{self.samples_per_epoch} samples per epoch, "
f"frame_gap={frame_gap}, symmetric bidirectional pattern"
)
def _compute_valid_positions(self):
"""Compute valid episodes and ALL sampling positions for uniform sampling.
With symmetric bidirectional sampling, we can sample from ANY frame:
- Early frames: backward indices clamp to first frame
- Late frames: forward indices clamp to last frame
"""
self.valid_episodes = []
self.all_valid_positions = []
for ep_idx in range(len(self.dataset_from_index)):
ep_start = self.dataset_from_index[ep_idx]
ep_end = self.dataset_to_index[ep_idx]
episode_length = ep_end - ep_start
# Include all episodes with at least min_episode_length frames
if episode_length >= self.min_episode_length:
self.valid_episodes.append((ep_idx, ep_start, ep_end))
# Include ALL positions in the episode (truly uniform sampling)
for pos in range(ep_start, ep_end):
self.all_valid_positions.append(pos)
self.valid_episodes = np.array(self.valid_episodes)
self.all_valid_positions = np.array(self.all_valid_positions)
if len(self.all_valid_positions) == 0:
raise ValueError(
f"No valid sampling positions found! "
f"Check that episodes have at least {self.min_episode_length} frames."
)
def __len__(self) -> int:
return self.samples_per_epoch
def __iter__(self) -> Iterator[int]:
"""
Yields global dataset indices for uniform sampling across episodes.
Each yielded index represents the "current frame" position.
The dataset's observation_delta_indices then handles loading:
- Frame 0: Episode initial frame (via large negative delta clamping)
- Frames 1-8: Symmetric context around current frame (with boundary clamping)
For early frames: backward indices clamp to first frame (progress ~0%)
For late frames: forward indices clamp to last frame (progress ~100%)
"""
if self.shuffle:
# Randomly sample from all valid positions
for _ in range(self.samples_per_epoch):
idx = np.random.randint(0, len(self.all_valid_positions))
yield int(self.all_valid_positions[idx])
else:
# Sequential sampling with wrap-around
for i in range(self.samples_per_epoch):
idx = i % len(self.all_valid_positions)
yield int(self.all_valid_positions[idx])

View File

@@ -206,11 +206,6 @@ class ImageTransformsConfig:
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
"affine": ImageTransformConfig(
weight=1.0,
type="RandomAffine",
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
),
}
)
@@ -222,8 +217,6 @@ def make_transform_from_config(cfg: ImageTransformConfig):
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")

View File

@@ -28,7 +28,6 @@ import numpy as np
import packaging.version
import pandas
import pandas as pd
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from datasets import Dataset
@@ -104,9 +103,7 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
return chunk_idx, file_idx
def load_nested_dataset(
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
) -> Dataset:
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
Concatenate all pyarrow references to return HF Dataset format
@@ -114,26 +111,15 @@ def load_nested_dataset(
Args:
pq_dir: Directory containing parquet files
features: Optional features schema to ensure consistent loading of complex types like images
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
"""
paths = sorted(pq_dir.glob("*/*.parquet"))
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars():
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
# PyArrow loads the entire dataset into memory
if episodes is None:
return Dataset.from_parquet([str(path) for path in paths], features=features)
arrow_dataset = pa_ds.dataset(paths, format="parquet")
filter_expr = pa_ds.field("episode_index").isin(episodes)
table = arrow_dataset.to_table(filter=filter_expr)
if features is not None:
table = table.cast(features.arrow_schema)
return Dataset(table)
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
return datasets
def get_parquet_num_frames(parquet_path: str | Path) -> int:

View File

@@ -98,7 +98,7 @@ OLD
videos/chunk-000/CAMERA/episode_000000.mp4
NEW
videos/CAMERA/chunk-000/file_000.mp4
videos/chunk-000/file_000.mp4
-------------------------
OLD
episodes.jsonl

View File

@@ -342,8 +342,8 @@ def encode_video_frames(
# Define video output frame size (assuming all input frames are the same size)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size
dummy_image = Image.open(input_list[0])
width, height = dummy_image.size
# Define video codec options
video_options = {}
@@ -373,12 +373,11 @@ def encode_video_frames(
# Loop through input frames and encode them
for input_data in input_list:
with Image.open(input_data) as input_image:
input_image = input_image.convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
input_image = Image.open(input_data).convert("RGB")
input_frame = av.VideoFrame.from_image(input_image)
packet = output_stream.encode(input_frame)
if packet:
output.mux(packet)
# Flush the encoder
packet = output_stream.encode()

View File

@@ -21,22 +21,7 @@ import draccus
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.constants import (
ACTION,
LIBERO_KEY_EEF_MAT,
LIBERO_KEY_EEF_POS,
LIBERO_KEY_EEF_QUAT,
LIBERO_KEY_GRIPPER_QPOS,
LIBERO_KEY_GRIPPER_QVEL,
LIBERO_KEY_JOINTS_POS,
LIBERO_KEY_JOINTS_VEL,
LIBERO_KEY_PIXELS_AGENTVIEW,
LIBERO_KEY_PIXELS_EYE_IN_HAND,
OBS_ENV_STATE,
OBS_IMAGE,
OBS_IMAGES,
OBS_STATE,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
@dataclass
@@ -52,16 +37,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def type(self) -> str:
return self.get_choice_name(self.__class__)
@property
def package_name(self) -> str:
"""Package name to import if environment not found in gym registry"""
return f"gym_{self.type}"
@property
def gym_id(self) -> str:
"""ID string used in gym.make() to instantiate the environment"""
return f"{self.package_name}/{self.task}"
@property
@abc.abstractmethod
def gym_kwargs(self) -> dict:
@@ -261,61 +236,28 @@ class LiberoEnv(EnvConfig):
features_map: dict[str, str] = field(
default_factory=lambda: {
ACTION: ACTION,
LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
"agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
}
)
def __post_init__(self):
if self.obs_type == "pixels":
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
elif self.obs_type == "pixels_agent_pos":
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
self.features["pixels/agentview_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
)
self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(3,),
)
self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(4,),
)
self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
type=FeatureType.STATE,
shape=(3, 3),
)
self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(2,),
)
self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
type=FeatureType.STATE,
shape=(7,),
)
else:
raise ValueError(f"Unsupported obs_type: {self.obs_type}")

View File

@@ -14,16 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Any
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
from lerobot.processor import ProcessorStep
from lerobot.processor.env_processor import LiberoProcessorStep
from lerobot.processor.pipeline import PolicyProcessorPipeline
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -37,60 +31,16 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
raise ValueError(f"Policy type '{env_type}' is not available.")
def make_env_pre_post_processors(
env_cfg: EnvConfig,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
]:
"""
Create preprocessor and postprocessor pipelines for environment observations.
This function creates processor pipelines that transform raw environment
observations and actions. By default, it returns identity processors that do nothing.
For specific environments like LIBERO, it adds environment-specific processing steps.
Args:
env_cfg: The configuration of the environment.
Returns:
A tuple containing:
- preprocessor: Pipeline that processes environment observations
- postprocessor: Pipeline that processes environment outputs (currently identity)
"""
# Preprocessor and Postprocessor steps are Identity for most environments
preprocessor_steps: list[ProcessorStep] = []
postprocessor_steps: list[ProcessorStep] = []
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
preprocessor_steps.append(LiberoProcessorStep())
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
return preprocessor, postprocessor
def make_env(
cfg: EnvConfig | str,
n_envs: int = 1,
use_async_envs: bool = False,
hub_cache_dir: str | None = None,
trust_remote_code: bool = False,
cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config or Hub reference.
"""Makes a gym vector environment according to the config.
Args:
cfg (EnvConfig | str): Either an `EnvConfig` object describing the environment to build locally,
or a Hugging Face Hub repository identifier (e.g. `"username/repo"`). In the latter case,
the repo must include a Python file (usually `env.py`).
cfg (EnvConfig): the config of the environment to instantiate.
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
use_async_envs (bool, optional): Whether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
False.
hub_cache_dir (str | None): Optional cache path for downloaded hub files.
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
Default False — must be set to True to import/exec hub `env.py`.
Raises:
ValueError: if n_envs < 1
@@ -103,21 +53,6 @@ def make_env(
- For single-task environments: a single suite entry (cfg.type) with task_id=0.
"""
# if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py")
# simplified: only support hub-provided `make_env`
if isinstance(cfg, str):
# _download_hub_file will raise the same RuntimeError if trust_remote_code is False
repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir)
# import and surface clear import errors
module = _import_hub_module(local_file, repo_id)
# call the hub-provided make_env
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
# normalize the return into {suite: {task_id: vec_env}}
return _normalize_hub_result(raw_result)
if n_envs < 1:
raise ValueError("`n_envs` must be at least 1")
@@ -149,24 +84,17 @@ def make_env(
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
if cfg.gym_id not in gym_registry:
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
try:
importlib.import_module(cfg.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if cfg.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
)
gym_handle = f"{package_name}/{cfg.task}"
def _make_one():
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)

View File

@@ -28,6 +28,7 @@ import torch
from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.transform_utils import quat2axisangle
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
@@ -174,36 +175,11 @@ class LiberoEnv(gym.Env):
self.observation_space = spaces.Dict(
{
"pixels": spaces.Dict(images),
"robot_state": spaces.Dict(
{
"eef": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
"quat": spaces.Box(
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
),
"mat": spaces.Box(
low=-np.inf, high=np.inf, shape=(3, 3), dtype=np.float64
),
}
),
"gripper": spaces.Dict(
{
"qpos": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
"qvel": spaces.Box(
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
),
}
),
"joints": spaces.Dict(
{
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
}
),
}
"agent_pos": spaces.Box(
low=AGENT_POS_LOW,
high=AGENT_POS_HIGH,
shape=(OBS_STATE_DIM,),
dtype=np.float64,
),
}
)
@@ -215,7 +191,6 @@ class LiberoEnv(gym.Env):
def render(self):
raw_obs = self._env.env._get_observations()
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
image = image[::-1, ::-1] # flip both H and W for visualization
return image
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
@@ -237,48 +212,23 @@ class LiberoEnv(gym.Env):
images = {}
for camera_name in self.camera_name:
image = raw_obs[camera_name]
image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image
eef_pos = raw_obs.get("robot0_eef_pos")
eef_quat = raw_obs.get("robot0_eef_quat")
# rotation matrix from controller
eef_mat = self._env.robots[0].controller.ee_ori_mat if eef_pos is not None else None
gripper_qpos = raw_obs.get("robot0_gripper_qpos")
gripper_qvel = raw_obs.get("robot0_gripper_qvel")
joint_pos = raw_obs.get("robot0_joint_pos")
joint_vel = raw_obs.get("robot0_joint_vel")
obs = {
"pixels": images,
"robot_state": {
"eef": {
"pos": eef_pos, # (3,)
"quat": eef_quat, # (4,)
"mat": eef_mat, # (3, 3)
},
"gripper": {
"qpos": gripper_qpos, # (2,)
"qvel": gripper_qvel, # (2,)
},
"joints": {
"pos": joint_pos, # (7,)
"vel": joint_vel, # (7,)
},
},
}
state = np.concatenate(
(
raw_obs["robot0_eef_pos"],
quat2axisangle(raw_obs["robot0_eef_quat"]),
raw_obs["robot0_gripper_qpos"],
)
)
agent_pos = state
if self.obs_type == "pixels":
return {"pixels": images.copy()}
if self.obs_type == "pixels_agent_pos":
# Validate required fields are present
if eef_pos is None or eef_quat is None or gripper_qpos is None:
raise ValueError(
f"Missing required robot state fields in raw observation. "
f"Got eef_pos={eef_pos is not None}, eef_quat={eef_quat is not None}, "
f"gripper_qpos={gripper_qpos is not None}"
)
return obs
return {
"pixels": images.copy(),
"agent_pos": agent_pos,
}
raise NotImplementedError(
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
@@ -405,10 +355,12 @@ def create_libero_envs(
print(f"Restricting to task_ids={task_ids_filter}")
out: dict[str, dict[int, Any]] = defaultdict(dict)
for suite_name in suite_names:
suite = _get_suite(suite_name)
total = len(suite.tasks)
selected = _select_task_ids(total, task_ids_filter)
if not selected:
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")

View File

@@ -13,8 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import os
import warnings
from collections.abc import Mapping, Sequence
from functools import singledispatch
@@ -24,27 +22,14 @@ import einops
import gymnasium as gym
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.utils import get_channel_first_image_shape
def _convert_nested_dict(d):
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _convert_nested_dict(v)
elif isinstance(v, np.ndarray):
result[k] = torch.from_numpy(v)
else:
result[k] = v
return result
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
@@ -90,14 +75,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[OBS_ENV_STATE] = env_state
if "agent_pos" in observations:
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations[OBS_STATE] = agent_pos
if "robot_state" in observations:
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
return return_observations
@@ -212,132 +195,3 @@ def _(envs: Sequence) -> None:
@close_envs.register
def _(env: gym.Env) -> None:
_close_single_env(env)
# helper to safely load a python file as a module
def _load_module_from_path(path: str, module_name: str | None = None):
module_name = module_name or f"hub_env_{os.path.basename(path).replace('.', '_')}"
spec = importlib.util.spec_from_file_location(module_name, path)
if spec is None:
raise ImportError(f"Could not load module spec for {module_name} from {path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
# helper to parse hub string (supports "user/repo", "user/repo@rev", optional path)
# examples:
# "user/repo" -> will look for env.py at repo root
# "user/repo@main:envs/my_env.py" -> explicit revision and path
def _parse_hub_url(hub_uri: str):
# very small parser: [repo_id][@revision][:path]
# repo_id is required (user/repo or org/repo)
revision = None
file_path = "env.py"
if "@" in hub_uri:
repo_and_rev, *rest = hub_uri.split(":", 1)
repo_id, rev = repo_and_rev.split("@", 1)
revision = rev
if rest:
file_path = rest[0]
else:
repo_id, *rest = hub_uri.split(":", 1)
if rest:
file_path = rest[0]
return repo_id, revision, file_path
def _download_hub_file(
cfg_str: str,
trust_remote_code: bool,
hub_cache_dir: str | None,
) -> tuple[str, str, str, str]:
"""
Parse `cfg_str` (hub URL), enforce `trust_remote_code`, and return
(repo_id, file_path, local_file, revision).
"""
if not trust_remote_code:
raise RuntimeError(
f"Refusing to execute remote code from the Hub for '{cfg_str}'. "
"Executing hub env modules runs arbitrary Python code from third-party repositories. "
"If you trust this repo and understand the risks, call `make_env(..., trust_remote_code=True)` "
"and prefer pinning to a specific revision: 'user/repo@<commit-hash>:env.py'."
)
repo_id, revision, file_path = _parse_hub_url(cfg_str)
try:
local_file = hf_hub_download(
repo_id=repo_id, filename=file_path, revision=revision, cache_dir=hub_cache_dir
)
except Exception as e:
# fallback to snapshot download
snapshot_dir = snapshot_download(repo_id=repo_id, revision=revision, cache_dir=hub_cache_dir)
local_file = os.path.join(snapshot_dir, file_path)
if not os.path.exists(local_file):
raise FileNotFoundError(
f"Could not find {file_path} in repository {repo_id}@{revision or 'main'}"
) from e
return repo_id, file_path, local_file, revision
def _import_hub_module(local_file: str, repo_id: str) -> Any:
"""
Import the downloaded file as a module and surface helpful import error messages.
"""
module_name = f"hub_env_{repo_id.replace('/', '_')}"
try:
module = _load_module_from_path(local_file, module_name=module_name)
except ModuleNotFoundError as e:
missing = getattr(e, "name", None) or str(e)
raise ModuleNotFoundError(
f"Hub env '{repo_id}:{os.path.basename(local_file)}' failed to import because the dependency "
f"'{missing}' is not installed locally.\n\n"
) from e
except ImportError as e:
raise ImportError(
f"Failed to load hub env module '{repo_id}:{os.path.basename(local_file)}'. Import error: {e}\n\n"
) from e
return module
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any:
"""
Ensure module exposes make_env and call it.
"""
if not hasattr(module, "make_env"):
raise AttributeError(
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
)
entry_fn = module.make_env
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""
Normalize possible return types from hub `make_env` into the mapping:
{ suite_name: { task_id: vector_env } }
Accepts:
- dict (assumed already correct)
- gym.vector.VectorEnv
- gym.Env (will be wrapped into SyncVectorEnv)
"""
if isinstance(result, dict):
return result
# VectorEnv: use its spec.id if available
if isinstance(result, gym.vector.VectorEnv):
suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
return {suite_name: {0: result}}
# Single Env: wrap into SyncVectorEnv
if isinstance(result, gym.Env):
vec = gym.vector.SyncVectorEnv([lambda: result])
suite_name = getattr(result, "spec", None) and getattr(result.spec, "id", None) or "hub_env"
return {suite_name: {0: vec}}
raise ValueError(
"Hub `make_env` must return either a mapping {suite: {task_id: vec_env}}, "
"a gym.vector.VectorEnv, or a single gym.Env."
)

View File

@@ -22,18 +22,18 @@ class RobotKinematics:
self,
urdf_path: str,
target_frame_name: str = "gripper_frame_link",
joint_names: list[str] | None = None,
joint_names: list[str] = None,
):
"""
Initialize placo-based kinematics solver.
Args:
urdf_path (str): Path to the robot URDF file
target_frame_name (str): Name of the end-effector frame in the URDF
joint_names (list[str] | None): List of joint names to use for the kinematics solver
urdf_path: Path to the robot URDF file
target_frame_name: Name of the end-effector frame in the URDF
joint_names: List of joint names to use for the kinematics solver
"""
try:
import placo # type: ignore[import-not-found] # C++ library with Python bindings, no type stubs available. TODO: Create stub file or request upstream typing support.
import placo
except ImportError as e:
raise ImportError(
"placo is required for RobotKinematics. "
@@ -52,7 +52,7 @@ class RobotKinematics:
# Initialize frame task for IK
self.tip_frame = self.solver.add_frame_task(self.target_frame_name, np.eye(4))
def forward_kinematics(self, joint_pos_deg: np.ndarray) -> np.ndarray:
def forward_kinematics(self, joint_pos_deg):
"""
Compute forward kinematics for given joint configuration given the target frame name in the constructor.
@@ -77,12 +77,8 @@ class RobotKinematics:
return self.robot.get_T_world_frame(self.target_frame_name)
def inverse_kinematics(
self,
current_joint_pos: np.ndarray,
desired_ee_pose: np.ndarray,
position_weight: float = 1.0,
orientation_weight: float = 0.01,
) -> np.ndarray:
self, current_joint_pos, desired_ee_pose, position_weight=1.0, orientation_weight=0.01
):
"""
Compute inverse kinematics using placo solver.

View File

@@ -60,7 +60,7 @@ class OperatingMode(Enum):
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
# conveyor systems or a system that requires an additional reduction gear. Note that Max Position
# conveyer systems or a system that requires an additional reduction gear. Note that Max Position
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
EXTENDED_POSITION = 4

View File

@@ -206,12 +206,8 @@ MODEL_BAUDRATE_TABLE = {
# Sign-Magnitude encoding bits
STS_SMS_SERIES_ENCODINGS_TABLE = {
"Homing_Offset": 11,
"Goal_Position": 15,
"Goal_Velocity": 15,
"Goal_Speed": 15,
"Present_Position": 15,
"Present_Velocity": 15,
"Present_Speed": 15,
}
MODEL_ENCODING_TABLE = {

View File

@@ -14,7 +14,6 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .groot.configuration_groot import GrootConfig as GrootConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
@@ -30,5 +29,4 @@ __all__ = [
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
"GrootConfig",
]

View File

@@ -626,8 +626,8 @@ class ACTDecoderLayer(nn.Module):
x: (Decoder Sequence, Batch, Channel) tensor of input tokens.
encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are
cross-attending with.
encoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
decoder_pos_embed: (DS, 1, C) positional embedding for the queries (from the decoder).
decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).
encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).
Returns:
(DS, B, C) tensor of decoder output features.
"""

View File

@@ -30,16 +30,13 @@ from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
@@ -104,14 +101,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "sarm":
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot":
from lerobot.policies.groot.modeling_groot import GrootPolicy
return GrootPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -153,8 +142,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -212,27 +199,6 @@ def make_pre_post_processors(
policy configuration type.
"""
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
# GROOT handles normalization in groot_pack_inputs_v3 step
# Need to override both stats AND normalize_min_max since saved config might be empty
preprocessor_overrides = {}
postprocessor_overrides = {}
preprocessor_overrides["groot_pack_inputs_v3"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
}
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features["action"].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
"env_action_dim": env_action_dim,
}
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
return (
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
@@ -327,22 +293,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SARMConfig):
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig):
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
processors = make_groot_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
@@ -353,7 +303,6 @@ def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
rename_map: dict[str, str] | None = None,
) -> PreTrainedPolicy:
"""
Instantiate a policy model.
@@ -370,8 +319,6 @@ def make_policy(
statistics for normalization layers.
env_cfg: Environment configuration used to infer feature shapes and types.
One of `ds_meta` or `env_cfg` must be provided.
rename_map: Optional mapping of dataset or environment feature keys to match
expected policy feature names (e.g., `"left"` → `"camera1"`).
Returns:
An instantiated and device-placed policy model.
@@ -413,18 +360,9 @@ def make_policy(
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
if not cfg.output_features:
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
if ds_meta is not None and hasattr(ds_meta, 'stats'):
kwargs["dataset_stats"] = ds_meta.stats
if ds_meta is not None:
kwargs["dataset_meta"] = ds_meta
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
@@ -440,8 +378,4 @@ def make_policy(
# policy = torch.compile(policy, mode="reduce-overhead")
if not rename_map:
validate_visual_features_consistency(cfg, features)
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
return policy

View File

@@ -1 +0,0 @@
../../../../docs/source/policy_groot_README.md

View File

@@ -1,54 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
def swish(x):
return x * torch.sigmoid(x)
class SinusoidalPositionalEncoding(nn.Module):
"""
Produces a sinusoidal encoding of shape (B, T, w)
given timesteps of shape (B, T).
"""
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps):
# timesteps: shape (B, T)
# We'll compute sin/cos frequencies across dim T
timesteps = timesteps.float() # ensure float
b, t = timesteps.shape
device = timesteps.device
half_dim = self.embedding_dim // 2
# typical log space frequencies for sinusoidal encoding
exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
# Expand timesteps to (B, T, 1) then multiply
freqs = timesteps.unsqueeze(-1) * exponent.exp() # (B, T, half_dim)
sin = torch.sin(freqs)
cos = torch.cos(freqs)
enc = torch.cat([sin, cos], dim=-1) # (B, T, w)
return enc

View File

@@ -1,370 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F # noqa: N812
from diffusers import ConfigMixin, ModelMixin
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import (
SinusoidalPositionalEmbedding,
TimestepEmbedding,
Timesteps,
)
from torch import nn
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(self, timesteps):
dtype = next(self.parameters()).dtype
timesteps_proj = self.time_proj(timesteps).to(dtype)
timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
return timesteps_emb
class AdaLayerNorm(nn.Module):
def __init__(
self,
embedding_dim: int,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
chunk_dim: int = 0,
):
super().__init__()
self.chunk_dim = chunk_dim
output_dim = embedding_dim * 2
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
temb: torch.Tensor | None = None,
) -> torch.Tensor:
temb = self.linear(self.silu(temb))
scale, shift = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: int | None = None,
activation_fn: str = "geglu",
attention_bias: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: str | None = None,
num_positional_embeddings: int | None = None,
ff_inner_dim: int | None = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.norm_type = norm_type
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embeddings` type is defined, `num_positional_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
if final_dropout:
self.final_dropout = nn.Dropout(dropout)
else:
self.final_dropout = None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
temb: torch.LongTensor | None = None,
) -> torch.Tensor:
# 0. Self-Attention
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, temb)
else:
norm_hidden_states = self.norm1(hidden_states)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
)
if self.final_dropout:
attn_output = self.final_dropout(attn_output)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class DiT(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
output_dim: int = 26,
num_layers: int = 12,
dropout: float = 0.1,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: int | None = 1000,
upcast_attention: bool = False,
norm_type: str = "ada_norm",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
max_num_positional_embeddings: int = 512,
compute_dtype=torch.float32,
final_dropout: bool = True,
positional_embeddings: str | None = "sinusoidal",
interleave_self_attention=False,
cross_attention_dim: int | None = None,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
# Timestep encoder
self.timestep_encoder = TimestepEncoder(
embedding_dim=self.inner_dim, compute_dtype=self.config.compute_dtype
)
all_blocks = []
for idx in range(self.config.num_layers):
use_self_attn = idx % 2 == 1 and interleave_self_attention
curr_cross_attention_dim = cross_attention_dim if not use_self_attn else None
all_blocks += [
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=self.config.norm_elementwise_affine,
norm_eps=self.config.norm_eps,
positional_embeddings=positional_embeddings,
num_positional_embeddings=self.config.max_num_positional_embeddings,
final_dropout=final_dropout,
cross_attention_dim=curr_cross_attention_dim,
)
]
self.transformer_blocks = nn.ModuleList(all_blocks)
# Output blocks
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
print(
"Total number of DiT parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def forward(
self,
hidden_states: torch.Tensor, # Shape: (B, T, D)
encoder_hidden_states: torch.Tensor, # Shape: (B, S, D)
timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
return_all_hidden_states: bool = False,
):
# Encode timesteps
temb = self.timestep_encoder(timestep)
# Process through transformer blocks - single pass through the blocks
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
all_hidden_states = [hidden_states]
# Process through transformer blocks
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1 and self.config.interleave_self_attention:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
temb=temb,
)
all_hidden_states.append(hidden_states)
# Output processing
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
if return_all_hidden_states:
return self.proj_out_2(hidden_states), all_hidden_states
else:
return self.proj_out_2(hidden_states)
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
output_dim: int = 26,
num_layers: int = 12,
dropout: float = 0.1,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: int | None = 1000,
upcast_attention: bool = False,
max_num_positional_embeddings: int = 512,
compute_dtype=torch.float32,
final_dropout: bool = True,
positional_embeddings: str | None = "sinusoidal",
interleave_self_attention=False,
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.gradient_checkpointing = False
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,
dropout=self.config.dropout,
activation_fn=self.config.activation_fn,
attention_bias=self.config.attention_bias,
upcast_attention=self.config.upcast_attention,
positional_embeddings=positional_embeddings,
num_positional_embeddings=self.config.max_num_positional_embeddings,
final_dropout=final_dropout,
)
for _ in range(self.config.num_layers)
]
)
print(
"Total number of SelfAttentionTransformer parameters: ",
sum(p.numel() for p in self.parameters() if p.requires_grad),
)
def forward(
self,
hidden_states: torch.Tensor, # Shape: (B, T, D)
return_all_hidden_states: bool = False,
):
# Process through transformer blocks - single pass through the blocks
hidden_states = hidden_states.contiguous()
all_hidden_states = [hidden_states]
# Process through transformer blocks
for _idx, block in enumerate(self.transformer_blocks):
hidden_states = block(hidden_states)
all_hidden_states.append(hidden_states)
if return_all_hidden_states:
return hidden_states, all_hidden_states
else:
return hidden_states

View File

@@ -1,406 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
else:
PretrainedConfig = object
BatchFeature = None
from lerobot.policies.groot.action_head.action_encoder import (
SinusoidalPositionalEncoding,
swish,
)
from .cross_attention_dit import DiT, SelfAttentionTransformer
class CategorySpecificLinear(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim):
super().__init__()
self.num_categories = num_categories
# For each category, we have separate weights and biases.
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x, cat_ids):
selected_w = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
def __init__(self, num_categories, input_dim, hidden_dim, output_dim):
super().__init__()
self.num_categories = num_categories
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x, cat_ids):
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class MultiEmbodimentActionEncoder(nn.Module):
def __init__(self, action_dim, hidden_size, num_embodiments):
super().__init__()
self.hidden_size = hidden_size
self.num_embodiments = num_embodiments
# W1: R^{w x d}, W2: R^{w x 2w}, W3: R^{w x w}
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size) # (d -> w)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size) # (2w -> w)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size) # (w -> w)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions, timesteps, cat_ids):
"""
actions: shape (B, T, action_dim)
timesteps: shape (B,) -- a single scalar per batch item
cat_ids: shape (B,)
returns: shape (B, T, hidden_size)
"""
b, t, _ = actions.shape
# 1) Expand each batch's single scalar time 'tau' across all T steps
# so that shape => (B, T)
# e.g. if timesteps is (B,), replicate across T
if timesteps.dim() == 1 and timesteps.shape[0] == b:
# shape (B,) => (B,T)
timesteps = timesteps.unsqueeze(1).expand(-1, t)
else:
raise ValueError("Expected `timesteps` to have shape (B,) so we can replicate across T.")
# 2) Standard action MLP step for shape => (B, T, w)
a_emb = self.W1(actions, cat_ids)
# 3) Get the sinusoidal encoding (B, T, w)
tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype)
# 4) Concat along last dim => (B, T, 2w), then W2 => (B, T, w), swish
x = torch.cat([a_emb, tau_emb], dim=-1)
x = swish(self.W2(x, cat_ids))
# 5) Finally W3 => (B, T, w)
x = self.W3(x, cat_ids)
return x
@dataclass
class FlowmatchingActionHeadConfig(PretrainedConfig):
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
add_pos_embed: bool = field(default=True, metadata={"help": "Whether to add positional embedding"})
model_dtype: str = field(default="float32", metadata={"help": "Model data type."})
diffusion_model_cfg: dict = field(default=None, metadata={"help": "Diffusion model configuration."})
input_embedding_dim: int = field(default=1536, metadata={"help": "Input embedding channel dimension."})
backbone_embedding_dim: int = field(
default=1536, metadata={"help": "Backbone embedding channel dimension."}
)
hidden_size: int = field(default=1024, metadata={"help": "Input embedding dimension."})
max_seq_len: int = field(default=1024, metadata={"help": "Maximum Sequence Length"})
action_dim: int = field(default=None, metadata={"help": "Action dimension."})
action_horizon: int = field(default=None, metadata={"help": "Action horizon."})
noise_beta_alpha: float = field(default=1.5, metadata={"help": ""})
noise_beta_beta: float = field(default=1.0, metadata={"help": ""})
noise_s: float = field(default=0.999, metadata={"help": "Flow matching noise Beta distribution s."})
num_timestep_buckets: int = field(
default=1000, metadata={"help": "Number of timestep discretization buckets."}
)
num_inference_timesteps: int = field(
default=None,
metadata={"help": "Number of inference steps for noise diffusion."},
)
max_num_embodiments: int = field(default=32, metadata={"help": "Number of embodiments."})
tune_projector: bool = field(default=True, metadata={"help": "Whether to tune the projector."})
tune_diffusion_model: bool = field(
default=True, metadata={"help": "Whether to tune the diffusion model."}
)
load_pretrained_det_decode_layer_path: str = field(
default=None, metadata={"help": "Path to pretrained detection model."}
)
detection_coeff: float = field(default=1.0, metadata={"help": "Detection coefficient."})
freeze_decode_layer: bool = field(default=False)
expand_batch: int = field(default=None)
use_vlln: bool = field(default=True)
vl_self_attention_cfg: dict = field(default=None)
num_target_vision_tokens: int = field(default=32, metadata={"help": "Number of target vision tokens."})
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
class FlowmatchingActionHead(nn.Module):
config_class = FlowmatchingActionHeadConfig
supports_gradient_checkpointing = True
def __init__(
self,
config: FlowmatchingActionHeadConfig,
):
super().__init__()
self.hidden_size = config.hidden_size
self.input_embedding_dim = config.input_embedding_dim
self.model = DiT(**config.diffusion_model_cfg)
self.action_dim = config.action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=config.max_state_dim,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
)
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=config.action_dim,
hidden_size=self.input_embedding_dim,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.future_tokens = nn.Embedding(config.num_target_vision_tokens, self.input_embedding_dim)
nn.init.normal_(self.future_tokens.weight, mean=0.0, std=0.02)
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
self.vl_self_attention = (
SelfAttentionTransformer(**config.vl_self_attention_cfg) if config.use_vlln else nn.Identity()
)
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.beta_dist = Beta(config.noise_beta_alpha, config.noise_beta_beta)
self.num_timestep_buckets = config.num_timestep_buckets
self.config = config
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model)
def set_trainable_parameters(self, tune_projector: bool, tune_diffusion_model: bool):
self.tune_projector = tune_projector
self.tune_diffusion_model = tune_diffusion_model
for p in self.parameters():
p.requires_grad = True
if not tune_projector:
self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
print(f"Tune action head projector: {self.tune_projector}")
print(f"Tune action head diffusion model: {self.tune_diffusion_model}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_projector and not tune_diffusion_model:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Action head trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No action head trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if not self.tune_projector:
self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
def sample_time(self, batch_size, device, dtype):
sample = self.beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (self.config.noise_s - sample) / self.config.noise_s
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
backbone_features = backbone_output["backbone_features"]
backbone_features = self.vlln(backbone_features)
backbone_features = self.vl_self_attention(backbone_features)
backbone_output["backbone_features"] = backbone_features
return backbone_output
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
# Set frozen modules to eval
self.set_frozen_modules_to_eval_mode()
backbone_output = self.process_backbone_output(backbone_output)
if self.config.expand_batch is not None:
for k, v in backbone_output.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
backbone_output[k] = expanded
for k, v in action_input.items():
ndim = len(v.shape)
factors = [self.config.expand_batch]
while len(factors) < ndim:
factors.append(1)
factors = tuple(factors)
expanded = v.repeat(*factors)
action_input[k] = expanded
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
device = vl_embs.device
# Get embodiment ID.
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Embed noised action trajectory.
actions = action_input.action
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None] # shape (B,1,1) for broadcast
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
# Convert (continuous) t -> discrete if needed
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
vl_attn_mask = backbone_output.backbone_attention_mask
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
encoder_attention_mask=vl_attn_mask,
timestep=t_discretized,
return_all_hidden_states=False, # NOTE (YL): not using flare now
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
# Slice out only the action portion of pred and target.
action_mask = action_input.action_mask
loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = loss.sum() / action_mask.sum()
output_dict = {
"loss": loss,
}
return BatchFeature(data=output_dict)
@torch.no_grad()
def get_action(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
backbone_output = self.process_backbone_output(backbone_output)
# Get vision and language embeddings.
vl_embs = backbone_output.backbone_features
embodiment_id = action_input.embodiment_id
# Embed state.
state_features = self.state_encoder(action_input.state, embodiment_id)
# Set initial actions as the sampled noise.
batch_size = vl_embs.shape[0]
device = vl_embs.device
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.config.action_dim),
dtype=vl_embs.dtype,
device=device,
)
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# Run denoising steps.
for t in range(num_steps):
t_cont = t / float(num_steps) # e.g. goes 0, 1/N, 2/N, ...
t_discretized = int(t_cont * self.num_timestep_buckets)
# Embed noised action trajectory.
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
# Maybe add position embedding.
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
# Join vision, language, state and action embedding along sequence dimension.
future_tokens = self.future_tokens.weight.unsqueeze(0).expand(vl_embs.shape[0], -1, -1)
sa_embs = torch.cat((state_features, future_tokens, action_features), dim=1)
# Run model forward.
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embs,
timestep=timesteps_tensor,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -self.action_horizon :]
# Update actions using euler integration.
actions = actions + dt * pred_velocity
return BatchFeature(data={"action_pred": actions})
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype

View File

@@ -1,201 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("groot")
@dataclass
class GrootConfig(PreTrainedConfig):
"""Configuration for Groot policy wrapper."""
# Basic policy settings
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
# Dimension settings (must match pretrained GR00T model expectations)
# Maximum state dimension. Shorter states will be zero-padded.
max_state_dim: int = 64
# Maximum action dimension. Shorter actions will be zero-padded.
max_action_dim: int = 32
# Normalization (start with identity, adjust as needed)
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Image preprocessing (adjust to match Groot's expected input)
image_size: tuple[int, int] = (224, 224)
# Groot-specific model parameters (from groot_finetune_script.py)
# Path or HuggingFace model ID for the base Groot model
base_model_path: str = "nvidia/GR00T-N1.5-3B"
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
# Fine-tuning control arguments
# Whether to fine-tune the llm backbone
tune_llm: bool = False
# Whether to fine-tune the vision tower
tune_visual: bool = False
# Whether to fine-tune the projector
tune_projector: bool = True
# Whether to fine-tune the diffusion model
tune_diffusion_model: bool = True
# LoRA parameters (from groot_finetune_script.py)
# Rank for the LORA model. If 0, no LORA will be used.
lora_rank: int = 0
# Alpha value for the LORA model
lora_alpha: int = 16
# Dropout rate for the LORA model
lora_dropout: float = 0.1
# Whether to use the full model for LORA
lora_full_model: bool = False
# Training parameters (matching groot_finetune_script.py)
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
warmup_ratio: float = 0.05
use_bf16: bool = True
# Dataset parameters
# Video backend to use for training ('decord' or 'torchvision_av')
video_backend: str = "decord"
# Whether to balance dataset weights in mixture datasets
balance_dataset_weights: bool = True
# Whether to sample trajectories weighted by their length
balance_trajectory_weights: bool = True
# Optional dataset paths for delegating training to Isaac-GR00T runner
dataset_paths: list[str] | None = None
output_dir: str = "./tmp/gr00t"
save_steps: int = 1000
max_steps: int = 10000
batch_size: int = 32
dataloader_num_workers: int = 8
report_to: str = "wandb"
resume: bool = False
def __post_init__(self):
super().__post_init__()
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
)
# groot_repo_path is now optional since we ported the components
# No validation needed
def validate_features(self) -> None:
"""Validate and set up input/output features for Groot."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
if not image_features:
raise ValueError(
"Groot policy requires at least one visual input feature. "
"No features of type FeatureType.VISUAL found in input_features."
)
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,),
)
self.input_features["observation.state"] = state_feature
else:
state_shape = self.input_features["observation.state"].shape
state_dim = state_shape[0] if state_shape else 0
if state_dim > self.max_state_dim:
raise ValueError(
f"State dimension {state_dim} exceeds max_state_dim {self.max_state_dim}. "
f"Either reduce state dimension or increase max_state_dim in config."
)
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,),
)
self.output_features["action"] = action_feature
else:
action_shape = self.output_features["action"].shape
action_dim = action_shape[0] if action_shape else 0
if action_dim > self.max_action_dim:
raise ValueError(
f"Action dimension {action_dim} exceeds max_action_dim {self.max_action_dim}. "
f"Either reduce action dimension or increase max_action_dim in config."
)
def get_optimizer_preset(self) -> AdamWConfig:
"""Return optimizer configuration."""
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Return scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
num_decay_steps=10000, # Adjust based on training steps
peak_lr=self.optimizer_lr,
decay_lr=self.optimizer_lr * 0.1,
)
@property
def observation_delta_indices(self) -> None:
"""Return indices for delta observations (None for Groot)."""
return None
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
return list(range(min(self.chunk_size, 16)))
@property
def reward_delta_indices(self) -> None:
"""Return indices for delta rewards (None for Groot)."""
return None

View File

@@ -1,135 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Eagle25VLConfig(PretrainedConfig):
model_type = "eagle_2_5_vl"
is_composition = True
sub_configs = {"vision_config": SiglipVisionConfig, "text_config": Qwen2Config}
def __init__(
self,
vision_config=None,
text_config=None,
use_backbone_lora=0,
use_llm_lora=0,
pad2square=False,
select_layer=-4,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
loss_version="v1",
min_dynamic_tiles=1,
max_dynamic_tiles=6,
mlp_checkpoint=False,
initializer_range=0.02,
_attn_implementation="flash_attention_2",
_attn_implementation_autoset=False,
llm_config=None,
image_token_index=None,
use_pixel_shuffle=True,
mlp_connector_layers=2,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {"model_type": "siglip_vision_model"}
logger.info("vision_config is None. Initializing the InternVisionConfig with default values.")
if text_config is None:
text_config = {"architectures": ["Qwen2ForCausalLM"]}
logger.info(
"text_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
)
if vision_config["model_type"] == "siglip_vision_model":
self.vision_config = SiglipVisionConfig(**vision_config)
else:
raise ValueError("Unsupported model_type: {}".format(vision_config["model_type"]))
if text_config["architectures"][0] == "LlamaForCausalLM":
self.text_config = LlamaConfig(**text_config)
elif text_config["architectures"][0] == "Qwen2ForCausalLM":
self.text_config = Qwen2Config(**text_config)
elif text_config["architectures"][0] == "Qwen3ForCausalLM":
self.text_config = Qwen3Config(**text_config)
else:
raise ValueError("Unsupported architecture: {}".format(text_config["architectures"][0]))
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.mlp_checkpoint = mlp_checkpoint
self.pad2square = pad2square
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.loss_version = loss_version
self.initializer_range = initializer_range
self.min_dynamic_tiles = min_dynamic_tiles
self.max_dynamic_tiles = max_dynamic_tiles
self.tie_word_embeddings = self.text_config.tie_word_embeddings
self._attn_implementation = _attn_implementation
self._attn_implementation_autoset = _attn_implementation_autoset
self.image_token_index = image_token_index
self.use_pixel_shuffle = use_pixel_shuffle
self.mlp_connector_layers = mlp_connector_layers
logger.info(f"min_dynamic_tiles: {self.min_dynamic_tiles}")
logger.info(f"max_dynamic_tiles: {self.max_dynamic_tiles}")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
output["use_backbone_lora"] = self.use_backbone_lora
output["use_llm_lora"] = self.use_llm_lora
output["pad2square"] = self.pad2square
output["select_layer"] = self.select_layer
output["force_image_size"] = self.force_image_size
output["downsample_ratio"] = self.downsample_ratio
output["template"] = self.template
output["dynamic_image_size"] = self.dynamic_image_size
output["use_thumbnail"] = self.use_thumbnail
output["min_dynamic_tiles"] = self.min_dynamic_tiles
output["max_dynamic_tiles"] = self.max_dynamic_tiles
output["tie_word_embeddings"] = self.tie_word_embeddings
output["_attn_implementation"] = self._attn_implementation
output["_attn_implementation_autoset"] = self._attn_implementation_autoset
output["use_pixel_shuffle"] = self.use_pixel_shuffle
output["mlp_connector_layers"] = self.mlp_connector_layers
return output

View File

@@ -1,504 +0,0 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py
from typing import Optional
from transformers.image_processing_utils import (
BatchFeature,
get_patch_output_size,
)
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import (
IMAGENET_STANDARD_MEAN, # 0.5, 0.5, 0.5
IMAGENET_STANDARD_STD, # 0.5, 0.5, 0.5
ChannelDimension,
ImageInput,
PILImageResampling,
SizeDict,
get_image_size,
make_flat_list_of_images,
validate_kwargs,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_v2_available,
)
from transformers.video_utils import VideoInput
if is_torch_available():
import torch
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F # noqa: N812
from transformers.image_utils import pil_torch_interpolation_mapping
else:
from torchvision.transforms import functional as F # noqa: N812
def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> torch.Tensor:
"""Crop the given numpy array.
Args:
img (torch.Tensor): Image to be cropped. Format should be (C, H, W).
left (int): The left coordinate of the crop box.
top (int): The top coordinate of the crop box.
right (int): The right coordinate of the crop box.
bottom (int): The bottom coordinate of the crop box.
Returns:
torch.Tensor: Cropped image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError(f"img should be torch.Tensor. Got {type(img)}")
if img.ndim not in [2, 3]:
raise ValueError(f"Image should have 2 or 3 dimensions. Got {img.ndim}")
img_height = img.shape[1]
img_width = img.shape[2]
if top < 0 or left < 0 or bottom > img_height or right > img_width:
raise ValueError("Crop coordinates out of bounds")
if top >= bottom or left >= right:
raise ValueError("Invalid crop coordinates")
return img[:, top:bottom, left:right]
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
max_dynamic_tiles: int | None
min_dynamic_tiles: int | None
use_thumbnail: bool | None
pad_during_tiling: bool | None
do_pad: bool | None
@add_start_docstrings(
"Constructs a fast ConvNeXT image processor. Based on [`SiglipImageProcessor`] with incorporation of processing each video frame.",
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, TODO: this was depreciated from transformers remove!
"""
image_grid_pinpoints (`List[List[int]]`, *optional*):
A list of possible resolutions to use for processing high resolution images. The best resolution is selected
based on the original size of the image. Can be overridden by `image_grid_pinpoints` in the `preprocess`
method. Not used for processing videos.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
class Eagle25VLImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BICUBIC
image_mean = IMAGENET_STANDARD_MEAN
image_std = IMAGENET_STANDARD_STD
size = {"height": 448, "width": 448}
default_to_square = False
crop_size = None
do_resize = True
do_center_crop = None
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_pad = True
max_dynamic_tiles = 12
min_dynamic_tiles = 1
use_thumbnail = True
pad_during_tiling = False
valid_kwargs = Eagle25VLFastImageProcessorKwargs
model_input_names = ["pixel_values_videos"]
def __init__(self, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]):
super().__init__(**kwargs)
@add_start_docstrings(
# BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, TODO: this was depreciated from transformers remove!
"""
max_dynamic_tiles (`int`, *optional*):
The maximum number of dynamic tiles to use for processing high resolution images.
min_dynamic_tiles (`int`, *optional*):
The minimum number of dynamic tiles to use for processing high resolution images.
use_thumbnail (`bool`, *optional*):
Whether to use a thumbnail for processing high resolution images.
pad_during_tiling (`bool`, *optional*):
Whether to pad the image during tiling.
do_pad (`bool`, *optional*):
Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
number of patches in the batch. Padding will be applied to the bottom and right with zeros.
""",
)
# NOTE(YL): we will overload the preprocess method to add the image_flags
# def preprocess(
# self, images: ImageInput, **kwargs: Unpack[Eagle25VLFastImageProcessorKwargs]
# ) -> BatchFeature:
# return super().preprocess(images, **kwargs)
def _prepare_images_structure(
self,
images: ImageInput,
expected_ndims: int = 3,
) -> ImageInput:
"""
Prepare the images structure for processing.
Args:
images (`ImageInput`):
The input images to process.
expected_ndims (`int`, *optional*, defaults to 3):
Expected number of dimensions for the images (added for transformers >=4.53.0 compatibility).
Returns:
`ImageInput`: The images with a valid nesting.
"""
return make_flat_list_of_images(images)
def _resize_for_patching(
self,
image: "torch.Tensor",
target_resolution: tuple,
interpolation: "F.InterpolationMode",
input_data_format: ChannelDimension,
) -> "torch.Tensor":
"""
Resizes an image to a target resolution while maintaining aspect ratio.
Args:
image ("torch.Tensor"):
The input image.
target_resolution (tuple):
The target resolution (height, width) of the image.
interpolation (`InterpolationMode`):
Resampling filter to use if resizing the image.
input_data_format (`ChannelDimension` or `str`):
The channel dimension format of the input image.
Returns:
"torch.Tensor": The resized and padded image.
"""
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
# Resize the image
resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation)
return resized_image
def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
"""
previous version mainly focus on ratio.
We also consider area ratio here.
"""
best_factor = float("-inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
# ratio_diff = abs(aspect_ratio - target_aspect_ratio)
# area_ratio = (ratio[0] * ratio[1] * image_size * image_size) / area
"""
new area > 60% of original image area is enough.
"""
factor_based_on_area_n_ratio = min(
(ratio[0] * ratio[1] * image_size * image_size) / area, 0.6
) * min(target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def _pad_for_patching(
self, image: "torch.Tensor", target_resolution: tuple, input_data_format: ChannelDimension
) -> "torch.Tensor":
"""
Pad an image to a target resolution while maintaining aspect ratio.
"""
target_height, target_width = target_resolution
new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
padded_image = F.pad(image, padding=[paste_x, paste_y, paste_x, paste_y])
return padded_image
def _get_image_patches(
self,
image: "torch.Tensor",
min_num: int,
max_num: int,
size: tuple,
tile_size: int,
use_thumbnail: bool,
interpolation: "F.InterpolationMode",
pad_during_tiling: bool,
) -> list["torch.Tensor"]:
image_size = get_image_size(image, channel_dim=ChannelDimension.FIRST)
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
# calculate the target width and height
target_width = tile_size * target_aspect_ratio[0]
target_height = tile_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
if pad_during_tiling:
resized_image = self._resize_for_patching(
image,
(target_height, target_width),
interpolation=interpolation,
input_data_format=ChannelDimension.FIRST,
)
padded_image = self._pad_for_patching(
resized_image,
(target_height, target_width),
input_data_format=ChannelDimension.FIRST,
)
image_used_to_split = padded_image
else:
image_used_to_split = F.resize(image, (target_height, target_width), interpolation=interpolation)
processed_tiles = []
for i in range(blocks):
box = (
(i % (target_width // tile_size)) * tile_size,
(i // (target_width // tile_size)) * tile_size,
((i % (target_width // tile_size)) + 1) * tile_size,
((i // (target_width // tile_size)) + 1) * tile_size,
)
# split the image
split_img = crop(image_used_to_split, box[0], box[1], box[2], box[3])
processed_tiles.append(split_img)
assert len(processed_tiles) == blocks
if use_thumbnail and len(processed_tiles) != 1:
thumbnail_img = F.resize(image, (tile_size, tile_size), interpolation=interpolation)
processed_tiles.append(thumbnail_img)
return processed_tiles
def _pad_for_batching(
self,
pixel_values: list["torch.Tensor"],
) -> list["torch.Tensor"]:
"""
Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
Args:
pixel_values (`List[torch.Tensor]`):
An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
Returns:
List[`torch.Tensor`]: The padded images.
"""
max_patch = max(len(x) for x in pixel_values)
pixel_values = [
torch.nn.functional.pad(image, pad=[0, 0, 0, 0, 0, 0, 0, max_patch - image.shape[0]])
for image in pixel_values
]
return pixel_values
def _preprocess(
self,
images: list["torch.Tensor"],
do_resize: bool,
size: SizeDict,
max_dynamic_tiles: int,
min_dynamic_tiles: int,
use_thumbnail: bool,
pad_during_tiling: bool,
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: float | list[float] | None,
image_std: float | list[float] | None,
do_pad: bool,
return_tensors: str | TensorType | None,
pad_size: SizeDict | None = None, # Added for transformers >=4.53.0 compatibility
disable_grouping: bool | None = None, # Added for transformers >=4.53.0 compatibility
) -> BatchFeature:
processed_images = []
image_sizes = []
# Determine the size tuple
if size and size.height and size.width:
size_tuple = (size.height, size.width)
else:
size_tuple = (size.shortest_edge, size.shortest_edge)
# Determine the patch size
if crop_size and crop_size.height:
tile_size = crop_size.height
elif size and size.height:
tile_size = size.height
else:
tile_size = size.shortest_edge
for image in images:
image_patches = self._get_image_patches(
image,
min_num=min_dynamic_tiles,
max_num=max_dynamic_tiles,
size=size_tuple,
tile_size=tile_size,
use_thumbnail=use_thumbnail,
interpolation=interpolation,
pad_during_tiling=pad_during_tiling,
)
# Group images by size for batched processing
processed_image_patches_grouped = {}
# Added for transformers >=4.53.0 compatibility
grouped_image_patches, grouped_image_patches_index = group_images_by_shape(
image_patches,
disable_grouping=disable_grouping,
)
for shape, stacked_image_patches in grouped_image_patches.items():
if do_resize:
stacked_image_patches = self.resize(
image=stacked_image_patches,
size=size,
interpolation=interpolation,
)
if do_center_crop:
stacked_image_patches = self.center_crop(stacked_image_patches, crop_size)
# Fused rescale and normalize
stacked_image_patches = self.rescale_and_normalize(
stacked_image_patches,
do_rescale,
rescale_factor,
do_normalize,
image_mean,
image_std,
)
processed_image_patches_grouped[shape] = stacked_image_patches
processed_image_patches = reorder_images(
processed_image_patches_grouped, grouped_image_patches_index
)
processed_image_patches = (
torch.stack(processed_image_patches, dim=0) if return_tensors else processed_image_patches
)
processed_images.append(processed_image_patches)
image_sizes.append(get_image_size(image, ChannelDimension.FIRST))
if do_pad:
processed_images = self._pad_for_batching(processed_images)
# processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(
data={"pixel_values": processed_images, "image_sizes": image_sizes},
tensor_type=return_tensors,
)
def preprocess(
self,
images: ImageInput,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLFastImageProcessorKwargs],
) -> BatchFeature:
validate_kwargs(
captured_kwargs=kwargs.keys(),
valid_processor_keys=self.valid_kwargs.__annotations__.keys(),
)
# Set default kwargs from self. This ensures that if a kwarg is not provided
# by the user, it gets its default value from the instance, or is set to None.
for kwarg_name in self.valid_kwargs.__annotations__:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
# Extract parameters that are only used for preparing the input images
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
# Prepare input images
# transformers >= 4.53.0: uses _prepare_image_like_inputs instead of _prepare_input_images
if images is not None:
images = self._prepare_image_like_inputs(
images=images,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
if videos is not None:
videos = self._prepare_image_like_inputs(
images=videos,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
)
# Update kwargs that need further processing before being validated
kwargs = self._further_process_kwargs(**kwargs)
# Validate kwargs
self._validate_preprocess_kwargs(**kwargs)
# torch resize uses interpolation instead of resample
# Added for transformers >=4.53.0 compatibility
resample = kwargs.pop("resample", self.resample)
kwargs["interpolation"] = (
pil_torch_interpolation_mapping[resample]
if isinstance(resample, PILImageResampling | int)
else resample
)
# Filter kwargs to only include those accepted by _preprocess
valid_preprocess_kwargs = {
"do_resize",
"size",
"max_dynamic_tiles",
"min_dynamic_tiles",
"use_thumbnail",
"pad_during_tiling",
"interpolation",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_pad",
"return_tensors",
"pad_size",
"disable_grouping",
}
filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_preprocess_kwargs}
if images is not None:
return self._preprocess(images, **filtered_kwargs)
elif videos is not None:
return self._preprocess(videos, **filtered_kwargs)
__all__ = ["Eagle25VLImageProcessorFast"]

View File

@@ -1,395 +0,0 @@
# --------------------------------------------------------
# NVIDIA
# Copyright (c) 2025 NVIDIA
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import inspect
import torch
import torch.utils.checkpoint as cp
from peft import LoraConfig, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.qwen3.modeling_qwen3 import Qwen3ForCausalLM
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from transformers.utils import add_start_docstrings, logging
from .configuration_eagle2_5_vl import Eagle25VLConfig
logger = logging.get_logger(__name__)
# copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py#L241C1-L280C1
EAGLE2_5_VL_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Eagle25VLConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Eagle2_5_VL Model outputting raw hidden-states without any specific head on top.",
EAGLE2_5_VL_START_DOCSTRING,
)
class Eagle25VLPreTrainedModel(PreTrainedModel):
config_class = Eagle25VLConfig
base_model_prefix = "model"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_no_split_modules = [
"Qwen2DecoderLayer",
"LlamaDecoderLayer",
"Siglip2EncoderLayer",
"SiglipEncoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = True
_supports_sdpa = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear | nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class Eagle25VLForConditionalGeneration(Eagle25VLPreTrainedModel, GenerationMixin):
config_class = Eagle25VLConfig
def __init__(self, config: Eagle25VLConfig, vision_model=None, language_model=None):
super().__init__(config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
if config.use_pixel_shuffle:
self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio**2))
else:
self.num_image_token = int((image_size // patch_size) ** 2)
self.select_layer = config.select_layer
self.downsample_ratio = config.downsample_ratio
self.loss_version = config.loss_version
self.mlp_checkpoint = config.mlp_checkpoint
self.use_pixel_shuffle = config.use_pixel_shuffle
self.mlp_connector_layers = config.mlp_connector_layers
logger.info(f"num_image_token: {self.num_image_token}")
logger.info(f"mlp_checkpoint: {self.mlp_checkpoint}")
if vision_model is not None:
self.vision_model = vision_model
else:
if config.vision_config.model_type == "siglip_vision_model":
config.vision_config._attn_implementation = "flash_attention_2"
self.vision_model = SiglipVisionModel(config.vision_config)
else:
raise NotImplementedError(f"{config.vision_config.model_type} is not implemented.")
if language_model is not None:
self.language_model = language_model
else:
if config.text_config.architectures[0] == "LlamaForCausalLM":
self.language_model = LlamaForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Phi3ForCausalLM":
raise NotImplementedError("Phi3 is not implemented.")
# self.language_model = Phi3ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen2ForCausalLM":
assert config.text_config._attn_implementation == "flash_attention_2", (
f"Qwen2 must use flash_attention_2 but got {config.text_config._attn_implementation}"
)
self.language_model = Qwen2ForCausalLM(config.text_config)
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
self.language_model = Qwen3ForCausalLM(config.text_config)
else:
raise NotImplementedError(f"{config.text_config.architectures[0]} is not implemented.")
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
if config.mlp_connector_layers == 2:
self.mlp1 = nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
)
elif config.mlp_connector_layers == 1 and not config.use_pixel_shuffle:
self.mlp1 = nn.Sequential(
nn.Linear(vit_hidden_size, llm_hidden_size),
)
else:
raise NotImplementedError(f"{config.mlp_connector_layers} is not implemented.")
self.image_token_index = config.image_token_index
self.neftune_alpha = None
if config.use_backbone_lora:
self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
self.use_llm_lora = config.use_llm_lora
if config.use_llm_lora:
self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
self.check_forward_kwargs()
def check_forward_kwargs(self):
# We intentionally avoid using **kwargs in forward because Hugging Face Transformers
# has special handling for functions with **kwargs parameters that would affect
# how our model is processed during training and inference.
forward_params = inspect.signature(self.forward).parameters
assert not any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.out_proj",
"mlp.fc1",
"mlp.fc2",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.vision_model = get_peft_model(self.vision_model, lora_config)
self.vision_model.print_trainable_parameters()
def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=[
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.gate_proj",
"mlp.down_proj",
"mlp.up_proj",
],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
task_type="CAUSAL_LM",
)
self.language_model = get_peft_model(self.language_model, lora_config)
self.language_model.enable_input_require_grads()
self.language_model.print_trainable_parameters()
self.use_llm_lora = True
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
image_flags: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
num_tiles_list: list[torch.Tensor] | None = None,
) -> tuple | CausalLMOutputWithPast:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_embeds = self.language_model.get_input_embeddings()(input_ids)
vit_embeds = self.extract_feature(pixel_values)
if image_flags is not None:
image_flags = image_flags.view(-1)
vit_embeds = vit_embeds[image_flags == 1]
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.image_token_index
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, c)
print(
f"warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, "
f"vit_embeds.shape={vit_embeds.shape}"
)
n_token = selected.sum()
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
input_embeds = input_embeds.reshape(b, n, c)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor)))
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
)
if hasattr(vit_embeds, "last_hidden_state"):
vit_embeds = vit_embeds.last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
).hidden_states[self.select_layer]
if self.use_pixel_shuffle:
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio
) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
vit_embeds = vit_embeds.reshape(
vit_embeds.shape[0], -1, vit_embeds.shape[-1]
) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
if self.mlp_checkpoint and vit_embeds.requires_grad:
vit_embeds = cp.checkpoint(self.mlp1, vit_embeds)
else:
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor | None = None,
input_ids: torch.FloatTensor | None = None,
attention_mask: torch.LongTensor | None = None,
visual_features: torch.FloatTensor | None = None,
generation_config: GenerationConfig | None = None,
output_hidden_states: bool | None = None,
image_sizes: list[tuple[int, int]] | None = None,
**generate_kwargs,
) -> torch.LongTensor:
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
b, n, c = input_embeds.shape
input_embeds = input_embeds.reshape(b * n, c)
input_ids = input_ids.reshape(b * n)
selected = input_ids == self.config.image_token_index
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, c).to(input_embeds.device)
input_embeds = input_embeds.reshape(b, n, c)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
if "use_cache" not in generate_kwargs:
generate_kwargs["use_cache"] = True
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
**generate_kwargs,
)
return outputs
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
# Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
def get_decoder(self):
return self.language_model.get_decoder()

View File

@@ -1,518 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Eagle25VL.
copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/processing_llava_onevision.py
"""
import base64
import os
import re
from io import BytesIO
import requests
import torch
from PIL import Image
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from transformers.video_utils import VideoInput
logger = logging.get_logger(__name__)
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 256
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == "RGBA":
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image:
image = ele["image"] if "image" in ele else ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
response = requests.get(image, stream=True, timeout=10)
image_obj = Image.open(BytesIO(response.content))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
)
image = to_rgb(image_obj)
if "scale_factor" in ele:
scale_factor = ele["scale_factor"]
image = image.resize((image.width * scale_factor, image.height * scale_factor), Image.BILINEAR)
return image
class Eagle25VLProcessorKwargs(ProcessingKwargs, total=False):
# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
"videos_kwargs": {"max_dynamic_tiles": 1},
}
class Eagle25VLProcessor(ProcessorMixin):
r"""
Constructs a Eagle25VL processor which wraps a Eagle25VL video processor, Eagle25VL image processor and a Eagle25VL tokenizer into a single processor.
[`Eagle25VLProcessor`] offers all the functionalities of [`Eagle25VLVideoProcessor`], [`Eagle25VLImageProcessor`] and [`Eagle25VLTokenizer`]. See the
[`~Eagle25VLVideoProcessor.__call__`], [`~Eagle25VLProcessor.__call__`] and [`~Eagle25VLProcessor.decode`] for more information.
Args:
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
num_image_tokens (`int`, *optional*):
Number of image tokens for one imagethat will be returned by vision tower.
vision_feature_select_strategy (`str`, *optional*):
The feature selection strategy used to select the vision feature from the vision backbone.
Should be same as in model's config
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
image_token (`str`, *optional*, defaults to `"<image>"`):
Special token used to denote image location.
video_token (`str`, *optional*, defaults to `"<video>"`):
Special token used to denote video location.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = [
"chat_template",
"num_image_tokens",
"vision_feature_select_strategy",
"image_token",
"video_token",
"images_kwargs",
"videos_kwargs",
"text_kwargs",
]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
tokenizer=None,
vision_feature_select_strategy=None,
chat_template=None,
image_token="<IMG_CONTEXT>", # nosec: B107
video_token="<IMG_CONTEXT>", # nosec: B107
tokens_per_tile=256,
image_placeholder="image",
video_placeholder="video",
image_start_token="<img>",
image_end_token="</img>",
**kwargs,
):
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
self.image_token_id = (
tokenizer.image_token_id
if getattr(tokenizer, "image_token_id", None)
else tokenizer.convert_tokens_to_ids(self.image_token)
)
self.video_token_id = (
tokenizer.video_token_id
if getattr(tokenizer, "video_token_id", None)
else tokenizer.convert_tokens_to_ids(self.video_token)
)
self.image_placeholder = image_placeholder
self.video_placeholder = video_placeholder
self.tokens_per_tile = tokens_per_tile
self.image_start_token = image_start_token
self.image_end_token = image_end_token
if "auto_map" in kwargs:
self.auto_map = kwargs["auto_map"]
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def replace_media_placeholder(
self, text, image_list, video_list, timestamps_list, fps_list, **output_kwargs
):
num_of_images_in_this_sample = 0
num_of_videos_in_this_sample = 0
# Regular expression pattern to match formats like <image-1> or <video-2>
pattern = re.compile(rf"<({self.image_placeholder}|{self.video_placeholder})-(\d+)>")
unified_frame_list = []
# image_min_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "min_dynamic_tiles", self.image_processor.min_dynamic_tiles
# )
# image_max_dynamic_tiles = output_kwargs["images_kwargs"].get(
# "max_dynamic_tiles", self.image_processor.max_dynamic_tiles
# )
# image_use_thumbnail = output_kwargs["images_kwargs"].get(
# "use_thumbnail", self.image_processor.use_thumbnail
# )
video_min_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"min_dynamic_tiles", self.image_processor.min_dynamic_tiles
)
video_max_dynamic_tiles = output_kwargs["videos_kwargs"].get(
"max_dynamic_tiles", self.image_processor.max_dynamic_tiles
)
video_use_thumbnail = output_kwargs["videos_kwargs"].get(
"use_thumbnail", self.image_processor.use_thumbnail
)
tile_size = self.image_processor.size.get("height", 448)
# Function to replace tags in a single text
def replace_in_text(text):
# repl callback function for each match replacement operation
def repl(match):
nonlocal unified_frame_list
nonlocal num_of_images_in_this_sample
nonlocal num_of_videos_in_this_sample
media_type = match.group(1) # 'image' or 'video'
idx_in_list = int(match.group(2)) - 1 # Convert to list index (0-based)
# Select the corresponding path based on media type
idx_mapper = {
0: "first",
1: "second",
2: "third",
3: "fourth",
4: "fifth",
5: "sixth",
6: "seventh",
7: "eighth",
8: "ninth",
9: "tenth",
}
if media_type == "image":
image_inputs = self.image_processor(
images=[image_list[idx_in_list]],
videos=None,
**output_kwargs["images_kwargs"],
)
num_all_tiles = image_inputs["pixel_values"].shape[0]
special_placeholder = f"<image {idx_in_list + 1}>{self.image_start_token}{self.image_token * num_all_tiles * self.tokens_per_tile}{self.image_end_token}"
unified_frame_list.append(image_inputs)
num_of_images_in_this_sample += 1
elif media_type == "video":
video_inputs = self.image_processor(
images=None,
videos=[video_list[idx_in_list]],
**output_kwargs["videos_kwargs"],
)
num_all_tiles = video_inputs["pixel_values"].shape[0]
image_sizes = video_inputs["image_sizes"]
if timestamps_list is not None and -1 not in timestamps_list:
frame_timestamps = timestamps_list[idx_in_list]
else:
frame_timestamps = None
sampled_fps = fps_list[idx_in_list] if fps_list is not None else None
num_of_tiles_each_frame = [
self.get_number_tiles_based_on_image_size(
image_size,
video_min_dynamic_tiles,
video_max_dynamic_tiles,
video_use_thumbnail,
tile_size,
)
for image_size in image_sizes
]
assert sum(num_of_tiles_each_frame) == num_all_tiles, (
f"The number of tiles in each frame is not equal to the total number of tiles: {sum(num_of_tiles_each_frame)} != {num_all_tiles}"
)
if frame_timestamps is not None:
assert len(frame_timestamps) == len(num_of_tiles_each_frame), (
f"The number of timestamps is not equal to the number of frames: {len(frame_timestamps)} != {len(num_of_tiles_each_frame)}"
)
special_placeholder = [
f"Frame {i + 1} sample at {frame_timestamps[i]:.2f}s: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
else:
special_placeholder = [
f"Frame {i + 1}: {self.image_start_token}{self.image_token * num_of_tiles * self.tokens_per_tile}{self.image_end_token}"
for i, num_of_tiles in enumerate(num_of_tiles_each_frame)
]
if sampled_fps is not None:
special_placeholder = (
f"The {idx_mapper[idx_in_list]} video sampled with {sampled_fps:.2f} fps: "
+ "".join(special_placeholder)
)
else:
special_placeholder = f"The {idx_mapper[idx_in_list]} video: " + "".join(
special_placeholder
)
unified_frame_list.append(video_inputs)
num_of_videos_in_this_sample += 1
else:
raise ValueError(f"Unknown media type: {media_type}")
return special_placeholder
return pattern.sub(repl, text)
text = replace_in_text(text)
if len(unified_frame_list) > 0:
pixel_values = torch.cat([frame["pixel_values"] for frame in unified_frame_list])
image_sizes = torch.cat([frame["image_sizes"] for frame in unified_frame_list])
else:
pixel_values = None
image_sizes = None
return (
text,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
)
def __call__(
self,
images: ImageInput = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
audio=None,
videos: VideoInput = None,
**kwargs: Unpack[Eagle25VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
LlavaNextImageProcessor's [`~LlavaNextImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of a video input to be fed to a model. Returned when `videos` is not `None`.
- **image_sizes** -- Size of each image that will be used to unpad an image. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Eagle25VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if isinstance(text, str):
text_list = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
elif isinstance(text, list) and isinstance(text[0], str):
text_list = text
if images is None:
images = []
if videos is None:
videos = []
pixel_values_list = []
image_sizes_list = []
new_sample_list = []
image_start_idx = 0
video_start_idx = 0
timestamps_batch = output_kwargs["videos_kwargs"].pop("timestamps", None)
fps_batch = output_kwargs["videos_kwargs"].pop("fps", None)
for sample in text_list:
timestamps_list = timestamps_batch[video_start_idx:] if timestamps_batch is not None else None
fps_list = fps_batch[video_start_idx:] if fps_batch is not None else None
(
sample,
pixel_values,
image_sizes,
num_of_images_in_this_sample,
num_of_videos_in_this_sample,
) = self.replace_media_placeholder(
sample,
images[image_start_idx:],
videos[video_start_idx:],
timestamps_list,
fps_list,
**output_kwargs,
)
new_sample_list.append(sample)
if pixel_values is not None:
pixel_values_list.append(pixel_values)
image_sizes_list.append(image_sizes)
image_start_idx += num_of_images_in_this_sample
video_start_idx += num_of_videos_in_this_sample
if len(pixel_values_list) > 0:
image_inputs = {
"pixel_values": torch.cat(pixel_values_list),
"image_sizes": torch.cat(image_sizes_list),
}
else:
image_inputs = {}
video_inputs = {}
text_inputs = self.tokenizer(new_sample_list, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **video_inputs})
def get_number_tiles_based_on_image_size(
self, image_size: tuple, min_num: int, max_num: int, use_thumbnail: bool, tile_size: int
) -> int:
"""
Get the number of tiles based on the image size.
"""
orig_height, orig_width = image_size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = self.image_processor.find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, tile_size
)
tiles_num = target_aspect_ratio[0] * target_aspect_ratio[1]
if use_thumbnail and tiles_num > 1:
tiles_num += 1
return tiles_num
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
# override to save video-config in a separate config file
def save_pretrained(self, save_directory, **kwargs):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
outputs = super().save_pretrained(save_directory, **kwargs)
return outputs
# override to load video-config from a separate config file
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
processor = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# if return_unused_kwargs a tuple is returned where the second element is 'unused_kwargs'
if isinstance(processor, tuple):
processor = processor[0]
return processor
# Copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def process_vision_info(
self,
conversations: list[dict] | list[list[dict]],
return_video_kwargs: bool = False,
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, dict | None]:
vision_infos = self.extract_vision_info(conversations)
## Read images or videos
image_inputs = []
video_inputs = []
video_sample_fps_list = []
video_timestamps_list = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
if return_video_kwargs:
return (
image_inputs,
video_inputs,
{"fps": video_sample_fps_list, "timestamps": video_timestamps_list},
)
return image_inputs, video_inputs
def extract_vision_info(self, conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if (
"image" in ele
or "image_url" in ele
or "video" in ele
or ele["type"] in ("image", "image_url", "video")
):
vision_infos.append(ele)
return vision_infos
__all__ = ["Eagle25VLProcessor"]

View File

@@ -1,376 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from lerobot.utils.import_utils import _transformers_available
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
try:
import tree
except ImportError:
tree = None
from lerobot.policies.groot.action_head.flow_matching_action_head import (
FlowmatchingActionHead,
FlowmatchingActionHeadConfig,
)
from lerobot.policies.groot.utils import ensure_eagle_cache_ready
from lerobot.utils.constants import HF_LEROBOT_HOME
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
class EagleBackbone(nn.Module):
def __init__(
self,
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
eagle_path: str = DEFAULT_VENDOR_EAGLE_PATH,
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO,
project_to_dim: int = 1536,
):
"""
Args:
tune_llm: whether to tune the LLM model (default: True)
tune_visual: whether to tune the visual model (default: False)
"""
super().__init__()
assert not reproject_vision, "Reproject vision is not implemented here, set to False"
# Prefer loading Eagle model config from the cache directory where vendor files were copied.
vendor_dir = DEFAULT_VENDOR_EAGLE_PATH
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
try:
ensure_eagle_cache_ready(vendor_dir, cache_dir, tokenizer_assets_repo)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: failed to prepare Eagle cache for backbone: {exc}")
config = AutoConfig.from_pretrained(str(cache_dir), trust_remote_code=True)
self.eagle_model = AutoModel.from_config(config, trust_remote_code=True)
if project_to_dim is not None:
self.eagle_linear = torch.nn.Linear(2048, project_to_dim)
else:
self.eagle_linear = torch.nn.Identity()
# needed since we don't use these layers. Also saves compute
while len(self.eagle_model.language_model.model.layers) > select_layer:
self.eagle_model.language_model.model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual)
def set_trainable_parameters(self, tune_llm: bool, tune_visual: bool):
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for p in self.parameters():
p.requires_grad = True
if not tune_llm:
self.eagle_model.language_model.requires_grad_(False)
if not tune_visual:
self.eagle_model.vision_model.requires_grad_(False)
self.eagle_model.mlp1.requires_grad_(False)
print(f"Tune backbone llm: {self.tune_llm}")
print(f"Tune backbone visual: {self.tune_visual}")
# Check if any parameters are still trainable. If not, print a warning.
if not tune_llm and not tune_visual:
for name, p in self.named_parameters():
if p.requires_grad:
print(f"Backbone trainable parameter: {name}")
if not any(p.requires_grad for p in self.parameters()):
print("Warning: No backbone trainable parameters found.")
def set_frozen_modules_to_eval_mode(self):
"""
Huggingface will call model.train() at each training_step. To ensure
the expected behaviors for modules like dropout, batchnorm, etc., we
need to call model.eval() for the frozen modules.
"""
if self.training:
if self.eagle_model.language_model and not self.tune_llm:
self.eagle_model.language_model.eval()
if self.eagle_model.vision_model and not self.tune_visual:
self.eagle_model.vision_model.eval()
def prepare_input(self, batch: dict) -> BatchFeature:
return BatchFeature(data=batch)
def forward_eagle(self, vl_input: BatchFeature) -> BatchFeature:
eagle_prefix = "eagle_"
eagle_input = {
k.removeprefix(eagle_prefix): v for k, v in vl_input.items() if k.startswith(eagle_prefix)
}
del eagle_input["image_sizes"]
eagle_output = self.eagle_model(**eagle_input, output_hidden_states=True, return_dict=True)
eagle_features = eagle_output.hidden_states[self.select_layer]
eagle_features = self.eagle_linear(eagle_features)
return eagle_features, eagle_input["attention_mask"]
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
if self.training and self.tune_visual:
dummy_term = torch.tensor(
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
)
for param in self.eagle_model.vision_model.parameters():
if param.requires_grad:
dummy_term = dummy_term + 0.0 * param.sum()
eagle_embeds = eagle_embeds + dummy_term
return BatchFeature(
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
) # [B, T2, hidden_size]
BACKBONE_FEATURE_KEY = "backbone_features"
ACTION_KEY = "action_pred"
LOSS_KEY = "loss"
ERROR_MSG = "Error: unexpected input/output"
N_COLOR_CHANNELS = 3
# config
@dataclass
class GR00TN15Config(PretrainedConfig):
model_type = "gr00t_n1_5"
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
action_dim: int = field(init=False, metadata={"help": "Action dimension."})
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items():
setattr(self, key, value)
# real model
class GR00TN15(PreTrainedModel):
supports_gradient_checkpointing = True
config_class = GR00TN15Config
"""
we expect the backbone output to have a key 'backbone_features' with shape (batch_size, n, hidden_size)
here n is variable and can be e.g. time, 1 or user specified
we expect the action head output to have a key 'action_pred' with shape (batch_size, time, action_dim) during inference time
we expect these to have type BatchFeature, and they can of course have many other user specified keys too
"""
def __init__(
self,
config: GR00TN15Config,
local_model_path: str,
):
assert isinstance(config.backbone_cfg, dict)
assert isinstance(config.action_head_cfg, dict)
super().__init__(config)
self.local_model_path = local_model_path
self.backbone = EagleBackbone(**config.backbone_cfg)
action_head_cfg = FlowmatchingActionHeadConfig(**config.action_head_cfg)
self.action_head = FlowmatchingActionHead(action_head_cfg)
self.action_horizon = config.action_horizon
self.action_dim = config.action_dim
self.compute_dtype = config.compute_dtype
def validate_inputs(self, inputs):
# NOTE -- this should be handled internally by the model
# however, doing that will likely be breaking changes -- so we'll need to do it after the deadline
detected_error = False
error_msg = ERROR_MSG
if "action" in inputs:
action = inputs["action"]
# In inference, action may be omitted or None; validate only when it's a tensor.
if action is None:
pass # allow None during inference
elif isinstance(action, torch.Tensor):
shape_ok = (
len(action.shape) == 3
and action.shape[1] == self.action_horizon
and action.shape[2] == self.action_dim
)
if not shape_ok:
error_msg += f"\n{action.shape=}"
detected_error = True
else:
# Unexpected non-tensor type provided for action
error_msg += f"\nInvalid type for action: {type(action)}"
detected_error = True
if "video" in inputs:
video = inputs["video"]
type_ok = isinstance(video, np.ndarray)
dtype_ok = video.dtype == np.uint8
shape_ok = len(video.shape) == 6 and video.shape[3] == N_COLOR_CHANNELS
if not type_ok:
error_msg += f"\n{type(video)=}"
detected_error = True
if not dtype_ok:
error_msg += f"\n{video.dtype=}"
detected_error = True
if not shape_ok:
error_msg += f"\n{video.shape=}"
detected_error = True
if detected_error:
raise ValueError(error_msg)
def validate_data(self, action_head_outputs, backbone_outputs, is_training):
fail_backbone = (
not isinstance(backbone_outputs, BatchFeature) or BACKBONE_FEATURE_KEY not in backbone_outputs
)
if fail_backbone:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(backbone_outputs, BatchFeature)=}"
error_msg += f"\n{BACKBONE_FEATURE_KEY in backbone_outputs=}"
error_msg += f"\n{backbone_outputs[BACKBONE_FEATURE_KEY].shape=}"
raise ValueError(error_msg)
fail_action_head = (not isinstance(action_head_outputs, BatchFeature)) or not (
(
LOSS_KEY in action_head_outputs and is_training
) # there might not be an action prediction during training
or (
ACTION_KEY in action_head_outputs
and action_head_outputs[ACTION_KEY].shape[1] == self.action_horizon
and action_head_outputs[ACTION_KEY].shape[2] == self.action_dim
)
)
if fail_action_head:
error_msg = ERROR_MSG
error_msg += f"\n{isinstance(action_head_outputs, BatchFeature)=}"
error_msg += f"\n{LOSS_KEY in action_head_outputs=}"
error_msg += f"\n{action_head_outputs[ACTION_KEY].shape=}"
error_msg += f"\n{self.action_horizon=}"
error_msg += f"\n{self.action_dim=}"
raise ValueError(error_msg)
def forward(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=True)
return action_head_outputs
def get_action(
self,
inputs: dict,
) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
# Because the behavior of backbones remains the same for training and inference, we can use `forward` for backbones.
backbone_outputs = self.backbone(backbone_inputs)
action_head_outputs = self.action_head.get_action(backbone_outputs, action_inputs)
self.validate_data(action_head_outputs, backbone_outputs, is_training=False)
return action_head_outputs
def prepare_input(self, inputs) -> tuple[BatchFeature, BatchFeature]:
self.validate_inputs(inputs)
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_maybe_dtype(x):
# Cast floating tensors to a memory-efficient compute dtype when requested.
# Rationale: Upcasting backbone activations to fp32 significantly increases VRAM.
# When compute_dtype is bfloat16, prefer bf16 for activations to match AMP behavior.
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
if getattr(self, "compute_dtype", None) == "bfloat16":
return x.to(self.device, dtype=torch.bfloat16)
# Fallback: preserve previous behavior if not using bf16 compute
return x.to(self.device, dtype=self.action_head.dtype)
# Non-floating tensors: move device only
return x.to(self.device)
backbone_inputs = tree.map_structure(to_device_with_maybe_dtype, backbone_inputs)
action_inputs = tree.map_structure(to_device_with_maybe_dtype, action_inputs)
return backbone_inputs, action_inputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
print(f"Loading pretrained dual brain from {pretrained_model_name_or_path}")
print(f"Tune backbone vision tower: {tune_visual}")
print(f"Tune backbone LLM: {tune_llm}")
print(f"Tune action head projector: {tune_projector}")
print(f"Tune action head DiT: {tune_diffusion_model}")
# get the current model path being downloaded
try:
# NOTE(YL) This downloads the model to the local cache and returns the local path to the model
# saved in ~/.cache/huggingface/hub/
local_model_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
# HFValidationError, RepositoryNotFoundError
except (HFValidationError, RepositoryNotFoundError):
print(
f"Model not found or avail in the huggingface hub. Loading from local path: {pretrained_model_name_or_path}"
)
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path, local_model_path=local_model_path, **kwargs
)
pretrained_model.backbone.set_trainable_parameters(tune_visual=tune_visual, tune_llm=tune_llm)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector, tune_diffusion_model=tune_diffusion_model
)
return pretrained_model

View File

@@ -1,198 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T components where possible
without porting their code. The intent is to:
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
- Optionally align action horizon similar to gr00t_finetune.py
- Expose predict_action via GR00T model.get_action
- Provide a training forward that can call the GR00T model forward if batch
structure matches.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
"""
import os
from collections import deque
import torch
from torch import Tensor
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.groot_n1 import GR00TN15
from lerobot.policies.pretrained import PreTrainedPolicy
class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration."""
name = "groot"
config_class = GrootConfig
def __init__(self, config: GrootConfig):
"""Initialize Groot policy wrapper."""
super().__init__(config)
config.validate_features()
self.config = config
# Initialize GR00T model using ported components
self._groot_model = self._create_groot_model()
self.reset()
def _create_groot_model(self):
"""Create and initialize the GR00T model using Isaac-GR00T API.
This is only called when creating a NEW policy (not when loading from checkpoint).
Steps (delegating to Isaac-GR00T):
1) Download and load pretrained model via GR00TN15.from_pretrained
2) Align action horizon with data_config if provided
"""
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
model = GR00TN15.from_pretrained(
pretrained_model_name_or_path=self.config.base_model_path,
tune_llm=self.config.tune_llm,
tune_visual=self.config.tune_visual,
tune_projector=self.config.tune_projector,
tune_diffusion_model=self.config.tune_diffusion_model,
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
return model
def reset(self):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Training forward pass.
Delegates to Isaac-GR00T model.forward when inputs are compatible.
"""
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = next(self.parameters()).device
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.forward(groot_inputs)
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss")
loss_dict = {"loss": loss.item()}
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
Returns a tensor of shape (B, n_action_steps, action_dim).
"""
self.eval()
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
# Preprocessing is handled by the processor pipeline, so we just filter the batch
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
allowed_base = {"state", "state_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
original_action_dim = self.config.output_features["action"].shape[0]
actions = actions[:, :, :original_action_dim]
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select single action from action queue."""
self.eval()
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
# -------------------------
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Handle Flash Attention compatibility issues by setting environment variables.
This addresses the common 'undefined symbol' error that occurs when Flash Attention
is compiled against a different PyTorch version than what's currently installed.
"""
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try:
import flash_attn
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError as e:
print(f"[GROOT] Flash Attention not available: {e}")
print("[GROOT] Will use fallback attention mechanism")
except Exception as e:
if "undefined symbol" in str(e):
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")

View File

@@ -1,664 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from einops import rearrange
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, ProcessorMixin
else:
AutoProcessor = None
ProcessorMixin = object
from lerobot.configs.types import (
FeatureType,
NormalizationMode,
PolicyFeature,
)
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
)
from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
HF_LEROBOT_HOME,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
# Defaults for Eagle processor locations
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
def make_groot_pre_post_processors(
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create preprocessor and postprocessor for Groot policy.
This creates a processing pipeline that transforms LeRobot data format into
the format expected by Isaac-GR00T models:
Preprocessing steps:
1. Optional key renaming (dataset-specific key mapping)
2. Add batch dimension to unbatched data
3. Pack video/state/action/language/embodiment and apply optional min-max normalization before padding
4. Encode video+language with Eagle VLM into intermediate eagle_content
5. Collate eagle_content into batched eagle_* tensors
6. Move tensors to device (GPU)
NOTE: We optionally apply min-max normalization to STATE and ACTION using
dataset-provided statistics prior to padding, mapping values to [-1, 1].
This mirrors SO100-style preprocessing and keeps scales consistent with GR00T.
Args:
config: Groot configuration containing data_config, embodiment_tag, etc.
dataset_stats: Optional per-key min/max statistics for normalization before padding.
Returns:
Tuple of (preprocessor, postprocessor) pipelines
"""
# Get horizon/dimension parameters from config
# These should match the config used for the pretrained model
# Default values match most GR00T configs (state_horizon=1, action_horizon=16)
state_horizon = 1
# CRITICAL: Pretrained GR00T models use action_horizon=16 max!
# The model architecture hardcodes this limit
action_horizon = min(config.chunk_size, 16)
max_state_dim = config.max_state_dim
max_action_dim = config.max_action_dim
# Pass raw dataset_stats; normalization will occur inside pack step before padding
padded_stats = dataset_stats or {}
# Define feature specs for optional normalization steps
_features: dict[str, PolicyFeature] = {
# Observation features (only add those we may normalize)
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)),
# Action feature
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)),
}
# Normalize STATE and ACTION with min_max (SO100-like default)
_norm_map = {
FeatureType.ACTION: NormalizationMode.MIN_MAX,
FeatureType.STATE: NormalizationMode.MIN_MAX,
}
# Determine env action dimension from config (simple, object-like PolicyFeature)
try:
env_action_dim = int(config.output_features["action"].shape[0])
except Exception:
env_action_dim = 0
input_steps: list[ProcessorStep] = [
# 1. Rename keys if needed (e.g., dataset-specific camera names)
# Leave empty for now - add mappings if your dataset uses different key names
RenameObservationsProcessorStep(rename_map={}),
# 2. Add batch dimension for single samples
AddBatchDimensionProcessorStep(),
# 3. Pack video/state/action/language/embodiment; apply optional min-max normalization before padding
GrootPackInputsStep(
state_horizon=state_horizon,
action_horizon=action_horizon,
max_state_dim=max_state_dim,
max_action_dim=max_action_dim,
language_key="task",
formalize_language=False,
embodiment_tag=config.embodiment_tag,
normalize_min_max=True,
stats=padded_stats,
),
# 4. Eagle encode (creates eagle_content)
GrootEagleEncodeStep(
tokenizer_assets_repo=config.tokenizer_assets_repo,
),
# 5. Collate eagle_content -> eagle_* tensors
GrootEagleCollateStep(
tokenizer_assets_repo=config.tokenizer_assets_repo,
),
# 6. Move to device
DeviceProcessorStep(device=config.device),
]
# Postprocessing: slice to env action dim and unnormalize to env scale, then move to CPU
output_steps: list[ProcessorStep] = [
GrootActionUnpackUnnormalizeStep(
env_action_dim=env_action_dim,
stats=padded_stats,
normalize_min_max=True,
),
# Finally, move to CPU for env interaction
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
# GR00T specific processor steps
def _to_uint8_np_bhwc(img_t: torch.Tensor) -> np.ndarray:
# img_t: (B, C, H, W) float in [0,1] or uint8
if img_t.dtype.is_floating_point:
img_t = (img_t.clamp(0, 1) * 255.0).to(torch.uint8)
return rearrange(img_t.cpu().numpy(), "b c h w -> b h w c")
def _build_eagle_processor(tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO) -> ProcessorMixin:
# Validate that the cache directory is ready. If not, instruct the user.
cache_dir = HF_LEROBOT_HOME / tokenizer_assets_repo
required = [
cache_dir / "processor_config.json",
cache_dir / "preprocessor_config.json",
cache_dir / "image_processing_eagle2_5_vl_fast.py",
]
if not all(p.exists() for p in required):
raise FileNotFoundError(
f"[GROOT] Eagle processor cache at '{cache_dir}' is not populated. "
"Vendor files are copied during model creation. Create the policy/model first, "
"or call ensure_eagle_cache_ready() before building processors."
)
proc = AutoProcessor.from_pretrained(str(cache_dir), trust_remote_code=True, use_fast=True)
proc.tokenizer.padding_side = "left"
return proc
@dataclass
@ProcessorStepRegistry.register(name="groot_pack_inputs_v3")
class GrootPackInputsStep(ProcessorStep):
state_horizon: int = 1
action_horizon: int = 16
max_state_dim: int = 64
max_action_dim: int = 32
language_key: str = "task"
formalize_language: bool = False
embodiment_tag: str = "new_embodiment"
embodiment_mapping: dict[str, int] = field(
default_factory=lambda: {
"new_embodiment": 31, # Match original GR00T EMBODIMENT_TAG_MAPPING
"oxe_droid": 17,
"agibot_genie1": 26,
"gr1": 24,
"so100": 2,
"unitree_g1": 3,
}
)
# Min-max normalization (SO100-like) applied BEFORE padding
normalize_min_max: bool = True
stats: dict[str, dict[str, Any]] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
def _align_vec(vec: Any, target_dim: int, *, default: float) -> torch.Tensor:
t = torch.as_tensor(vec)
t = t.flatten().to(
dtype=torch.float32,
device=next(
(v.device for v in obs.values() if isinstance(v, torch.Tensor)), torch.device("cpu")
),
)
d = int(t.shape[-1]) if t.numel() > 0 else 0
if d == target_dim:
return t
if d < target_dim:
pad = torch.full((target_dim - d,), default, dtype=t.dtype, device=t.device)
return torch.cat([t, pad], dim=0)
return t[:target_dim]
def _min_max_norm(x: torch.Tensor, key: str) -> torch.Tensor:
if not self.normalize_min_max:
return x
if self.stats is None or key not in self.stats:
return x
stats_k = self.stats[key]
last_dim = x.shape[-1]
min_v = _align_vec(stats_k.get("min", torch.zeros(last_dim)), last_dim, default=0.0)
max_v = _align_vec(stats_k.get("max", torch.ones(last_dim)), last_dim, default=1.0)
denom = max_v - min_v
mask = denom != 0
safe_denom = torch.where(mask, denom, torch.ones_like(denom))
mapped = 2 * (x - min_v) / safe_denom - 1
return torch.where(mask, mapped, torch.zeros_like(mapped))
# 1) Video (B, T=1, V, H, W, C) uint8
img_keys = sorted([k for k in obs if k.startswith("observation.images.")])
if not img_keys and "observation.image" in obs:
img_keys = ["observation.image"]
if img_keys:
cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys]
video = np.stack(cams, axis=1) # (B, V, H, W, C)
video = np.expand_dims(video, axis=1) # (B, 1, V, H, W, C)
# GR00T validates that video.shape[3] == 3 (channels), so reorder to (B, T, V, C, H, W)
video = np.transpose(video, (0, 1, 2, 5, 3, 4)) # (B, 1, V, C, H, W)
obs["video"] = video
# Drop raw images to avoid confusion downstream
for k in img_keys:
obs.pop(k, None)
# 2) Language (string)
lang = comp.get(self.language_key)
if isinstance(lang, list):
lang = lang[0] if len(lang) > 0 else None
if not lang:
lang = "Perform the task."
if self.formalize_language:
lang = (lang or "").lower()
lang = "".join(ch for ch in lang if ch.isalnum() or ch.isspace())
comp["language"] = lang
# 3) State/state_mask -> (B, 1, max_state_dim)
if "observation.state" in obs:
state = obs["observation.state"] # (B, D)
if state.dim() != 2:
raise ValueError(f"state must be (B, D), got {tuple(state.shape)}")
bsz, d = state.shape
# Normalize BEFORE padding
if self.normalize_min_max:
state = _min_max_norm(state, "observation.state")
state = state.unsqueeze(1) # (B, 1, D)
if d > self.max_state_dim:
state = state[:, :, : self.max_state_dim]
d = self.max_state_dim
elif d < self.max_state_dim:
pad = torch.zeros(bsz, 1, self.max_state_dim - d, dtype=state.dtype, device=state.device)
state = torch.cat([state, pad], dim=2)
state_mask = torch.zeros(bsz, 1, self.max_state_dim, dtype=torch.bool, device=state.device)
state_mask[:, :, :d] = True
obs["state"] = state
obs["state_mask"] = state_mask
# 4) Action/action_mask -> (B, action_horizon, max_action_dim)
action = transition.get(TransitionKey.ACTION)
if isinstance(action, torch.Tensor):
# Normalize BEFORE temporal expansion/padding
if self.normalize_min_max:
if action.dim() == 2:
action = _min_max_norm(action, "action")
elif action.dim() == 3:
b, t, d = action.shape
flat = action.reshape(b * t, d)
flat = _min_max_norm(flat, "action")
action = flat.view(b, t, d)
if action.dim() == 2:
action = action.unsqueeze(1).repeat(1, self.action_horizon, 1)
elif action.dim() == 3:
b, t, d = action.shape
if t < self.action_horizon:
last = action[:, -1:, :]
pad = last.repeat(1, self.action_horizon - t, 1)
action = torch.cat([action, pad], dim=1)
elif t > self.action_horizon:
action = action[:, : self.action_horizon, :]
else:
raise ValueError(f"action must be (B, D) or (B, T, D), got {tuple(action.shape)}")
b, t, d = action.shape
if d > self.max_action_dim:
action = action[:, :, : self.max_action_dim]
d = self.max_action_dim
elif d < self.max_action_dim:
pad = torch.zeros(b, t, self.max_action_dim - d, dtype=action.dtype, device=action.device)
action = torch.cat([action, pad], dim=2)
action_mask = torch.zeros(b, t, self.max_action_dim, dtype=torch.bool, device=action.device)
action_mask[:, :, :d] = True
transition[TransitionKey.ACTION] = action
comp["action_mask"] = action_mask
# 5) Embodiment id as LongTensor (B,)
emb_id = self.embodiment_mapping.get(self.embodiment_tag, 0)
# Infer batch size/device from any tensor in obs or action
bsz = None
device = torch.device("cpu")
for v in list(obs.values()) + [transition.get(TransitionKey.ACTION)]:
if isinstance(v, torch.Tensor):
bsz = v.shape[0]
device = v.device
break
if bsz is None and "video" in obs and isinstance(obs["video"], np.ndarray):
bsz = obs["video"].shape[0]
if bsz is None:
bsz = 1
comp["embodiment_id"] = torch.full((bsz,), emb_id, dtype=torch.long, device=device)
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
# Pipeline API requirement: declare how features change (we keep it simple)
def transform_features(self, features):
return features
def get_config(self) -> dict[str, Any]:
"""
Returns a serializable dictionary of the processor's configuration.
Excludes 'stats' since they are saved separately via state_dict().
"""
return {
"state_horizon": self.state_horizon,
"action_horizon": self.action_horizon,
"max_state_dim": self.max_state_dim,
"max_action_dim": self.max_action_dim,
"language_key": self.language_key,
"formalize_language": self.formalize_language,
"embodiment_tag": self.embodiment_tag,
"embodiment_mapping": self.embodiment_mapping,
"normalize_min_max": self.normalize_min_max,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""
Returns normalization statistics as a flat state dictionary.
This enables saving stats to safetensors files, similar to normalizer_processor.
"""
if not self.stats:
return {}
flat: dict[str, torch.Tensor] = {}
for key, sub in self.stats.items():
for stat_name, value in sub.items():
tensor = torch.as_tensor(value).cpu()
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""
Loads normalization statistics from a flat state dictionary.
This enables loading stats from safetensors files during from_pretrained.
"""
if not state:
return
reconstructed: dict[str, dict[str, Any]] = {}
for flat_key, tensor in state.items():
if "." in flat_key:
key, stat_name = flat_key.rsplit(".", 1)
if key not in reconstructed:
reconstructed[key] = {}
reconstructed[key][stat_name] = tensor
if reconstructed:
self.stats = reconstructed
@dataclass
@ProcessorStepRegistry.register(name="groot_eagle_encode_v3")
class GrootEagleEncodeStep(ProcessorStep):
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property
def proc(self) -> ProcessorMixin:
if self._proc is None:
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
return self._proc
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
if "video" not in obs:
return transition
video = obs["video"] # (B, T, V, H, W, C) uint8
lang = comp.get("language", "Perform the task.")
if isinstance(lang, list):
lang = lang[0] if len(lang) > 0 else "Perform the task."
bsz = video.shape[0]
eagle_contents: list[dict[str, Any]] = []
for b in range(bsz):
vt = video[b] # (T, V, C, H, W) after reorder
if vt.ndim != 5:
# Fallback: assume (T, V, H, W, C)
t, v, h, w, c = vt.shape
flat = rearrange(vt, "t v h w c -> (t v) h w c")
else:
t, v, c, h, w = vt.shape
flat = rearrange(vt, "t v c h w -> (t v) h w c")
images = [Image.fromarray(flat[i]) for i in range(t * v)]
# Format language as string list representation to match Original GROOT
lang_formatted = str([lang])
text_content = [{"type": "text", "text": lang_formatted}]
image_content = [{"type": "image", "image": img} for img in images]
conv = [{"role": "user", "content": image_content + text_content}]
text_list = [self.proc.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)]
img_inputs, vid_inputs = self.proc.process_vision_info(conv)
eagle_contents.append(
{
"text_list": text_list,
"image_inputs": img_inputs,
"video_inputs": vid_inputs,
}
)
comp["eagle_content"] = eagle_contents
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
# Pipeline API requirement: declare how features change (no schema change here)
def transform_features(self, features):
return features
# Original GR00T-style collate: converts eagle_content -> eagle_* tensors
def collate(features: list[dict[str, Any]], eagle_processor: ProcessorMixin) -> dict[str, Any]:
batch: dict[str, Any] = {}
keys = features[0].keys()
for key in keys:
values = [elem[key] for elem in features]
if key == "eagle_content":
text_list: list[str] = []
image_inputs: list[Any] = []
for v in values:
curr_text_list = v["text_list"]
curr_image_inputs = v["image_inputs"]
text_list += curr_text_list
image_inputs += curr_image_inputs
eagle_inputs = eagle_processor(
text=text_list,
images=image_inputs,
images_kwargs={"min_dynamic_tiles": 1, "max_dynamic_tiles": 1, "use_thumbnail": False},
return_tensors="pt",
padding=True,
)
for k, v in eagle_inputs.items():
k = "eagle_" + k
batch[k] = v
elif key in ("pixel_values", "image_grid_thw", "attention_mask", "input_ids"):
# Concat in existing batch dimension.
batch[key] = torch.cat(values)
else:
# state, state_mask, action and action_mask.
# Stack to form the batch dimension.
batch[key] = torch.from_numpy(np.stack(values))
return batch
@dataclass
@ProcessorStepRegistry.register(name="groot_eagle_collate_v3")
class GrootEagleCollateStep(ProcessorStep):
tokenizer_assets_repo: str = DEFAULT_TOKENIZER_ASSETS_REPO
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@property
def proc(self) -> ProcessorMixin:
if self._proc is None:
self._proc = _build_eagle_processor(self.tokenizer_assets_repo)
return self._proc
def __call__(self, transition: EnvTransition) -> EnvTransition:
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
contents = comp.get("eagle_content")
if not contents:
return transition
# Build features list as original API expects: one dict per batch item
features = [{"eagle_content": content} for content in contents]
batched = collate(features, self.proc)
# Inject eagle_* tensors and remove the temporary content and raw video to free memory
for k, v in batched.items():
comp[k] = v
comp.pop("eagle_content", None)
obs.pop(
"video", None
) # The video has been fully encoded into eagle_* tensors, so we don't need the raw video anymore
transition[TransitionKey.OBSERVATION] = obs
transition[TransitionKey.COMPLEMENTARY_DATA] = comp
return transition
def transform_features(self, features):
return features
@dataclass
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1")
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
env_action_dim: int = 0
# Apply inverse of min-max normalization if it was used in preprocessor
normalize_min_max: bool = True
stats: dict[str, dict[str, Any]] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Expect model outputs to be in TransitionKey.ACTION as (B, T, D_model)
action = transition.get(TransitionKey.ACTION)
if not isinstance(action, torch.Tensor):
return transition
# Select last timestep and slice to env dimension
if action.dim() == 3:
action = action[:, -1, :]
# Now action is (B, D_model)
if self.env_action_dim and action.shape[-1] >= self.env_action_dim:
action = action[..., : self.env_action_dim]
# Inverse min-max normalization mirroring _min_max_norm:
# forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0
# inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min
if self.normalize_min_max and self.stats is not None:
stats_k = self.stats.get("action", {})
d = action.shape[-1]
min_v = torch.as_tensor(
stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device
)
max_v = torch.as_tensor(
stats_k.get("max", torch.ones(d)), dtype=action.dtype, device=action.device
)
if min_v.numel() != d:
min_v = torch.nn.functional.pad(min_v.flatten()[:d], (0, max(0, d - min_v.numel())))
min_v = min_v.to(action.device, dtype=action.dtype)
if max_v.numel() != d:
max_v = torch.nn.functional.pad(max_v.flatten()[:d], (0, max(0, d - max_v.numel())))
max_v = max_v.to(action.device, dtype=action.dtype)
denom = max_v - min_v
mask = denom != 0
safe_denom = torch.where(mask, denom, torch.ones_like(denom))
inv = (action + 1.0) * 0.5 * safe_denom + min_v
action = torch.where(mask, inv, min_v)
transition[TransitionKey.ACTION] = action
return transition
def transform_features(self, features):
return features
def get_config(self) -> dict[str, Any]:
"""
Returns a serializable dictionary of the processor's configuration.
Excludes 'stats' since they are saved separately via state_dict().
"""
return {
"env_action_dim": self.env_action_dim,
"normalize_min_max": self.normalize_min_max,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""
Returns normalization statistics as a flat state dictionary.
This enables saving stats to safetensors files, similar to normalizer_processor.
"""
if not self.stats:
return {}
flat: dict[str, torch.Tensor] = {}
for key, sub in self.stats.items():
for stat_name, value in sub.items():
tensor = torch.as_tensor(value).cpu()
flat[f"{key}.{stat_name}"] = tensor
return flat
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""
Loads normalization statistics from a flat state dictionary.
This enables loading stats from safetensors files during from_pretrained.
"""
if not state:
return
reconstructed: dict[str, dict[str, Any]] = {}
for flat_key, tensor in state.items():
if "." in flat_key:
key, stat_name = flat_key.rsplit(".", 1)
if key not in reconstructed:
reconstructed[key] = {}
reconstructed[key][stat_name] = tensor
if reconstructed:
self.stats = reconstructed

View File

@@ -1,47 +0,0 @@
from pathlib import Path
from shutil import copytree
from huggingface_hub import hf_hub_download
def ensure_eagle_cache_ready(vendor_dir: Path, cache_dir: Path, assets_repo: str) -> None:
"""Populate the Eagle processor directory in cache and ensure tokenizer assets exist.
- Copies the vendored Eagle files into cache_dir (overwriting when needed).
- Downloads vocab.json and merges.txt into the same cache_dir if missing.
"""
cache_dir = Path(cache_dir)
vendor_dir = Path(vendor_dir)
try:
# Populate/refresh cache with vendor files to ensure a complete processor directory
print(f"[GROOT] Copying vendor Eagle files to cache: {vendor_dir} -> {cache_dir}")
copytree(vendor_dir, cache_dir, dirs_exist_ok=True)
except Exception as exc: # nosec: B110
print(f"[GROOT] Warning: Failed to copy vendor Eagle files to cache: {exc}")
required_assets = [
"vocab.json",
"merges.txt",
"added_tokens.json",
"chat_template.json",
"special_tokens_map.json",
"config.json",
"generation_config.json",
"preprocessor_config.json",
"processor_config.json",
"tokenizer_config.json",
]
print(f"[GROOT] Assets repo: {assets_repo} \n Cache dir: {cache_dir}")
for fname in required_assets:
dst = cache_dir / fname
if not dst.exists():
print(f"[GROOT] Fetching {fname}")
hf_hub_download(
repo_id=assets_repo,
filename=fname,
repo_type="model",
local_dir=str(cache_dir),
)

View File

@@ -20,7 +20,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -48,9 +47,6 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.

View File

@@ -19,12 +19,11 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -43,7 +42,6 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -53,12 +51,6 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -511,10 +503,9 @@ class PaliGemmaWithExpertModel(
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI0 PyTorch model."""
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
def __init__(self, config: PI0Config):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -569,9 +560,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -768,15 +756,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
@@ -818,41 +798,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
time += dt
return x_t
@@ -916,8 +869,7 @@ class PI0Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI0 model
self.init_rtc_processor()
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
self.model = PI0Pytorch(config)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1107,22 +1059,6 @@ class PI0Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1201,10 +1137,6 @@ class PI0Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1216,7 +1148,7 @@ class PI0Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1225,8 +1157,8 @@ class PI0Policy(PreTrainedPolicy):
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
state = self.prepare_state(batch)
# Sample actions using the model (pass through RTC kwargs)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
# Sample actions using the model
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]

View File

@@ -20,7 +20,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
@PreTrainedConfig.register_subclass("pi05")
@@ -47,9 +46,6 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.

View File

@@ -19,12 +19,11 @@ import logging
import math
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Literal
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available
@@ -43,7 +42,6 @@ else:
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -52,12 +50,6 @@ from lerobot.utils.constants import (
)
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64:
@@ -510,10 +502,9 @@ class PaliGemmaWithExpertModel(
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI05 PyTorch model."""
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
def __init__(self, config: PI05Config):
super().__init__()
self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -565,9 +556,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -743,16 +731,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(
self,
images,
img_masks,
tokens,
masks,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
"""Do a full inference forward and compute the action."""
if num_steps is None:
num_steps = self.config.num_inference_steps
@@ -791,40 +770,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
# Define a closure function to properly capture expanded_time
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
x_t = x_t + dt * v_t
time += dt
return x_t
@@ -887,8 +839,7 @@ class PI05Policy(PreTrainedPolicy):
self.config = config
# Initialize the core PI05 model
self.init_rtc_processor()
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
self.model = PI05Pytorch(config)
# Enable gradient checkpointing if requested
if config.gradient_checkpointing:
@@ -1084,22 +1035,6 @@ class PI05Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps),
}
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
model_value = getattr(self, "model", None)
if model_value is not None:
model_value.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model.
@@ -1174,10 +1109,6 @@ class PI05Policy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
# Action queue logic for n_action_steps > 1
@@ -1189,7 +1120,7 @@ class PI05Policy(PreTrainedPolicy):
return self._action_queue.popleft()
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
@@ -1197,8 +1128,8 @@ class PI05Policy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
# Sample actions using the model (no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks)
# Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0]

View File

@@ -1,38 +0,0 @@
# Real-Time Chunking (RTC)
This module contains the LeRobot implementation of **Real-Time Chunking (RTC)**, an inference-time technique for flow-matching based policies.
**Note**: RTC is not a policy itself, but rather an inference enhancement that works with flow-matching based policies including [π₀](../pi0/), [π₀.₅](../pi05/), and [SmolVLA](../smolvla/).
---
## Citation
If you use Real-Time Chunking in your work, please cite:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{black2025realtimeexecutionactionchunking,
title={Real-Time Execution of Action Chunking Flow Policies},
author={Kevin Black and Manuel Y. Galliker and Sergey Levine},
year={2025},
eprint={2506.07339},
archivePrefix={arXiv},
primaryClass={cs.RO},
url={https://arxiv.org/abs/2506.07339},
}
```
---
## License
This implementation follows the **Apache 2.0 License**, consistent with the LeRobot project.

View File

@@ -1,219 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Action queue management for Real-Time Chunking (RTC).
This module provides ActionQueue, a thread-safe queue for managing action chunks
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
handling action merging and leftover tracking.
"""
import logging
from threading import Lock
import torch
from torch import Tensor
from lerobot.policies.rtc.configuration_rtc import RTCConfig
logger = logging.getLogger(__name__)
class ActionQueue:
"""Thread-safe queue for managing action chunks in real-time control.
This queue handles two types of action sequences:
- Original actions: Used for RTC to compute leftovers from previous chunks
- Processed actions: Post-processed actions ready for robot execution
The queue operates in two modes:
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
Args:
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
Attributes:
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
last_index (int): Current consumption index in the queue.
"""
def __init__(self, cfg: RTCConfig):
"""Initialize the action queue.
Args:
cfg: RTC configuration controlling queue behavior.
"""
self.queue = None # Processed actions for robot rollout
self.original_queue = None # Original actions for RTC
self.lock = Lock()
self.last_index = 0
self.cfg = cfg
def get(self) -> Tensor | None:
"""Get the next action from the queue.
Returns:
Tensor | None: The next action (action_dim,) or None if queue is empty.
Returns a clone to prevent external modifications.
"""
with self.lock:
if self.queue is None or self.last_index >= len(self.queue):
return None
action = self.queue[self.last_index]
self.last_index += 1
return action.clone()
def qsize(self) -> int:
"""Get the number of remaining actions in the queue.
Returns:
int: Number of unconsumed actions.
"""
if self.queue is None:
return 0
length = len(self.queue)
return length - self.last_index
def empty(self) -> bool:
"""Check if the queue is empty.
Returns:
bool: True if no actions remain, False otherwise.
"""
if self.queue is None:
return True
length = len(self.queue)
return length - self.last_index <= 0
def get_action_index(self) -> int:
"""Get the current action consumption index.
Returns:
int: Index of the next action to be consumed.
"""
return self.last_index
def get_left_over(self) -> Tensor | None:
"""Get leftover original actions for RTC prev_chunk_left_over.
These are the unconsumed actions from the current chunk, which will be
used by RTC to compute corrections for the next chunk.
Returns:
Tensor | None: Remaining original actions (remaining_steps, action_dim),
or None if no original queue exists.
"""
with self.lock:
if self.original_queue is None:
return None
return self.original_queue[self.last_index :]
def merge(
self,
original_actions: Tensor,
processed_actions: Tensor,
real_delay: int,
action_index_before_inference: int | None = 0,
):
"""Merge new actions into the queue.
This method operates differently based on RTC mode:
- RTC enabled: Replaces the queue, accounting for inference delay
- RTC disabled: Appends to the queue, maintaining continuity
Args:
original_actions: Unprocessed actions from policy (time_steps, action_dim).
processed_actions: Post-processed actions for robot (time_steps, action_dim).
real_delay: Number of time steps of inference delay.
action_index_before_inference: Index before inference started, for validation.
"""
with self.lock:
self._check_delays(real_delay, action_index_before_inference)
if self.cfg.enabled:
self._replace_actions_queue(original_actions, processed_actions, real_delay)
return
self._append_actions_queue(original_actions, processed_actions)
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
"""Replace the queue with new actions (RTC mode).
Discards the first `real_delay` actions since they correspond to the time
spent during inference, when the robot was executing previous actions.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
real_delay: Number of time steps to skip due to inference delay.
"""
self.original_queue = original_actions[real_delay:].clone()
self.queue = processed_actions[real_delay:].clone()
logger.debug(f"original_actions shape: {self.original_queue.shape}")
logger.debug(f"processed_actions shape: {self.queue.shape}")
logger.debug(f"real_delay: {real_delay}")
self.last_index = 0
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
"""Append new actions to the queue (non-RTC mode).
Removes already-consumed actions and appends new ones, maintaining
queue continuity without replacement.
Args:
original_actions: Unprocessed actions from policy.
processed_actions: Post-processed actions for robot.
"""
if self.queue is None:
self.original_queue = original_actions.clone()
self.queue = processed_actions.clone()
return
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
self.original_queue = self.original_queue[self.last_index :]
self.queue = torch.cat([self.queue, processed_actions.clone()])
self.queue = self.queue[self.last_index :]
self.last_index = 0
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
"""Validate that computed delays match expectations.
Compares the delay computed from inference latency with the actual
number of actions consumed during inference.
Args:
real_delay: Delay computed from inference latency.
action_index_before_inference: Action index when inference started.
"""
if action_index_before_inference is None:
return
indexes_diff = self.last_index - action_index_before_inference
if indexes_diff != real_delay:
# Let's check that action index difference (real delay calculated based on action queue)
# is the same as delay calculated based on inference latency
logger.warning(
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
)

Some files were not shown because too many files have changed in this diff Show More