mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Compare commits
16 Commits
pr/3545
...
codex/fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
727ca1a92c | ||
|
|
ee737b72d0 | ||
|
|
cf2e42f557 | ||
|
|
eb9519eb91 | ||
|
|
f29ce39f39 | ||
|
|
682d5f2f95 | ||
|
|
6ee093db9a | ||
|
|
4499519dbf | ||
|
|
d70c3baf7c | ||
|
|
bff2d50dc1 | ||
|
|
8d1401abe3 | ||
|
|
195b777367 | ||
|
|
b49e4016f2 | ||
|
|
02d8a34829 | ||
|
|
14c7a25ce4 | ||
|
|
bc06cb44ca |
6
.github/workflows/benchmark_tests.yml
vendored
6
.github/workflows/benchmark_tests.yml
vendored
@@ -382,7 +382,6 @@ jobs:
|
||||
--policy.path=\"\$ROBOTWIN_POLICY\" \
|
||||
--env.type=robotwin \
|
||||
--env.task=\"\$ROBOTWIN_TASKS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -483,7 +482,6 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_robocasa \
|
||||
--env.type=robocasa \
|
||||
--env.task=CloseFridge,OpenCabinet,OpenDrawer,TurnOnMicrowave,TurnOffStove,CloseToasterOvenDoor,SlideDishwasherRack,TurnOnSinkFaucet,NavigateKitchen,TurnOnElectricKettle \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -695,7 +693,6 @@ jobs:
|
||||
--env.task=\"\$ROBOMME_TASKS\" \
|
||||
--env.dataset_split=test \
|
||||
--env.task_ids=[0] \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -803,7 +800,6 @@ jobs:
|
||||
--env.type=libero_plus \
|
||||
--env.task=\"\$LIBERO_PLUS_SUITE\" \
|
||||
--env.task_ids=\"\$LIBERO_PLUS_TASK_IDS\" \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
@@ -904,8 +900,6 @@ jobs:
|
||||
--policy.path=lerobot/smolvla_vlabench \
|
||||
--env.type=vlabench \
|
||||
--env.task=select_fruit,select_toy,select_book,select_painting,select_drink,select_ingredient,select_billiards,select_poker,add_condiment,insert_flower \
|
||||
--env.episode_length=50 \
|
||||
--env.max_parallel_tasks=5 \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.use_async_envs=false \
|
||||
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
||||
with:
|
||||
package_name: lerobot
|
||||
secrets:
|
||||
|
||||
4
.github/workflows/documentation.yml
vendored
4
.github/workflows/documentation.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
|
||||
16
.github/workflows/stale.yml
vendored
16
.github/workflows/stale.yml
vendored
@@ -19,19 +19,19 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
# Runs at 02:00
|
||||
# schedule:
|
||||
# - cron: "0 2 * * *"
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
env:
|
||||
CLOSE_ISSUE_MESSAGE: >
|
||||
This issue was closed because it has been stalled for 30 days with no activity.
|
||||
This issue was closed because it has been stalled for 14 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
CLOSE_PR_MESSAGE: >
|
||||
This PR was closed because it has been stalled for 30 days with no activity.
|
||||
This PR was closed because it has been stalled for 21 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this issue will reset this count.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
@@ -59,10 +59,10 @@ jobs:
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: never-stale
|
||||
exempt-pr-labels: never-stale
|
||||
days-before-issue-stale: 365
|
||||
days-before-issue-close: 30
|
||||
days-before-issue-stale: 180
|
||||
days-before-issue-close: 14
|
||||
days-before-pr-stale: 365
|
||||
days-before-pr-close: 30
|
||||
days-before-pr-close: 21
|
||||
delete-branch: true
|
||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -35,7 +35,7 @@ USER root
|
||||
ARG ROBOTWIN_SHA=0aeea2d669c0f8516f4d5785f0aa33ba812c14b4
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
cuda-nvcc-12-6 cuda-cudart-dev-12-6 \
|
||||
cuda-nvcc-12-4 cuda-cudart-dev-12-4 \
|
||||
libvulkan1 vulkan-tools \
|
||||
&& mkdir -p /usr/share/vulkan/icd.d \
|
||||
&& echo '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.3.0"}}' \
|
||||
|
||||
@@ -18,8 +18,9 @@
|
||||
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||
|
||||
# Configure the base image for CI with GPU access
|
||||
ARG CUDA_VERSION=12.6.3
|
||||
ARG OS_VERSION=24.04
|
||||
# TODO(Steven): Bump these versions
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG OS_VERSION=22.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version argument
|
||||
@@ -35,13 +36,16 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
# Install Python, system dependencies, and uv (as root)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential git curl \
|
||||
libglib2.0-0 libgl1 libegl1 ffmpeg \
|
||||
software-properties-common build-essential git curl \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
cmake pkg-config ninja-build \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
|
||||
@@ -47,10 +47,6 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: evo1
|
||||
title: EVO1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
# EO-1
|
||||
|
||||
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
|
||||
|
||||
## Model Overview
|
||||
|
||||
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
|
||||
alt="An overview of EO-1"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=eo1` configuration through LeRobot
|
||||
- Qwen2.5-VL image and text preprocessing through policy processors
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with `lerobot-eval`
|
||||
|
||||
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EO-1 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1]"
|
||||
```
|
||||
|
||||
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1,libero]"
|
||||
```
|
||||
|
||||
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EO-1 expects a LeRobot dataset with:
|
||||
|
||||
- At least one visual observation, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction through the dataset `task` field
|
||||
|
||||
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EO-1 in a LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=eo1
|
||||
```
|
||||
|
||||
By default, a new EO-1 policy initializes its backbone from:
|
||||
|
||||
```python
|
||||
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
|
||||
```
|
||||
|
||||
Once a LeRobot-format EO-1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-eo1-checkpoint
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=eo1 \
|
||||
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.attn_implementation=sdpa \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--output_dir=./outputs/eo1_training \
|
||||
--job_name=eo1_training \
|
||||
--steps=300000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
|
||||
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
|
||||
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
|
||||
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
|
||||
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
|
||||
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
|
||||
| `policy.max_state_dim` | `32` | State padding dimension |
|
||||
| `policy.max_action_dim` | `32` | Action padding dimension |
|
||||
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
|
||||
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
|
||||
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
|
||||
|
||||
## Evaluation
|
||||
|
||||
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20
|
||||
```
|
||||
|
||||
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Image Processing
|
||||
|
||||
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
|
||||
|
||||
### State and Action Dimensions
|
||||
|
||||
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
|
||||
|
||||
### Attention Backend
|
||||
|
||||
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
|
||||
|
||||
## References
|
||||
|
||||
- [EO-1 project](https://github.com/EO-Robotics/EO1)
|
||||
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
|
||||
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{eo1,
|
||||
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
|
||||
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
|
||||
journal={arXiv preprint},
|
||||
year={2025},
|
||||
url={https://arxiv.org/abs/2508.21112}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.
|
||||
@@ -1,132 +0,0 @@
|
||||
# EVO1
|
||||
|
||||
EVO1 is a Vision-Language-Action policy for robot control built around an InternVL3 backbone and a continuous flow-matching action head. This LeRobot integration exposes EVO1 as a standard policy type so it can be trained and evaluated with the usual LeRobot dataset, checkpoint, and processor APIs.
|
||||
|
||||
## Model Overview
|
||||
|
||||
The policy embeds one or more camera images and the language task prompt with InternVL3, pads robot state/action vectors to fixed maximum dimensions, and predicts future action chunks with a flow-matching action head. During inference, the policy samples an action chunk and returns `n_action_steps` actions from that chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=evo1` configuration through LeRobot
|
||||
- InternVL3 image/text embedding with optional FlashAttention fallback
|
||||
- Stage-based finetuning controls for action-head-only and VLM finetuning runs
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with standard policy inference APIs
|
||||
|
||||
The broader EVO1 project may include additional training scripts and dataset tooling. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EVO1 dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e ".[evo1]"
|
||||
```
|
||||
|
||||
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available.
|
||||
|
||||
EVO1 uses InternVL3 through the Hugging Face `transformers` remote-code path, so the first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EVO1 expects a LeRobot dataset with:
|
||||
|
||||
- One to `policy.max_views` visual observations, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction in the dataset `task` field, or another field configured with `policy.task_field`
|
||||
|
||||
State and action vectors are padded to `policy.max_state_dim` and `policy.max_action_dim`. Predictions are cropped back to the dataset action dimension before being returned.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EVO1 in a LeRobot configuration, specify:
|
||||
|
||||
```python
|
||||
policy.type=evo1
|
||||
```
|
||||
|
||||
By default, a new EVO1 policy initializes its VLM from:
|
||||
|
||||
```python
|
||||
policy.vlm_model_name=OpenGVLab/InternVL3-1B
|
||||
```
|
||||
|
||||
Once a LeRobot-format EVO1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-evo1-checkpoint
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Stage 1
|
||||
|
||||
Stage 1 freezes the VLM and trains the action head:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=evo1 \
|
||||
--policy.training_stage=stage1 \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
|
||||
--policy.device=cuda \
|
||||
--policy.chunk_size=50 \
|
||||
--policy.n_action_steps=50 \
|
||||
--policy.max_state_dim=24 \
|
||||
--policy.max_action_dim=24 \
|
||||
--policy.optimizer_lr=1e-5 \
|
||||
--batch_size=4 \
|
||||
--steps=5000 \
|
||||
--output_dir=./outputs/evo1_stage1
|
||||
```
|
||||
|
||||
### Stage 2
|
||||
|
||||
Stage 2 finetunes the VLM branches and action head. A common workflow starts from a Stage 1 checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.path=./outputs/evo1_stage1/checkpoints/005000/pretrained_model \
|
||||
--policy.training_stage=stage2 \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B \
|
||||
--policy.device=cuda \
|
||||
--policy.chunk_size=50 \
|
||||
--policy.n_action_steps=50 \
|
||||
--policy.max_state_dim=24 \
|
||||
--policy.max_action_dim=24 \
|
||||
--policy.optimizer_lr=1e-5 \
|
||||
--batch_size=4 \
|
||||
--steps=80000 \
|
||||
--output_dir=./outputs/evo1_stage2
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| --------------------------------------------- | ------------------------ | ----------------------------------------------------------------- |
|
||||
| `policy.vlm_model_name` | `OpenGVLab/InternVL3-1B` | InternVL3 checkpoint or local model directory |
|
||||
| `policy.training_stage` | `stage1` | `stage1` trains the action head; `stage2` finetunes VLM branches |
|
||||
| `policy.vlm_num_layers` | `14` | Number of InternVL3 language layers kept for the policy |
|
||||
| `policy.vlm_dtype` | `bfloat16` | Requested VLM dtype |
|
||||
| `policy.use_flash_attn` | `true` | Requests FlashAttention when installed; otherwise falls back |
|
||||
| `policy.enable_gradient_checkpointing` | `true` | Enables checkpointing on supported InternVL3 modules |
|
||||
| `policy.gradient_checkpointing_use_reentrant` | `false` | Reentrant setting passed to gradient checkpointing when supported |
|
||||
| `policy.chunk_size` | `50` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `50` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.max_state_dim` | `24` | State padding dimension |
|
||||
| `policy.max_action_dim` | `24` | Action padding dimension |
|
||||
| `policy.task_field` | `task` | Batch field used as the language prompt |
|
||||
|
||||
## References
|
||||
|
||||
- [EVO1 repository](https://github.com/MINT-SJTU/Evo-1)
|
||||
- [InternVL3-1B](https://huggingface.co/OpenGVLab/InternVL3-1B)
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the Apache 2.0 License used by LeRobot. Check the upstream EVO1 and InternVL3 model pages for the licenses of released checkpoints and data.
|
||||
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
|
||||
|
||||
## Inputs and Targets (What the new code expects)
|
||||
|
||||
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
|
||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||
|
||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
|
||||
<hfoption id="single_stage">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -360,7 +360,7 @@ python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
<hfoption id="dense_only">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -373,7 +373,7 @@ python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
<hfoption id="dual">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
|
||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||
|
||||
```bash
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
@@ -465,15 +465,15 @@ This script:
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.head_mode=sparse \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -488,13 +488,12 @@ The training script automatically:
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
@@ -512,30 +511,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ----------------------------- | ------------- | ------------------------- |
|
||||
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--sample_weighting.kappa=0.03
|
||||
--rabc_kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--sample_weighting.kappa=0.05
|
||||
--rabc_kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--sample_weighting.kappa=0.07
|
||||
--rabc_kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
@@ -551,8 +550,8 @@ accelerate launch \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -577,7 +576,7 @@ accelerate launch \
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -220,7 +220,7 @@ REAL_DIM = 12
|
||||
# Postprocessing: Trim 20D predictions to 12D for deployment
|
||||
```
|
||||
|
||||
See the [action_hub.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||
|
||||
#### Auto Action Mode (Recommended)
|
||||
|
||||
@@ -519,9 +519,9 @@ If you use X-VLA in your research, please cite:
|
||||
|
||||
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Action Registry Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/action_hub.py)
|
||||
- [Processor Implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/processor_xvla.py)
|
||||
- [Model Configuration](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
|
||||
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
|
||||
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||
|
||||
## Contributing
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.rewards.sarm.compute_rabc_weights import (
|
||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
||||
generate_all_frame_indices,
|
||||
interpolate_progress,
|
||||
load_sarm_resources,
|
||||
|
||||
@@ -10,7 +10,7 @@ from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
|
||||
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
|
||||
|
||||
|
||||
def main():
|
||||
@@ -22,10 +22,10 @@ def main():
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make reward model, preprocessor, and optimizer
|
||||
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
|
||||
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
|
||||
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
|
||||
@@ -42,7 +42,7 @@ def main():
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = reward_model.forward(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -58,8 +58,8 @@ def main():
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained reward model.
|
||||
reward_model.push_to_hub(classifier_id)
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -59,8 +59,8 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
|
||||
dependencies = [
|
||||
# Core ML
|
||||
"torch>=2.7,<2.13.0",
|
||||
"torchvision>=0.22.0,<0.28.0",
|
||||
"torch>=2.7,<2.11.0",
|
||||
"torchvision>=0.22.0,<0.26.0",
|
||||
"numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
|
||||
"opencv-python-headless>=4.9.0,<4.14.0",
|
||||
"Pillow>=10.0.0,<13.0.0",
|
||||
@@ -99,7 +99,7 @@ dataset = [
|
||||
"pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
|
||||
"pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
|
||||
"lerobot[av-dep]",
|
||||
"torchcodec>=0.3.0,<0.13.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10), 0.11 needs torch==2.11, 0.12 needs torch==2.12.
|
||||
"torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
|
||||
"jsonlines>=4.0.0,<5.0.0",
|
||||
]
|
||||
training = [
|
||||
@@ -128,7 +128,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -194,8 +194,6 @@ groot = [
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
evo1 = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -259,7 +257,6 @@ all = [
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[evo1]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -336,7 +333,6 @@ ignore = [
|
||||
# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
|
||||
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
|
||||
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
|
||||
"src/lerobot/policies/evo1/**" = ["N801", "N812"]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
combine-as-imports = true
|
||||
|
||||
@@ -17,7 +17,6 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -42,7 +41,6 @@ from ..utils import get_cv2_rotation
|
||||
from .configuration_realsense import RealSenseCameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
pkg_name = "pyrealsense2-macosx" if sys.platform == "darwin" else "pyrealsense2"
|
||||
|
||||
|
||||
class RealSenseCamera(Camera):
|
||||
@@ -116,7 +114,7 @@ class RealSenseCamera(Camera):
|
||||
Args:
|
||||
config: The configuration settings for the camera.
|
||||
"""
|
||||
require_package(pkg_name, extra="intelrealsense", import_name="pyrealsense2")
|
||||
require_package("pyrealsense2", extra="intelrealsense")
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
|
||||
@@ -41,12 +41,8 @@ def cfg_to_group(
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
trainable_tag = f"reward_model:{cfg.reward_model.type}"
|
||||
else:
|
||||
trainable_tag = f"policy:{cfg.policy.type}"
|
||||
lst = [
|
||||
trainable_tag,
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
T = TypeVar("T", bound="RewardModelConfig")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""Base configuration for reward models.
|
||||
|
||||
Args:
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
"""
|
||||
|
||||
# Reuses PolicyFeature
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None
|
||||
|
||||
pretrained_path: str | None = None
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str | None = None
|
||||
|
||||
# Hub metadata
|
||||
license: str | None = None
|
||||
tags: list[str] | None = None
|
||||
private: bool | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
pass
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**reward_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
if Path(model_id).is_dir():
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
# HACK: Parse the original config to get the config subclass, so that we can
|
||||
# apply cli overrides.
|
||||
with draccus.config_type("json"):
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config.pop("type", None)
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(config, f)
|
||||
config_file = f.name
|
||||
|
||||
cli_overrides = reward_kwargs.pop("cli_overrides", [])
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
|
||||
@@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
import builtins
|
||||
import datetime as dt
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -28,57 +26,18 @@ from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||
legacy_fields = (
|
||||
"use_rabc",
|
||||
"rabc_progress_path",
|
||||
"rabc_kappa",
|
||||
"rabc_epsilon",
|
||||
"rabc_head_mode",
|
||||
)
|
||||
if not any(key in config for key in legacy_fields):
|
||||
return None
|
||||
|
||||
migrated_config = dict(config)
|
||||
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||
|
||||
# New configs may already define sample_weighting explicitly. In that case,
|
||||
# legacy fields are ignored after being stripped from the payload.
|
||||
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||
if rabc_progress_path is not None:
|
||||
sample_weighting["progress_path"] = rabc_progress_path
|
||||
if rabc_kappa is not None:
|
||||
sample_weighting["kappa"] = rabc_kappa
|
||||
if rabc_epsilon is not None:
|
||||
sample_weighting["epsilon"] = rabc_epsilon
|
||||
if rabc_head_mode is not None:
|
||||
sample_weighting["head_mode"] = rabc_head_mode
|
||||
migrated_config["sample_weighting"] = sample_weighting
|
||||
|
||||
return migrated_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
reward_model: RewardModelConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
@@ -113,41 +72,27 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def is_reward_model_training(self) -> bool:
|
||||
"""True when the config targets a reward model rather than a policy."""
|
||||
return self.reward_model is not None
|
||||
|
||||
@property
|
||||
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
|
||||
"""Return whichever config (policy or reward_model) is active."""
|
||||
if self.is_reward_model_training:
|
||||
return self.reward_model # type: ignore[return-value]
|
||||
return self.policy # type: ignore[return-value]
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
reward_model_path = parser.get_path_arg("reward_model")
|
||||
|
||||
if reward_model_path:
|
||||
cli_overrides = parser.get_cli_overrides("reward_model")
|
||||
self.reward_model = RewardModelConfig.from_pretrained(
|
||||
reward_model_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||
elif policy_path:
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
@@ -163,22 +108,18 @@ class TrainPipelineConfig(HubMixin):
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.pretrained_path = str(policy_dir)
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
if self.policy is None and self.reward_model is None:
|
||||
if self.policy is None:
|
||||
raise ValueError(
|
||||
"Neither policy nor reward_model is configured. "
|
||||
"Please specify one with `--policy.path` or `--reward_model.path`."
|
||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
self.job_name = f"{self.policy.type}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{active_cfg.type}"
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
|
||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||
raise FileExistsError(
|
||||
@@ -196,16 +137,26 @@ class TrainPipelineConfig(HubMixin):
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
elif self.use_policy_training_preset and not self.resume:
|
||||
self.optimizer = active_cfg.get_optimizer_preset()
|
||||
self.scheduler = active_cfg.get_scheduler_preset()
|
||||
self.optimizer = self.policy.get_optimizer_preset()
|
||||
self.scheduler = self.policy.get_scheduler_preset()
|
||||
|
||||
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||
if self.policy.push_to_hub and not self.policy.repo_id:
|
||||
raise ValueError(
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""Keys for draccus pretrained-path loading."""
|
||||
return ["policy", "reward_model"]
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||
@@ -256,17 +207,6 @@ class TrainPipelineConfig(HubMixin):
|
||||
) from e
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
||||
# Hand-written YAML/TOML configs are expected to use the current sample_weighting schema.
|
||||
if config_file is not None and config_file.endswith(".json"):
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||
if migrated_config is not None:
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(migrated_config, f)
|
||||
config_file = f.name
|
||||
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@@ -97,8 +97,8 @@ def update_data_df(df, src_meta, dst_meta):
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||
"""
|
||||
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||
df["index"] = df["index"] + dst_meta.info.total_frames
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
||||
|
||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
||||
@@ -225,9 +225,9 @@ def update_meta_data(
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info.total_frames
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info.total_frames
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
|
||||
return df
|
||||
|
||||
@@ -237,8 +237,8 @@ def aggregate_datasets(
|
||||
aggr_repo_id: str,
|
||||
roots: list[Path] | None = None,
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
@@ -313,8 +313,8 @@ def aggregate_datasets(
|
||||
# to avoid interference between different source datasets
|
||||
data_idx.pop("src_to_dst", None)
|
||||
|
||||
dst_meta.info.total_episodes += src_meta.total_episodes
|
||||
dst_meta.info.total_frames += src_meta.total_frames
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
finalize_aggregation(dst_meta, all_metadata)
|
||||
logging.info("Aggregation complete.")
|
||||
@@ -640,10 +640,14 @@ def finalize_aggregation(aggr_meta, all_metadata):
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
logging.info("write info")
|
||||
aggr_meta.info.total_tasks = len(aggr_meta.tasks)
|
||||
aggr_meta.info.total_episodes = sum(m.total_episodes for m in all_metadata)
|
||||
aggr_meta.info.total_frames = sum(m.total_frames for m in all_metadata)
|
||||
aggr_meta.info.splits = {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}
|
||||
aggr_meta.info.update(
|
||||
{
|
||||
"total_tasks": len(aggr_meta.tasks),
|
||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
||||
}
|
||||
)
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
logging.info("write stats")
|
||||
|
||||
@@ -37,11 +37,13 @@ from .io_utils import (
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_json,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
check_version_compatibility,
|
||||
get_safe_version,
|
||||
has_legacy_hub_download_metadata,
|
||||
@@ -226,7 +228,7 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info.codebase_version)
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
"""Return the relative parquet file path for the given episode index.
|
||||
@@ -281,27 +283,27 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info.data_path
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def video_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info.video_path
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
return self.info.robot_type
|
||||
return self.info["robot_type"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info.fps
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
"""All features contained in the dataset."""
|
||||
return self.info.features
|
||||
return self.info["features"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
@@ -331,32 +333,32 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info.total_episodes
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info.total_frames
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info.total_tasks
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of files per chunk."""
|
||||
return self.info.chunks_size
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def data_files_size_in_mb(self) -> int:
|
||||
"""Max size of data file in mega bytes."""
|
||||
return self.info.data_files_size_in_mb
|
||||
return self.info["data_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def video_files_size_in_mb(self) -> int:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info.video_files_size_in_mb
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
@@ -500,10 +502,10 @@ class LeRobotDatasetMetadata:
|
||||
self._save_episode_metadata(episode_dict)
|
||||
|
||||
# Update info
|
||||
self.info.total_episodes += 1
|
||||
self.info.total_frames += episode_length
|
||||
self.info.total_tasks = len(self.tasks)
|
||||
self.info.splits = {"train": f"0:{self.info.total_episodes}"}
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
@@ -522,7 +524,7 @@ class LeRobotDatasetMetadata:
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
@@ -544,17 +546,17 @@ class LeRobotDatasetMetadata:
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||
self.info.chunks_size = chunks_size
|
||||
self.info["chunks_size"] = chunks_size
|
||||
|
||||
if data_files_size_in_mb is not None:
|
||||
if data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||
self.info.data_files_size_in_mb = data_files_size_in_mb
|
||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
||||
|
||||
if video_files_size_in_mb is not None:
|
||||
if video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info.video_files_size_in_mb = video_files_size_in_mb
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
@@ -651,7 +653,7 @@ class LeRobotDatasetMetadata:
|
||||
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
||||
"Either remove video features from the features dict, or set 'use_videos=True'."
|
||||
)
|
||||
write_info(obj.info, obj.root)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj._pq_writer = None
|
||||
obj.latest_episode = None
|
||||
|
||||
@@ -897,10 +897,14 @@ def _copy_and_reindex_episodes_metadata(
|
||||
|
||||
dst_meta.finalize()
|
||||
|
||||
dst_meta.info.total_episodes = len(episode_mapping)
|
||||
dst_meta.info.total_frames = total_frames
|
||||
dst_meta.info.total_tasks = len(dst_meta.tasks) if dst_meta.tasks is not None else 0
|
||||
dst_meta.info.splits = {"train": f"0:{len(episode_mapping)}"}
|
||||
dst_meta.info.update(
|
||||
{
|
||||
"total_episodes": len(episode_mapping),
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0,
|
||||
"splits": {"train": f"0:{len(episode_mapping)}"},
|
||||
}
|
||||
)
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
if not all_stats:
|
||||
@@ -1065,20 +1069,21 @@ def _copy_episodes_metadata_and_stats(
|
||||
if episodes_dir.exists():
|
||||
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
|
||||
|
||||
dst_meta.info.total_episodes = src_dataset.meta.total_episodes
|
||||
dst_meta.info.total_frames = src_dataset.meta.total_frames
|
||||
dst_meta.info.total_tasks = src_dataset.meta.total_tasks
|
||||
# Preserve original splits if available, otherwise create default
|
||||
dst_meta.info.splits = (
|
||||
src_dataset.meta.info.splits
|
||||
if src_dataset.meta.info.splits
|
||||
else {"train": f"0:{src_dataset.meta.total_episodes}"}
|
||||
dst_meta.info.update(
|
||||
{
|
||||
"total_episodes": src_dataset.meta.total_episodes,
|
||||
"total_frames": src_dataset.meta.total_frames,
|
||||
"total_tasks": src_dataset.meta.total_tasks,
|
||||
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
|
||||
}
|
||||
)
|
||||
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
|
||||
"info", {}
|
||||
)
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
@@ -1520,7 +1525,7 @@ def modify_tasks(
|
||||
write_tasks(new_task_df, root)
|
||||
|
||||
# Update info.json
|
||||
dataset.meta.info.total_tasks = len(unique_tasks)
|
||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
||||
write_info(dataset.meta.info, root)
|
||||
|
||||
# Reload metadata to reflect changes
|
||||
@@ -1853,10 +1858,10 @@ def convert_image_to_video_dataset(
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info.total_episodes = len(episode_indices)
|
||||
new_meta.info.total_frames = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info.total_tasks = dataset.meta.total_tasks
|
||||
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
@@ -1865,7 +1870,7 @@ def convert_image_to_video_dataset(
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from pprint import pformat
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
||||
@@ -31,14 +30,12 @@ from .streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
) -> dict[str, list] | None:
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
|
||||
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
|
||||
``{observation,action,reward}_delta_indices`` properties used below.
|
||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||
delta_timestamps against.
|
||||
|
||||
@@ -85,7 +82,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
if not cfg.dataset.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
|
||||
@@ -28,7 +28,6 @@ from .utils import (
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
DatasetInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,8 +78,8 @@ def create_empty_dataset_info(
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> DatasetInfo:
|
||||
"""Create a template ``DatasetInfo`` object for a new dataset's ``meta/info.json``.
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
Args:
|
||||
codebase_version (str): The version of the LeRobot codebase.
|
||||
@@ -88,24 +87,25 @@ def create_empty_dataset_info(
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
chunks_size (int | None): Max files per chunk directory. Defaults to ``DEFAULT_CHUNK_SIZE``.
|
||||
data_files_size_in_mb (int | None): Max parquet file size in MB. Defaults to ``DEFAULT_DATA_FILE_SIZE_IN_MB``.
|
||||
video_files_size_in_mb (int | None): Max video file size in MB. Defaults to ``DEFAULT_VIDEO_FILE_SIZE_IN_MB``.
|
||||
|
||||
Returns:
|
||||
DatasetInfo: A typed dataset information object with initial metadata.
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
"""
|
||||
return DatasetInfo(
|
||||
codebase_version=codebase_version,
|
||||
fps=fps,
|
||||
features=features,
|
||||
robot_type=robot_type,
|
||||
chunks_size=chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
data_files_size_in_mb=data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb=video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_path=DEFAULT_DATA_PATH,
|
||||
video_path=DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
|
||||
@@ -39,7 +39,6 @@ from .utils import (
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
DatasetInfo,
|
||||
serialize_dict,
|
||||
)
|
||||
|
||||
@@ -116,21 +115,25 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
return dataset
|
||||
|
||||
|
||||
def write_info(info: DatasetInfo, local_dir: Path) -> None:
|
||||
write_json(info.to_dict(), local_dir / INFO_PATH)
|
||||
def write_info(info: dict, local_dir: Path) -> None:
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> DatasetInfo:
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
"""Load dataset info metadata from its standard file path.
|
||||
|
||||
Also converts shape lists to tuples for consistency.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
DatasetInfo: The typed dataset information object.
|
||||
dict: The dataset information dictionary.
|
||||
"""
|
||||
raw = load_json(local_dir / INFO_PATH)
|
||||
return DatasetInfo.from_dict(raw)
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path) -> None:
|
||||
|
||||
@@ -630,8 +630,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a new LeRobotDataset from scratch for recording data.
|
||||
|
||||
@@ -679,8 +677,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
metadata_buffer_size=metadata_buffer_size,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj._requested_root = obj.meta.root
|
||||
|
||||
@@ -123,7 +123,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info.fps
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
@@ -133,7 +133,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return len(self._datasets[0].meta.video_keys) > 0
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
|
||||
@@ -434,7 +434,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
def _make_padding_camera_frame(self, camera_key: str):
|
||||
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||
return torch.zeros(self.meta.info.features[camera_key]["shape"]).permute(-1, 0, 1)
|
||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||
|
||||
def _get_video_frame_padding_mask(
|
||||
self,
|
||||
|
||||
@@ -14,11 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -72,12 +70,9 @@ class ForwardCompatibilityError(CompatibilityError):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 50 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
@@ -99,123 +94,6 @@ LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
"""Typed representation of the ``meta/info.json`` file for a LeRobot dataset.
|
||||
|
||||
Replaces the previously untyped ``dict`` returned by ``load_info()`` and
|
||||
created by ``create_empty_dataset_info()``. Using a dataclass provides
|
||||
explicit field definitions, IDE auto-completion, and validation at
|
||||
construction time.
|
||||
"""
|
||||
|
||||
codebase_version: str
|
||||
fps: int
|
||||
features: dict[str, dict]
|
||||
|
||||
# Episode / frame counters — start at zero for new datasets
|
||||
total_episodes: int = 0
|
||||
total_frames: int = 0
|
||||
total_tasks: int = 0
|
||||
|
||||
# Storage settings
|
||||
chunks_size: int = field(default=DEFAULT_CHUNK_SIZE)
|
||||
data_files_size_in_mb: int = field(default=DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: int = field(default=DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
|
||||
# File path templates
|
||||
data_path: str = field(default=DEFAULT_DATA_PATH)
|
||||
video_path: str | None = field(default=DEFAULT_VIDEO_PATH)
|
||||
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
# returns lists, but the rest of the codebase expects tuples.
|
||||
for ft in self.features.values():
|
||||
if isinstance(ft.get("shape"), list):
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
if self.chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {self.chunks_size}")
|
||||
if self.data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {self.data_files_size_in_mb}")
|
||||
if self.video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {self.video_files_size_in_mb}")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DatasetInfo":
|
||||
"""Construct from a raw dict (e.g. loaded directly from JSON).
|
||||
|
||||
Unknown keys are ignored for forward compatibility with datasets that
|
||||
carry additional fields (e.g. ``total_videos`` from v2.x). A warning is
|
||||
logged when such fields are present.
|
||||
"""
|
||||
known = {f.name for f in dataclasses.fields(cls)}
|
||||
unknown = sorted(k for k in data if k not in known)
|
||||
if unknown:
|
||||
logger.warning(f"Unknown fields in DatasetInfo: {unknown}. These will be ignored.")
|
||||
return cls(**{k: v for k, v in data.items() if k in known})
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Temporary dict-style compatibility layer
|
||||
# Allows existing ``info["key"]`` call-sites to keep working without changes.
|
||||
# Once all callers have been migrated to attribute access, remove these.
|
||||
# ---------------------------------------------------------------------------
|
||||
def __getitem__(self, key: str):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Accessing DatasetInfo with dict-style syntax info['{key}'] is deprecated. "
|
||||
f"Use attribute access info.{key} instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError as err:
|
||||
raise KeyError(key) from err
|
||||
|
||||
def __setitem__(self, key: str, value) -> None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Setting DatasetInfo with dict-style syntax info['{key}'] = ... is deprecated. "
|
||||
f"Use attribute assignment info.{key} = ... instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not hasattr(self, key):
|
||||
raise KeyError(f"DatasetInfo has no field '{key}'")
|
||||
setattr(self, key, value)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if a field exists (dict-like interface)."""
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
"""Get attribute value with default fallback (dict-like interface)."""
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
return default
|
||||
|
||||
|
||||
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
||||
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
||||
|
||||
@@ -416,7 +294,7 @@ def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) ->
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None,
|
||||
dataset_info: DatasetInfo | None = None,
|
||||
dataset_info: dict | None = None,
|
||||
**kwargs,
|
||||
) -> DatasetCard:
|
||||
"""Create a `DatasetCard` for a LeRobot dataset.
|
||||
@@ -427,7 +305,7 @@ def create_lerobot_dataset_card(
|
||||
|
||||
Args:
|
||||
tags (list | None): A list of tags to add to the dataset card.
|
||||
dataset_info (DatasetInfo | None): The dataset's info object, which will
|
||||
dataset_info (dict | None): The dataset's info dictionary, which will
|
||||
be displayed on the card.
|
||||
**kwargs: Additional keyword arguments to populate the card template.
|
||||
|
||||
@@ -440,7 +318,7 @@ def create_lerobot_dataset_card(
|
||||
card_tags += tags
|
||||
if dataset_info:
|
||||
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info.to_dict(), indent=4)}\n```\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
||||
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
||||
card_data = DatasetCardData(
|
||||
license=kwargs.get("license"),
|
||||
|
||||
@@ -16,8 +16,6 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .evo1.configuration_evo1 import Evo1Config as Evo1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
@@ -26,6 +24,8 @@ from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
@@ -41,14 +41,14 @@ __all__ = [
|
||||
# Configuration classes
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"Evo1Config",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"EO1Config",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"RewardClassifierConfig",
|
||||
"SACConfig",
|
||||
"SARMConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -142,10 +142,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
abs_err = F.l1_loss(batch[ACTION], actions_hat, reduction="none")
|
||||
valid_mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||
num_valid = valid_mask.sum() * abs_err.shape[-1]
|
||||
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
l1_loss = (
|
||||
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
|
||||
@@ -100,8 +100,8 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 64
|
||||
n_action_steps: int = 32
|
||||
horizon: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -122,10 +122,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
use_group_norm: bool = False
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = True
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
|
||||
@@ -380,9 +380,7 @@ class DiffusionModel(nn.Module):
|
||||
f"{self.config.do_mask_loss_for_padding=}."
|
||||
)
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
mask = in_episode_bound.unsqueeze(-1)
|
||||
num_valid = mask.sum() * loss.shape[-1]
|
||||
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/eo1.mdx
|
||||
@@ -1,7 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from .configuration_eo1 import EO1Config
|
||||
from .modeling_eo1 import EO1Policy
|
||||
from .processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]
|
||||
@@ -1,193 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLTextConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
else:
|
||||
Qwen2_5_VLConfig = None
|
||||
Qwen2_5_VLTextConfig = None
|
||||
Qwen2_5_VLVisionConfig = None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("eo1")
|
||||
@dataclass
|
||||
class EO1Config(PreTrainedConfig):
|
||||
"""Configuration for native EO1 policy integration in LeRobot."""
|
||||
|
||||
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
vlm_config: dict | None = None
|
||||
|
||||
# Vision processor settings.
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
# Execution and action horizon.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 8
|
||||
n_action_steps: int = 8
|
||||
|
||||
# State/action padding to match EO1 flow head dimensionality.
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching sampling.
|
||||
num_denoise_steps: int = 10
|
||||
num_action_layers: int = 2
|
||||
action_act: str = "linear"
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
supervise_padding_action_dims: bool = True
|
||||
supervise_padding_actions: bool = True
|
||||
|
||||
# Policy-level dtype request for the Qwen backbone.
|
||||
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
|
||||
# The EO1 flow-matching head still keeps its own parameters in fp32.
|
||||
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
|
||||
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
|
||||
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
|
||||
force_fp32_autocast: bool = True
|
||||
|
||||
# Optional attention backend request passed through to the Qwen backbone.
|
||||
# Common values: None, "eager", "sdpa", "flash_attention_2".
|
||||
attn_implementation: str | None = None
|
||||
|
||||
# Training settings.
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.1
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# Populate the serialized backbone config only when the caller did not provide one.
|
||||
if self.vlm_config is None:
|
||||
require_package("transformers", extra="eo1")
|
||||
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
|
||||
|
||||
@property
|
||||
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
|
||||
require_package("transformers", extra="eo1")
|
||||
config_dict = deepcopy(self.vlm_config)
|
||||
if self.attn_implementation is not None:
|
||||
config_dict["attn_implementation"] = self.attn_implementation
|
||||
return Qwen2_5_VLConfig(**config_dict)
|
||||
|
||||
@property
|
||||
def text_config(self) -> Qwen2_5_VLTextConfig:
|
||||
return self.vlm_backbone_config.text_config
|
||||
|
||||
@property
|
||||
def vision_config(self) -> Qwen2_5_VLVisionConfig:
|
||||
return self.vlm_backbone_config.vision_config
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up EO1 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"EO1 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,620 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils import torch_compilable_check
|
||||
else:
|
||||
ACT2FN = None
|
||||
Qwen2_5_VLForConditionalGeneration = None
|
||||
torch_compilable_check = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] >= new_dim:
|
||||
return vector
|
||||
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
||||
|
||||
|
||||
class EO1Policy(PreTrainedPolicy):
|
||||
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
|
||||
|
||||
config_class = EO1Config
|
||||
name = "eo1"
|
||||
|
||||
def __init__(self, config: EO1Config, **kwargs):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
if config.pretrained_path is None:
|
||||
# Initialize from pretrained VLM
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config.vlm_base,
|
||||
dtype=config.dtype,
|
||||
attn_implementation=config.attn_implementation,
|
||||
)
|
||||
else:
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
|
||||
config.vlm_backbone_config,
|
||||
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
|
||||
)
|
||||
|
||||
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
|
||||
if config.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
self.model.to(config.device)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||
|
||||
@staticmethod
|
||||
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
|
||||
return {key: value for key, value in batch.items() if key not in excluded_keys}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
state = self.prepare_state(batch[OBS_STATE])
|
||||
actions = self.prepare_action(batch[ACTION])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
|
||||
loss = self.model(states=state, action=actions, **model_inputs)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
states = self.prepare_state(batch[OBS_STATE])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
|
||||
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
return actions[:, :, :original_action_dim]
|
||||
|
||||
def prepare_state(self, state: Tensor) -> Tensor:
|
||||
return pad_vector(state, self.config.max_state_dim)
|
||||
|
||||
def prepare_action(self, action: Tensor) -> Tensor:
|
||||
return pad_vector(action, self.config.max_action_dim)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
return torch.float32
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
class EO1VisionActionProjector(torch.nn.Sequential):
|
||||
"""This block implements the multi-layer perceptron (MLP) module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 2,
|
||||
activation_layer: str = "linear",
|
||||
bias: bool = True,
|
||||
device: Any = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
layers = []
|
||||
in_dim = in_channels
|
||||
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
|
||||
for hidden_dim in hidden_channels[:-1]:
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
|
||||
layers.append(ACT2FN[activation_layer])
|
||||
in_dim = hidden_dim
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
|
||||
super().__init__(*layers)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self[0].weight.dtype
|
||||
|
||||
|
||||
class EO1VisionFlowMatchingModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: EO1Config,
|
||||
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
|
||||
):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
|
||||
self.vlm_backbone = vlm_backbone
|
||||
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
|
||||
max_state_dim = config.max_state_dim
|
||||
max_action_dim = config.max_action_dim
|
||||
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_out_proj = EO1VisionActionProjector(
|
||||
self.hidden_size,
|
||||
max_action_dim,
|
||||
config.num_action_layers,
|
||||
config.action_act,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
|
||||
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.vlm_backbone.get_input_embeddings()
|
||||
|
||||
def flow_head_autocast_context(self):
|
||||
if self.config.force_fp32_autocast:
|
||||
return torch.autocast(
|
||||
device_type=self.state_proj.weight.device.type,
|
||||
enabled=False,
|
||||
)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.vlm_backbone.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.vlm_backbone.gradient_checkpointing_disable()
|
||||
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
|
||||
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(
|
||||
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
||||
)
|
||||
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def get_placeholder_mask(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None,
|
||||
inputs_embeds: torch.FloatTensor | None,
|
||||
state_features: torch.FloatTensor | None = None,
|
||||
action_features: torch.FloatTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
|
||||
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
|
||||
if input_ids is None:
|
||||
special_state_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_state_mask = special_state_mask.all(-1)
|
||||
special_action_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_action_mask = special_action_mask.all(-1)
|
||||
else:
|
||||
special_state_mask = input_ids == state_token_id
|
||||
special_action_mask = input_ids == action_token_id
|
||||
|
||||
n_state_tokens = special_state_mask.sum()
|
||||
special_state_mask = (
|
||||
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if state_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_state_mask].numel() == state_features.numel(),
|
||||
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
|
||||
)
|
||||
|
||||
n_action_tokens = special_action_mask.sum()
|
||||
special_action_mask = (
|
||||
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if action_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_action_mask].numel() == action_features.numel(),
|
||||
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
|
||||
)
|
||||
|
||||
return special_state_mask, special_action_mask
|
||||
|
||||
def embed_prefix(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
states: torch.Tensor,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
|
||||
|
||||
# Get the input embeddings for the input IDs
|
||||
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||
return self.get_input_embeddings()(input_ids)
|
||||
|
||||
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
|
||||
|
||||
# Project the states to the hidden size
|
||||
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
states = states.to(dtype=self.state_proj.weight.dtype)
|
||||
return self.state_proj(states)
|
||||
|
||||
state_embs = self._apply_checkpoint(state_proj_func, states)
|
||||
state_mask, _ = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_features=state_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
|
||||
return inputs_embeds
|
||||
|
||||
def embed_suffix(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
noisy_actions: torch.Tensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the suffix"""
|
||||
|
||||
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
time_embs = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.hidden_size,
|
||||
min_period=self.config.min_period,
|
||||
max_period=self.config.max_period,
|
||||
device=action_embs.device,
|
||||
)
|
||||
time_embs = time_embs.to(dtype=action_embs.dtype)
|
||||
time_embs = time_embs[:, None, :].expand_as(action_embs)
|
||||
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
|
||||
|
||||
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
|
||||
action_time_embs = self.action_time_mlp_in(action_time_embs)
|
||||
action_time_embs = F.silu(action_time_embs)
|
||||
return self.action_time_mlp_out(action_time_embs)
|
||||
|
||||
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
|
||||
return action_time_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.FloatTensor | None = None,
|
||||
action: torch.FloatTensor | None = None,
|
||||
action_is_pad: torch.BoolTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Run the EO1 training forward pass and compute the flow-matching loss."""
|
||||
|
||||
# 1. Build the EO1 prefix with state placeholders resolved.
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
|
||||
# 2. Sample the diffusion target and replace the action placeholders.
|
||||
time = self.sample_time(action.shape[0], inputs_embeds.device)
|
||||
noise = self.sample_noise(action.shape, inputs_embeds.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * action
|
||||
u_t = noise - action
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
_, action_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
action_features=action_time_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
|
||||
|
||||
# 3. Optionally drop padded action tokens from backbone attention.
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
|
||||
action_token_mask = action_mask[..., 0]
|
||||
action_padding_mask = torch.zeros_like(action_token_mask)
|
||||
action_padding_mask = action_padding_mask.masked_scatter(
|
||||
action_token_mask,
|
||||
action_is_pad.reshape(-1),
|
||||
)
|
||||
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
|
||||
|
||||
# 4. Run the Qwen backbone on the fused EO1 sequence.
|
||||
def vlm_forward_func(
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
pixel_values: torch.Tensor | None,
|
||||
image_grid_thw: torch.LongTensor | None,
|
||||
mm_token_type_ids: torch.IntTensor | None,
|
||||
) -> torch.FloatTensor:
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
use_cache=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
hidden_states = self._apply_checkpoint(
|
||||
vlm_forward_func,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
mm_token_type_ids,
|
||||
)
|
||||
action_hidden_states = hidden_states[action_mask[..., 0]]
|
||||
|
||||
# 5. Project the action-token hidden states back to the flow target space.
|
||||
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
return self.action_out_proj(action_hidden_states)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
|
||||
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
# 6. Apply the configured supervision mask and reduce the loss.
|
||||
if not self.config.supervise_padding_action_dims:
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[..., :original_action_dim]
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
losses = losses[~action_is_pad]
|
||||
|
||||
return losses.mean()
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.Tensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Sample actions from the model."""
|
||||
if states is None:
|
||||
raise ValueError("states are required for EO1 action sampling.")
|
||||
if mm_token_type_ids is None:
|
||||
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
|
||||
|
||||
# 1. Resolve the left-padded rollout prompt and locate the action span.
|
||||
chunk_size = self.config.chunk_size
|
||||
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
).clone()
|
||||
_, action_placeholder_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_mask = action_placeholder_mask[..., 0]
|
||||
token_counts = action_mask.sum(dim=1)
|
||||
if not torch.all(token_counts == chunk_size):
|
||||
raise ValueError(
|
||||
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
|
||||
)
|
||||
if action_mask.ne(action_mask[:1]).any():
|
||||
raise ValueError(
|
||||
"Batch inference expects all samples to share the same action token mask after left padding."
|
||||
)
|
||||
act_start = int(action_mask[0].to(torch.int64).argmax().item())
|
||||
act_end = act_start + self.config.chunk_size
|
||||
if not torch.all(action_mask[:, act_start:act_end]):
|
||||
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
|
||||
act_slice = slice(act_start, act_end)
|
||||
|
||||
# 2. Encode the fixed prefix once and cache its KV state.
|
||||
batch_size = input_ids.shape[0]
|
||||
device = inputs_embeds.device
|
||||
attention_mask = attention_mask.to(device)
|
||||
mm_token_type_ids = mm_token_type_ids.to(device)
|
||||
position_ids, _ = self.vlm_backbone.model.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
)
|
||||
position_ids = position_ids.to(device)
|
||||
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids[:, :act_start],
|
||||
attention_mask=attention_mask[:, :act_start],
|
||||
position_ids=position_ids[..., :act_start],
|
||||
inputs_embeds=inputs_embeds[:, :act_start],
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids[:, :act_start],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
x_t = self.sample_noise(
|
||||
(batch_size, chunk_size, self.config.max_action_dim),
|
||||
device,
|
||||
).to(dtype=self.action_in_proj.weight.dtype)
|
||||
dt = -1.0 / self.config.num_denoise_steps
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
|
||||
for step in range(self.config.num_denoise_steps):
|
||||
time = torch.full(
|
||||
(batch_size,),
|
||||
1.0 + step * dt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
|
||||
|
||||
# Keep the prefix KV cache invariant across denoising steps.
|
||||
past_key_values.crop(act_start)
|
||||
outputs = self.vlm_backbone.model(
|
||||
attention_mask=attention_mask[:, :act_end],
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds[:, act_slice],
|
||||
position_ids=position_ids[..., act_slice],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
with self.flow_head_autocast_context():
|
||||
hidden_states = outputs.last_hidden_state[:, :chunk_size]
|
||||
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
v_t = self.action_out_proj(hidden_states)
|
||||
|
||||
x_t += dt * v_t.reshape(x_t.shape)
|
||||
|
||||
return x_t
|
||||
@@ -1,282 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
else:
|
||||
Qwen2_5_VLProcessor = None
|
||||
|
||||
SYSTEM_MESSAGE = "You are a helpful physical assistant."
|
||||
|
||||
# EO-1 special tokens
|
||||
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
|
||||
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
|
||||
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
|
||||
STATE_START_TOKEN = "<|state_start|>" # nosec B105
|
||||
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
|
||||
STATE_END_TOKEN = "<|state_end|>" # nosec B105
|
||||
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
|
||||
|
||||
EO1_SPECIAL_TOKENS = [
|
||||
ACTION_START_TOKEN,
|
||||
DEFAULT_ACTION_TOKEN,
|
||||
ACTION_END_TOKEN,
|
||||
STATE_START_TOKEN,
|
||||
DEFAULT_STATE_TOKEN,
|
||||
STATE_END_TOKEN,
|
||||
TASK_VLA_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
|
||||
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
|
||||
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
|
||||
chunk_size: int
|
||||
|
||||
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps).
|
||||
if self.input_features:
|
||||
first_val = next(iter(self.input_features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.input_features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.input_features = reconstructed
|
||||
|
||||
self._image_keys = [
|
||||
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
|
||||
]
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
tasks = complementary_data.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("Task is required for EO1ConversationTemplateStep.")
|
||||
|
||||
observation = self.transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
|
||||
|
||||
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
|
||||
raise ValueError("Batch size mismatch between observation.state and task list.")
|
||||
|
||||
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
|
||||
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
|
||||
images = {
|
||||
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
|
||||
}
|
||||
messages = []
|
||||
for i in range(len(tasks)):
|
||||
content = [
|
||||
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
|
||||
),
|
||||
},
|
||||
]
|
||||
messages.append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
|
||||
{"role": "user", "content": content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
complementary_data["messages"] = messages
|
||||
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only materializes EO1-specific message objects in complementary_data.
|
||||
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
|
||||
feature contract change to record here.
|
||||
"""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
|
||||
},
|
||||
"chunk_size": self.chunk_size,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
|
||||
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
|
||||
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
|
||||
_state_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
_action_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
require_package("transformers", extra="eo1")
|
||||
self._processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
self.processor_name,
|
||||
use_fast=self.use_fast_processor,
|
||||
)
|
||||
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
|
||||
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
|
||||
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
messages = complementary_data.pop("messages", None)
|
||||
if messages is None:
|
||||
raise ValueError("Messages are required for EO1QwenProcessorStep.")
|
||||
|
||||
# Rollout batches use left padding so action spans stay aligned across samples.
|
||||
# Supervised batches use right padding to match standard training collation.
|
||||
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
|
||||
|
||||
inputs = self._processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
padding=True,
|
||||
padding_side=padding_side,
|
||||
min_pixels=self.image_min_pixels,
|
||||
max_pixels=self.image_max_pixels,
|
||||
add_generation_prompt=False,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
complementary_data["input_ids"] = inputs["input_ids"]
|
||||
complementary_data["pixel_values"] = inputs["pixel_values"]
|
||||
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
|
||||
complementary_data["attention_mask"] = inputs["attention_mask"]
|
||||
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
|
||||
complementary_data["state_token_id"] = self._state_token_id
|
||||
complementary_data["action_token_id"] = self._action_token_id
|
||||
|
||||
return complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"processor_name": self.processor_name,
|
||||
"image_min_pixels": self.image_min_pixels,
|
||||
"image_max_pixels": self.image_max_pixels,
|
||||
"use_fast_processor": self.use_fast_processor,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only converts the messages to the model input format.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_eo1_pre_post_processors(
|
||||
config: EO1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build pre/post processor pipelines for EO1."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
|
||||
EO1QwenProcessorStep(
|
||||
processor_name=config.vlm_base,
|
||||
image_min_pixels=config.image_min_pixels,
|
||||
image_max_pixels=config.image_max_pixels,
|
||||
use_fast_processor=config.use_fast_processor,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
../../../../docs/source/evo1.mdx
|
||||
@@ -1,19 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_evo1 import Evo1Config
|
||||
from .modeling_evo1 import EVO1Policy
|
||||
from .processor_evo1 import make_evo1_pre_post_processors
|
||||
|
||||
__all__ = ["Evo1Config", "EVO1Policy", "make_evo1_pre_post_processors"]
|
||||
@@ -1,211 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
@LRSchedulerConfig.register_subclass("evo1_exact")
|
||||
@dataclass
|
||||
class Evo1SchedulerConfig(LRSchedulerConfig):
|
||||
num_warmup_steps: int
|
||||
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
def lr_lambda(current_step: int) -> float:
|
||||
if current_step < self.num_warmup_steps:
|
||||
return current_step / max(1, self.num_warmup_steps)
|
||||
progress = (current_step - self.num_warmup_steps) / max(
|
||||
1, num_training_steps - self.num_warmup_steps
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("evo1")
|
||||
@dataclass
|
||||
class Evo1Config(PreTrainedConfig):
|
||||
training_stage: str = "stage1"
|
||||
use_amp: bool = True
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
max_state_dim: int = 24
|
||||
max_action_dim: int = 24
|
||||
max_views: int = 3
|
||||
image_resolution: tuple[int, int] = (448, 448)
|
||||
empty_cameras: int = 0
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
vlm_model_name: str = "OpenGVLab/InternVL3-1B"
|
||||
vlm_num_layers: int | None = 14
|
||||
vlm_dtype: str = "bfloat16"
|
||||
use_flash_attn: bool = True
|
||||
action_head: str = "flowmatching"
|
||||
embed_dim: int = 896
|
||||
hidden_dim: int = 1024
|
||||
state_hidden_dim: int = 1024
|
||||
num_heads: int = 8
|
||||
num_layers: int = 8
|
||||
dropout: float = 0.0
|
||||
num_inference_timesteps: int = 32
|
||||
num_categories: int = 1
|
||||
return_cls_only: bool = False
|
||||
enable_gradient_checkpointing: bool = True
|
||||
gradient_checkpointing_use_reentrant: bool = False
|
||||
|
||||
finetune_vlm: bool | None = None
|
||||
finetune_language_model: bool | None = None
|
||||
finetune_vision_model: bool | None = None
|
||||
finetune_action_head: bool | None = None
|
||||
|
||||
task_field: str = "task"
|
||||
embodiment_id_field: str | None = None
|
||||
default_embodiment_id: int = 0
|
||||
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 300
|
||||
drop_last: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.training_stage not in {"stage1", "stage2"}:
|
||||
raise ValueError(
|
||||
f"Unsupported EVO1 training_stage '{self.training_stage}', expected 'stage1' or 'stage2'"
|
||||
)
|
||||
|
||||
if self.training_stage == "stage1":
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = False
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = False
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = True
|
||||
elif self.training_stage == "stage2":
|
||||
has_explicit_branch_flags = any(
|
||||
flag is not None for flag in (self.finetune_language_model, self.finetune_vision_model)
|
||||
)
|
||||
if not has_explicit_branch_flags:
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = True
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = True
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = True
|
||||
elif self.finetune_vlm is None:
|
||||
self.finetune_vlm = bool(self.finetune_language_model or self.finetune_vision_model)
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = True
|
||||
|
||||
if self.finetune_vlm is None:
|
||||
self.finetune_vlm = False
|
||||
if self.finetune_language_model is None:
|
||||
self.finetune_language_model = False
|
||||
if self.finetune_vision_model is None:
|
||||
self.finetune_vision_model = False
|
||||
if self.finetune_action_head is None:
|
||||
self.finetune_action_head = False
|
||||
|
||||
branch_vlm = self.finetune_language_model or self.finetune_vision_model
|
||||
if self.finetune_vlm != branch_vlm:
|
||||
raise ValueError(
|
||||
"Inconsistent EVO1 finetune config: "
|
||||
f"finetune_vlm={self.finetune_vlm} but "
|
||||
f"(finetune_language_model or finetune_vision_model)={branch_vlm}. "
|
||||
"When branch-level flags are used, finetune_vlm must match their effective union."
|
||||
)
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) must be <= chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if self.input_features is None:
|
||||
self.input_features = {}
|
||||
if self.output_features is None:
|
||||
self.output_features = {}
|
||||
|
||||
for i in range(self.empty_cameras):
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
if key not in self.input_features:
|
||||
self.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution),
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
self.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
self.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return Evo1SchedulerConfig(
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,234 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.policies.evo1.flow_matching import FlowmatchingActionHead
|
||||
from lerobot.policies.evo1.internvl3_embedder import InternVL3Embedder
|
||||
|
||||
|
||||
def _cfgget(config: Any, key: str, default=None):
|
||||
if isinstance(config, dict):
|
||||
return config.get(key, default)
|
||||
return getattr(config, key, default)
|
||||
|
||||
|
||||
class EVO1(nn.Module):
|
||||
def __init__(self, config: dict):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._device = _cfgget(config, "device", "cuda")
|
||||
self.return_cls_only = _cfgget(config, "return_cls_only", False)
|
||||
vlm_name = _cfgget(config, "vlm_name", "OpenGVLab/InternVL3-1B")
|
||||
image_size = _cfgget(config, "image_size", 448)
|
||||
if image_size is None:
|
||||
image_resolution = _cfgget(config, "image_resolution", (448, 448))
|
||||
image_size = int(image_resolution[0])
|
||||
|
||||
self.embedder = InternVL3Embedder(
|
||||
model_name=vlm_name,
|
||||
image_size=image_size,
|
||||
device=self._device,
|
||||
num_language_layers=_cfgget(config, "vlm_num_layers", 14),
|
||||
model_dtype=_cfgget(config, "vlm_dtype", "bfloat16"),
|
||||
use_flash_attn=_cfgget(config, "use_flash_attn", True),
|
||||
enable_gradient_checkpointing=_cfgget(config, "enable_gradient_checkpointing", True),
|
||||
gradient_checkpointing_use_reentrant=_cfgget(
|
||||
config, "gradient_checkpointing_use_reentrant", False
|
||||
),
|
||||
)
|
||||
|
||||
action_head_type = _cfgget(config, "action_head", "flowmatching").lower()
|
||||
if action_head_type != "flowmatching":
|
||||
raise NotImplementedError(f"Unknown action_head: {action_head_type}")
|
||||
|
||||
horizon = _cfgget(config, "action_horizon", _cfgget(config, "horizon", 16))
|
||||
per_action_dim = _cfgget(config, "per_action_dim", 7)
|
||||
action_dim = horizon * per_action_dim
|
||||
|
||||
if isinstance(config, dict):
|
||||
config["horizon"] = horizon
|
||||
config["per_action_dim"] = per_action_dim
|
||||
config["action_dim"] = action_dim
|
||||
|
||||
self.horizon = horizon
|
||||
self.per_action_dim = per_action_dim
|
||||
self.action_head = FlowmatchingActionHead(config=config).to(self._device)
|
||||
|
||||
def _normalize_image_batches(
|
||||
self,
|
||||
images: Sequence[Image.Image | torch.Tensor] | Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
prompt: str | list[str] | None,
|
||||
image_mask: torch.Tensor,
|
||||
) -> tuple[list[list[Image.Image | torch.Tensor]], list[str], torch.Tensor]:
|
||||
if not images:
|
||||
raise ValueError("EVO1 expects at least one image per sample.")
|
||||
|
||||
first = images[0]
|
||||
if isinstance(first, (Image.Image, torch.Tensor)):
|
||||
image_batches = [list(images)] # type: ignore[arg-type]
|
||||
else:
|
||||
image_batches = [list(sample) for sample in images] # type: ignore[arg-type]
|
||||
|
||||
batch_size = len(image_batches)
|
||||
if prompt is None:
|
||||
prompts = [""] * batch_size
|
||||
elif isinstance(prompt, str):
|
||||
prompts = [prompt] * batch_size
|
||||
else:
|
||||
prompts = [str(p) for p in prompt]
|
||||
if len(prompts) != batch_size:
|
||||
raise ValueError(
|
||||
f"Prompt batch size {len(prompts)} does not match image batch size {batch_size}"
|
||||
)
|
||||
|
||||
if image_mask.dim() == 1:
|
||||
image_mask = image_mask.unsqueeze(0)
|
||||
if image_mask.shape[0] != batch_size:
|
||||
raise ValueError(
|
||||
f"image_mask batch size {image_mask.shape[0]} does not match image batch size {batch_size}"
|
||||
)
|
||||
|
||||
return image_batches, prompts, image_mask
|
||||
|
||||
def get_vl_embeddings(
|
||||
self,
|
||||
images: list[Image.Image | torch.Tensor] | list[list[Image.Image | torch.Tensor]],
|
||||
image_mask: torch.Tensor,
|
||||
prompt: str | list[str] | None = None,
|
||||
return_cls_only: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
if return_cls_only is None:
|
||||
return_cls_only = self.return_cls_only
|
||||
|
||||
image_batches, prompts, image_mask = self._normalize_image_batches(images, prompt, image_mask)
|
||||
return self.embedder.get_fused_image_text_embedding_from_tensor_images(
|
||||
image_tensors_batch=image_batches,
|
||||
image_masks=image_mask,
|
||||
text_prompts=prompts,
|
||||
return_cls_only=return_cls_only,
|
||||
)
|
||||
|
||||
def prepare_state(self, state_input: list | torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(state_input, list):
|
||||
state_tensor = torch.tensor(state_input)
|
||||
elif isinstance(state_input, torch.Tensor):
|
||||
state_tensor = state_input
|
||||
else:
|
||||
raise TypeError(f"Unsupported state input type: {type(state_input)}")
|
||||
|
||||
if state_tensor.ndim == 1:
|
||||
state_tensor = state_tensor.unsqueeze(0)
|
||||
|
||||
return state_tensor.to(self._device)
|
||||
|
||||
def predict_action(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor,
|
||||
actions_gt: torch.Tensor | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
):
|
||||
if actions_gt is None:
|
||||
return self.action_head.get_action(
|
||||
fused_tokens,
|
||||
state=state,
|
||||
action_mask=action_mask,
|
||||
embodiment_id=embodiment_ids,
|
||||
)
|
||||
return self.action_head(
|
||||
fused_tokens,
|
||||
state=state,
|
||||
actions_gt=actions_gt,
|
||||
action_mask=action_mask,
|
||||
embodiment_id=embodiment_ids,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(
|
||||
self,
|
||||
images: list[Image.Image | torch.Tensor],
|
||||
image_mask: torch.Tensor,
|
||||
prompt: str,
|
||||
state_input: list | torch.Tensor,
|
||||
return_cls_only: bool | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if image_mask.dim() == 1:
|
||||
image_mask = image_mask.unsqueeze(0)
|
||||
|
||||
fused_tokens = self.get_vl_embeddings(
|
||||
images=[images],
|
||||
image_mask=image_mask,
|
||||
prompt=[prompt],
|
||||
return_cls_only=return_cls_only,
|
||||
)
|
||||
state_tensor = self.prepare_state(state_input)
|
||||
action = self.predict_action(
|
||||
fused_tokens,
|
||||
state_tensor,
|
||||
action_mask=action_mask,
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
if isinstance(action, torch.Tensor) and action.dtype == torch.bfloat16:
|
||||
action = action.to(torch.float32)
|
||||
return action
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor | None = None,
|
||||
actions_gt: torch.Tensor | None = None,
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
):
|
||||
return self.predict_action(fused_tokens, state, actions_gt, action_mask, embodiment_ids)
|
||||
|
||||
def _set_module_trainable(self, module: nn.Module, trainable: bool):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = trainable
|
||||
|
||||
def set_finetune_flags(self):
|
||||
finetune_vlm = _cfgget(self.config, "finetune_vlm", False)
|
||||
finetune_language_model = _cfgget(self.config, "finetune_language_model", False)
|
||||
finetune_vision_model = _cfgget(self.config, "finetune_vision_model", False)
|
||||
has_explicit_branch_flags = any(
|
||||
flag is not None for flag in (finetune_language_model, finetune_vision_model)
|
||||
)
|
||||
finetune_language_model = bool(finetune_language_model)
|
||||
finetune_vision_model = bool(finetune_vision_model)
|
||||
finetune_vlm = bool(finetune_vlm)
|
||||
|
||||
if has_explicit_branch_flags:
|
||||
self._set_module_trainable(self.embedder, False)
|
||||
if hasattr(self.embedder.model, "language_model"):
|
||||
self._set_module_trainable(self.embedder.model.language_model, finetune_language_model)
|
||||
if hasattr(self.embedder.model, "vision_model"):
|
||||
self._set_module_trainable(self.embedder.model.vision_model, finetune_vision_model)
|
||||
if hasattr(self.embedder.model, "mlp1"):
|
||||
self._set_module_trainable(self.embedder.model.mlp1, finetune_vision_model)
|
||||
elif not finetune_vlm:
|
||||
self._set_module_trainable(self.embedder, False)
|
||||
|
||||
if not _cfgget(self.config, "finetune_action_head", False):
|
||||
self._set_module_trainable(self.action_head, False)
|
||||
@@ -1,456 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfgget(config, key: str, default=None):
|
||||
if isinstance(config, dict):
|
||||
return config.get(key, default)
|
||||
return getattr(config, key, default)
|
||||
|
||||
|
||||
class SinusoidalPositionalEncoding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int = 1000):
|
||||
super().__init__()
|
||||
pe = torch.zeros(max_len, dim)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, seq_len: int):
|
||||
if seq_len > self.pe.size(1):
|
||||
self._extend_pe(seq_len)
|
||||
return self.pe[:, :seq_len, :]
|
||||
|
||||
def _extend_pe(self, new_max_len):
|
||||
old_max_len, dim = self.pe.size(1), self.pe.size(2)
|
||||
if new_max_len <= old_max_len:
|
||||
return
|
||||
extra_positions = torch.arange(old_max_len, new_max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
||||
extra_pe = torch.zeros(new_max_len - old_max_len, dim)
|
||||
extra_pe[:, 0::2] = torch.sin(extra_positions * div_term)
|
||||
extra_pe[:, 1::2] = torch.cos(extra_positions * div_term)
|
||||
extra_pe = extra_pe.unsqueeze(0)
|
||||
new_pe = torch.cat([self.pe, extra_pe.to(self.pe.device)], dim=1)
|
||||
self.pe = new_pe
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, num_categories: int = 1):
|
||||
super().__init__()
|
||||
self.num_categories = num_categories
|
||||
if num_categories <= 1:
|
||||
self.linear = nn.Linear(in_dim, out_dim)
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.empty(num_categories, in_dim, out_dim))
|
||||
self.bias = nn.Parameter(torch.zeros(num_categories, out_dim))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||
if self.num_categories <= 1:
|
||||
if x.dtype != self.linear.weight.dtype:
|
||||
x = x.to(dtype=self.linear.weight.dtype)
|
||||
return self.linear(x)
|
||||
|
||||
if x.dtype != self.weight.dtype:
|
||||
x = x.to(dtype=self.weight.dtype)
|
||||
|
||||
orig_shape = x.shape
|
||||
x_flat = x.reshape(-1, orig_shape[-1])
|
||||
if category_id.dim() == 0:
|
||||
cid = category_id.item()
|
||||
out = x_flat @ self.weight[cid] + self.bias[cid]
|
||||
else:
|
||||
category_id = category_id.reshape(-1)
|
||||
if category_id.numel() != x_flat.size(0):
|
||||
raise ValueError(
|
||||
f"category_id length {category_id.numel()} does not match flattened batch {x_flat.size(0)}"
|
||||
)
|
||||
weight_selected = self.weight[category_id]
|
||||
bias_selected = self.bias[category_id]
|
||||
out = torch.bmm(x_flat.unsqueeze(1), weight_selected).squeeze(1) + bias_selected
|
||||
out_shape = orig_shape[:-1] + (out.shape[-1],)
|
||||
return out.view(out_shape)
|
||||
|
||||
|
||||
class CategorySpecificMLP(nn.Module):
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_categories: int = 1):
|
||||
super().__init__()
|
||||
self.fc1 = CategorySpecificLinear(input_dim, hidden_dim, num_categories)
|
||||
self.fc2 = CategorySpecificLinear(hidden_dim, output_dim, num_categories)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, category_id: torch.LongTensor):
|
||||
out = self.activation(self.fc1(x, category_id))
|
||||
out = self.fc2(out, category_id)
|
||||
return out
|
||||
|
||||
|
||||
class MultiEmbodimentActionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, action_dim: int, embed_dim: int, hidden_dim: int, horizon: int, num_categories: int = 1
|
||||
):
|
||||
super().__init__()
|
||||
self.horizon = horizon
|
||||
self.embed_dim = embed_dim
|
||||
self.num_categories = num_categories
|
||||
|
||||
self.W1 = CategorySpecificLinear(action_dim, hidden_dim, num_categories)
|
||||
self.W2 = CategorySpecificLinear(hidden_dim, hidden_dim, num_categories)
|
||||
self.W3 = CategorySpecificLinear(hidden_dim, embed_dim, num_categories)
|
||||
|
||||
self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=horizon)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, action_seq: torch.Tensor, category_id: torch.LongTensor):
|
||||
batch_size, horizon, action_dim = action_seq.shape
|
||||
assert self.horizon == horizon, "Action sequence length must match horizon"
|
||||
|
||||
x = action_seq.reshape(batch_size * horizon, action_dim)
|
||||
if category_id.dim() == 0:
|
||||
cat_ids = category_id.expand(horizon * batch_size)
|
||||
else:
|
||||
cat_ids = category_id.unsqueeze(1).expand(batch_size, horizon).reshape(batch_size * horizon)
|
||||
|
||||
out = self.activation(self.W1(x, cat_ids))
|
||||
pos_enc = self.pos_encoding(horizon).to(device=out.device, dtype=out.dtype)
|
||||
out = out.view(batch_size, horizon, -1) + pos_enc
|
||||
out = out.view(batch_size * horizon, -1)
|
||||
out = self.activation(self.W2(out, cat_ids))
|
||||
out = self.W3(out, cat_ids)
|
||||
return out.view(batch_size, horizon, self.embed_dim)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, embed_dim: int, num_heads: int, hidden_dim: int, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.ff = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
|
||||
|
||||
def forward(self, action_tokens: torch.Tensor, context_tokens: torch.Tensor, time_emb: torch.Tensor):
|
||||
x = self.norm1(action_tokens)
|
||||
attn_out, _ = self.attn(x, context_tokens, context_tokens)
|
||||
x = action_tokens + attn_out
|
||||
x2 = self.norm2(x)
|
||||
if time_emb is not None:
|
||||
x2 = x2 + time_emb.unsqueeze(1)
|
||||
ff_out = self.ff(x2)
|
||||
return x + ff_out
|
||||
|
||||
|
||||
class FlowmatchingActionHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config=None,
|
||||
embed_dim: int = 896,
|
||||
hidden_dim: int = 1024,
|
||||
action_dim: int = 16 * 7,
|
||||
horizon: int = 16,
|
||||
per_action_dim: int = 7,
|
||||
num_heads: int = 8,
|
||||
num_layers: int = 8,
|
||||
dropout: float = 0.0,
|
||||
num_inference_timesteps: int = 20,
|
||||
num_categories: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is not None:
|
||||
embed_dim = _cfgget(config, "embed_dim", embed_dim)
|
||||
hidden_dim = _cfgget(config, "hidden_dim", hidden_dim)
|
||||
action_dim = _cfgget(config, "action_dim", action_dim)
|
||||
horizon = _cfgget(config, "horizon", horizon)
|
||||
per_action_dim = _cfgget(config, "per_action_dim", per_action_dim)
|
||||
num_heads = _cfgget(config, "num_heads", num_heads)
|
||||
num_layers = _cfgget(config, "num_layers", num_layers)
|
||||
dropout = _cfgget(config, "dropout", dropout)
|
||||
num_inference_timesteps = _cfgget(config, "num_inference_timesteps", num_inference_timesteps)
|
||||
num_categories = _cfgget(config, "num_categories", num_categories)
|
||||
self.config = config
|
||||
else:
|
||||
self.config = SimpleNamespace(
|
||||
embed_dim=embed_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
per_action_dim=per_action_dim,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
num_inference_timesteps=num_inference_timesteps,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
logger.info("FlowmatchingActionHead num_inference_timesteps=%s", num_inference_timesteps)
|
||||
self.embed_dim = embed_dim
|
||||
self.horizon = horizon
|
||||
self.per_action_dim = _cfgget(self.config, "per_action_dim", per_action_dim)
|
||||
self.action_dim = _cfgget(self.config, "action_dim", action_dim)
|
||||
|
||||
self.time_pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len=1000)
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
hidden_dim=embed_dim * 4,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_out = nn.LayerNorm(embed_dim)
|
||||
self.seq_pool_proj = nn.Linear(self.horizon * self.embed_dim, self.embed_dim)
|
||||
self.mlp_head = CategorySpecificMLP(
|
||||
input_dim=embed_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
output_dim=action_dim,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
self.state_encoder = None
|
||||
state_dim = _cfgget(self.config, "state_dim")
|
||||
if state_dim is not None:
|
||||
state_hidden = _cfgget(self.config, "state_hidden_dim", embed_dim)
|
||||
self.state_encoder = CategorySpecificMLP(
|
||||
input_dim=state_dim,
|
||||
hidden_dim=state_hidden,
|
||||
output_dim=embed_dim,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
|
||||
if horizon > 1:
|
||||
self.action_encoder = MultiEmbodimentActionEncoder(
|
||||
action_dim=self.per_action_dim,
|
||||
embed_dim=embed_dim,
|
||||
hidden_dim=embed_dim,
|
||||
horizon=horizon,
|
||||
num_categories=num_categories,
|
||||
)
|
||||
self.single_action_proj = None
|
||||
else:
|
||||
self.action_encoder = None
|
||||
self.single_action_proj = nn.Linear(self.per_action_dim, self.embed_dim)
|
||||
|
||||
def _project_actions(self, action_seq: torch.Tensor, embodiment_id: torch.LongTensor) -> torch.Tensor:
|
||||
if self.horizon > 1 and self.action_encoder is not None:
|
||||
return self.action_encoder(action_seq, embodiment_id)
|
||||
if self.single_action_proj is None:
|
||||
raise RuntimeError("single_action_proj is not initialized for horizon <= 1.")
|
||||
return self.single_action_proj(action_seq)
|
||||
|
||||
def _expand_action_mask(
|
||||
self,
|
||||
action_mask: torch.Tensor,
|
||||
batch_size: int,
|
||||
per_action_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
if action_mask is None:
|
||||
raise ValueError("action_mask must be provided for flow matching inference.")
|
||||
|
||||
if action_mask.dim() == 2:
|
||||
expected_last_dim = self.horizon * per_action_dim
|
||||
if action_mask.shape == (batch_size, expected_last_dim):
|
||||
expanded_mask = action_mask.reshape(batch_size, self.horizon, per_action_dim)
|
||||
elif action_mask.shape == (batch_size, per_action_dim):
|
||||
expanded_mask = action_mask.unsqueeze(1).expand(batch_size, self.horizon, per_action_dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected action_mask shape {(batch_size, expected_last_dim)} or "
|
||||
f"{(batch_size, per_action_dim)}, got {tuple(action_mask.shape)}"
|
||||
)
|
||||
elif action_mask.dim() == 3:
|
||||
expected_shape = (batch_size, self.horizon, per_action_dim)
|
||||
if tuple(action_mask.shape) != expected_shape:
|
||||
raise ValueError(
|
||||
f"Expected action_mask shape {expected_shape}, got {tuple(action_mask.shape)}"
|
||||
)
|
||||
expanded_mask = action_mask
|
||||
else:
|
||||
raise ValueError(f"Unsupported action_mask rank: {action_mask.dim()}")
|
||||
|
||||
return expanded_mask.to(device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
actions_gt: torch.Tensor = None,
|
||||
embodiment_id: torch.LongTensor = None,
|
||||
state_mask: torch.Tensor = None,
|
||||
action_mask: torch.Tensor = None,
|
||||
):
|
||||
if actions_gt is None:
|
||||
return self.get_action(
|
||||
fused_tokens, state=state, embodiment_id=embodiment_id, action_mask=action_mask
|
||||
)
|
||||
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
if embodiment_id is None:
|
||||
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
context_tokens = fused_tokens
|
||||
if state is not None and self.state_encoder is not None:
|
||||
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
|
||||
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
|
||||
|
||||
t = (
|
||||
torch.distributions.Beta(2, 2)
|
||||
.sample((batch_size,))
|
||||
.clamp(0.02, 0.98)
|
||||
.to(device)
|
||||
.to(dtype=self.dtype)
|
||||
)
|
||||
time_index = (t * 999).long().clamp_(0, 999)
|
||||
time_emb = self.time_pos_enc(1000)[:, time_index, :].squeeze(0).to(dtype=context_tokens.dtype)
|
||||
|
||||
actions_gt_seq = actions_gt
|
||||
noise = torch.rand_like(actions_gt) * 2 - 1
|
||||
if action_mask is not None:
|
||||
action_mask = action_mask.to(dtype=noise.dtype, device=noise.device)
|
||||
if action_mask.shape != noise.shape:
|
||||
raise ValueError(f"action_mask shape {action_mask.shape} != noise shape {noise.shape}")
|
||||
actions_gt_seq = actions_gt_seq * action_mask
|
||||
noise = noise * action_mask
|
||||
|
||||
if self.horizon > 1:
|
||||
noise_seq = noise.view(batch_size, self.horizon, self.per_action_dim)
|
||||
else:
|
||||
noise_seq = noise if noise.dim() == 3 else noise.unsqueeze(1)
|
||||
t_broadcast = t.view(batch_size, 1, 1)
|
||||
action_intermediate_seq = (1 - t_broadcast) * noise_seq + t_broadcast * actions_gt_seq
|
||||
|
||||
action_tokens = self._project_actions(action_intermediate_seq, embodiment_id)
|
||||
target_dtype = self.dtype
|
||||
action_tokens = action_tokens.to(dtype=target_dtype)
|
||||
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||
time_emb = time_emb.to(dtype=target_dtype)
|
||||
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, time_emb)
|
||||
x = self.norm_out(x)
|
||||
|
||||
if self.horizon > 1:
|
||||
x_flat = x.reshape(batch_size, -1)
|
||||
x_pooled = self.seq_pool_proj(x_flat)
|
||||
else:
|
||||
x_pooled = x.squeeze(1)
|
||||
|
||||
pred_velocity = self.mlp_head(x_pooled, embodiment_id)
|
||||
return pred_velocity, noise
|
||||
|
||||
def get_action(
|
||||
self,
|
||||
fused_tokens: torch.Tensor,
|
||||
state: torch.Tensor = None,
|
||||
embodiment_id: torch.LongTensor = None,
|
||||
action_mask: torch.Tensor = None,
|
||||
):
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
if embodiment_id is None:
|
||||
embodiment_id = torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
context_tokens = fused_tokens
|
||||
if state is not None and self.state_encoder is not None:
|
||||
state_emb = self.state_encoder(state, embodiment_id).unsqueeze(1)
|
||||
context_tokens = torch.cat([context_tokens, state_emb], dim=1)
|
||||
|
||||
action_dim_total = _cfgget(self.config, "action_dim", self.action_dim)
|
||||
per_action_dim = _cfgget(self.config, "per_action_dim", action_dim_total // max(self.horizon, 1))
|
||||
|
||||
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
|
||||
action_seq = (
|
||||
action.view(batch_size, self.horizon, per_action_dim)
|
||||
if self.horizon > 1
|
||||
else action.view(batch_size, 1, per_action_dim)
|
||||
)
|
||||
action_mask = self._expand_action_mask(
|
||||
action_mask,
|
||||
batch_size=batch_size,
|
||||
per_action_dim=per_action_dim,
|
||||
device=action_seq.device,
|
||||
dtype=action_seq.dtype,
|
||||
)
|
||||
action_seq = action_seq * action_mask
|
||||
|
||||
target_dtype = self.dtype
|
||||
context_tokens = context_tokens.to(dtype=target_dtype)
|
||||
|
||||
num_steps = int(_cfgget(self.config, "num_inference_timesteps", 32))
|
||||
if num_steps <= 0:
|
||||
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
for i in range(num_steps):
|
||||
t = i / num_steps
|
||||
time_index = min(int(t * 999), 999)
|
||||
time_emb = (
|
||||
self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=context_tokens.dtype)
|
||||
)
|
||||
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
action_seq = action_seq * action_mask
|
||||
action_tokens = self._project_actions(action_seq, embodiment_id).to(dtype=target_dtype)
|
||||
time_emb = time_emb.to(dtype=target_dtype)
|
||||
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, time_emb)
|
||||
x = self.norm_out(x)
|
||||
|
||||
if self.horizon > 1:
|
||||
x_flat = x.reshape(batch_size, -1)
|
||||
x_pooled = self.seq_pool_proj(x_flat)
|
||||
else:
|
||||
x_pooled = x.squeeze(1)
|
||||
|
||||
pred = self.mlp_head(x_pooled, embodiment_id)
|
||||
action = action + dt * pred
|
||||
action_seq = (
|
||||
action.view(batch_size, self.horizon, per_action_dim)
|
||||
if self.horizon > 1
|
||||
else action.view(batch_size, 1, per_action_dim)
|
||||
)
|
||||
|
||||
action_seq = action_seq * action_mask
|
||||
return action_seq.reshape(batch_size, -1)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
@@ -1,372 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as TF
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
else:
|
||||
AutoModel = None
|
||||
AutoTokenizer = None
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>" # nosec B105
|
||||
IMG_START_TOKEN = "<img>" # nosec B105
|
||||
IMG_END_TOKEN = "</img>" # nosec B105
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def flash_attn_is_available() -> bool:
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=10000)
|
||||
def get_target_aspect_ratio(orig_width: int, orig_height: int, image_size: int, min_num: int, max_num: int):
|
||||
aspect_ratio = orig_width / orig_height
|
||||
target_ratios = {
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
}
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
best_ratio_diff = float("inf")
|
||||
best_ratio = (1, 1)
|
||||
area = orig_width * orig_height
|
||||
for ratio in target_ratios:
|
||||
target_ar = ratio[0] / ratio[1]
|
||||
diff = abs(aspect_ratio - target_ar)
|
||||
if diff < best_ratio_diff:
|
||||
best_ratio_diff = diff
|
||||
best_ratio = ratio
|
||||
elif diff == best_ratio_diff and area > 0.5 * image_size**2 * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
|
||||
def dynamic_preprocess(image, min_num=1, max_num=1, image_size=448, use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
ratio_w, ratio_h = get_target_aspect_ratio(orig_width, orig_height, image_size, min_num, max_num)
|
||||
target_width = image_size * ratio_w
|
||||
target_height = image_size * ratio_h
|
||||
blocks = ratio_w * ratio_h
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
processed_images.append(resized_img.crop(box))
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
processed_images.append(image.resize((image_size, image_size)))
|
||||
return processed_images
|
||||
|
||||
|
||||
class InternVL3Embedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_name="OpenGVLab/InternVL3-1B",
|
||||
image_size=448,
|
||||
device="cuda",
|
||||
num_language_layers: int | None = 14,
|
||||
model_dtype: str | torch.dtype = "bfloat16",
|
||||
use_flash_attn: bool = True,
|
||||
enable_gradient_checkpointing: bool = True,
|
||||
gradient_checkpointing_use_reentrant: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self._requested_device = device
|
||||
self.image_size = image_size
|
||||
self.num_language_layers = num_language_layers
|
||||
self.max_text_length = 1024
|
||||
self.enable_gradient_checkpointing = bool(enable_gradient_checkpointing)
|
||||
self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant)
|
||||
|
||||
require_package("transformers", extra="evo1")
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
|
||||
if isinstance(model_dtype, str):
|
||||
try:
|
||||
model_dtype = getattr(torch, model_dtype)
|
||||
except AttributeError as exc:
|
||||
raise ValueError(f"Unsupported EVO1 vlm_dtype '{model_dtype}'") from exc
|
||||
|
||||
resolved_use_flash_attn = bool(use_flash_attn and flash_attn_is_available())
|
||||
if use_flash_attn and not resolved_use_flash_attn:
|
||||
logger.warning("flash_attn is not installed. Falling back to standard attention.")
|
||||
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=model_dtype,
|
||||
trust_remote_code=True,
|
||||
use_flash_attn=resolved_use_flash_attn,
|
||||
low_cpu_mem_usage=True,
|
||||
_fast_init=False,
|
||||
).to(self._requested_device)
|
||||
|
||||
if hasattr(self.model.language_model, "model"):
|
||||
layers = self.model.language_model.model.layers
|
||||
else:
|
||||
layers = self.model.language_model.layers
|
||||
if self.num_language_layers is not None:
|
||||
layers = layers[: self.num_language_layers]
|
||||
|
||||
if hasattr(self.model.language_model, "model"):
|
||||
self.model.language_model.model.layers = torch.nn.ModuleList(layers)
|
||||
else:
|
||||
self.model.language_model.layers = torch.nn.ModuleList(layers)
|
||||
self.model.language_model.lm_head = torch.nn.Identity()
|
||||
|
||||
self._configure_memory_features()
|
||||
self.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
|
||||
|
||||
def _configure_memory_features(self) -> None:
|
||||
checkpoint_kwargs = {"use_reentrant": self.gradient_checkpointing_use_reentrant}
|
||||
|
||||
if not self.enable_gradient_checkpointing:
|
||||
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||
self.model.vision_model.encoder.gradient_checkpointing = False
|
||||
language_model = getattr(self.model, "language_model", None)
|
||||
if language_model is not None:
|
||||
if hasattr(language_model, "gradient_checkpointing_disable"):
|
||||
language_model.gradient_checkpointing_disable()
|
||||
elif hasattr(language_model, "gradient_checkpointing"):
|
||||
language_model.gradient_checkpointing = False
|
||||
if hasattr(language_model, "model"):
|
||||
inner = language_model.model
|
||||
if hasattr(inner, "gradient_checkpointing_disable"):
|
||||
inner.gradient_checkpointing_disable()
|
||||
elif hasattr(inner, "gradient_checkpointing"):
|
||||
inner.gradient_checkpointing = False
|
||||
return
|
||||
|
||||
def _enable_ckpt(module: nn.Module | None) -> bool:
|
||||
if module is None:
|
||||
return False
|
||||
if hasattr(module, "gradient_checkpointing_enable"):
|
||||
try:
|
||||
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=checkpoint_kwargs)
|
||||
except TypeError:
|
||||
module.gradient_checkpointing_enable()
|
||||
return True
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = True
|
||||
return True
|
||||
return False
|
||||
|
||||
enabled_any = _enable_ckpt(self.model)
|
||||
|
||||
if hasattr(self.model, "vision_model") and hasattr(self.model.vision_model, "encoder"):
|
||||
self.model.vision_model.encoder.gradient_checkpointing = True
|
||||
enabled_any = True
|
||||
|
||||
language_model = getattr(self.model, "language_model", None)
|
||||
if language_model is not None:
|
||||
enabled_any = _enable_ckpt(language_model) or enabled_any
|
||||
if hasattr(language_model, "model"):
|
||||
enabled_any = _enable_ckpt(language_model.model) or enabled_any
|
||||
if hasattr(language_model, "config"):
|
||||
language_model.config.use_cache = False
|
||||
|
||||
if hasattr(self.model, "config"):
|
||||
self.model.config.use_cache = False
|
||||
if hasattr(self.model, "enable_input_require_grads"):
|
||||
self.model.enable_input_require_grads()
|
||||
|
||||
if enabled_any:
|
||||
logger.info("Gradient checkpointing enabled for InternVL3 embedder.")
|
||||
else:
|
||||
logger.warning(
|
||||
"Requested gradient checkpointing, but model does not expose checkpointing controls."
|
||||
)
|
||||
|
||||
def _preprocess_single_image(self, image: Image.Image | torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(image, torch.Tensor):
|
||||
pil_image = to_pil_image(image.detach().cpu())
|
||||
else:
|
||||
pil_image = image.convert("RGB")
|
||||
tiles = dynamic_preprocess(pil_image, image_size=self.image_size)
|
||||
tile_tensors = torch.stack([TF.to_tensor(tile) for tile in tiles]).to(
|
||||
device=self.device, dtype=torch.bfloat16
|
||||
)
|
||||
mean = torch.tensor(IMAGENET_MEAN, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
std = torch.tensor(IMAGENET_STD, device=self.device, dtype=torch.bfloat16).view(1, 3, 1, 1)
|
||||
return (tile_tensors - mean) / std
|
||||
|
||||
def _preprocess_images(
|
||||
self,
|
||||
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, list[list[int]]]:
|
||||
pixel_values_list = []
|
||||
batch_num_tiles_list: list[list[int]] = []
|
||||
|
||||
for image_tensors in image_tensors_batch:
|
||||
num_tiles_list: list[int] = []
|
||||
for image in image_tensors:
|
||||
tiles = self._preprocess_single_image(image)
|
||||
pixel_values_list.append(tiles)
|
||||
num_tiles_list.append(int(tiles.shape[0]))
|
||||
batch_num_tiles_list.append(num_tiles_list)
|
||||
|
||||
if pixel_values_list:
|
||||
pixel_values = torch.cat(pixel_values_list, dim=0)
|
||||
else:
|
||||
pixel_values = torch.empty(
|
||||
0, 3, self.image_size, self.image_size, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
return pixel_values, batch_num_tiles_list
|
||||
|
||||
def _build_multimodal_prompts(
|
||||
self,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
text_prompts: Sequence[str],
|
||||
) -> list[str]:
|
||||
prompts = []
|
||||
for num_tiles_list, text_prompt in zip(batch_num_tiles_list, text_prompts, strict=True):
|
||||
prompt_segments = []
|
||||
for i, tile_count in enumerate(num_tiles_list):
|
||||
token_count = self.model.num_image_token * tile_count
|
||||
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * token_count + IMG_END_TOKEN
|
||||
prompt_segments.append(f"Image-{i + 1}: {image_tokens}\n")
|
||||
prompts.append("".join(prompt_segments) + text_prompt.strip())
|
||||
return prompts
|
||||
|
||||
def _prepare_and_fuse_embeddings(
|
||||
self,
|
||||
prompts: Sequence[str],
|
||||
vit_embeds: torch.Tensor,
|
||||
image_masks: torch.Tensor,
|
||||
batch_num_tiles_list: list[list[int]],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
untruncated_ids = self.tokenizer(list(prompts), padding=False, truncation=False)["input_ids"]
|
||||
true_sequence_length = max((len(ids) for ids in untruncated_ids), default=0)
|
||||
if true_sequence_length > self.max_text_length:
|
||||
logger.warning(
|
||||
"InternVL3 prompt truncated in batch: max_length=%s actual_max_length=%s",
|
||||
self.max_text_length,
|
||||
true_sequence_length,
|
||||
)
|
||||
|
||||
model_inputs = self.tokenizer(
|
||||
list(prompts),
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.max_text_length,
|
||||
).to(self.device)
|
||||
input_ids = model_inputs["input_ids"]
|
||||
attention_mask = model_inputs["attention_mask"]
|
||||
|
||||
img_token_mask = input_ids == self.img_context_token_id
|
||||
input_embeds = self.model.language_model.get_input_embeddings()(input_ids).clone()
|
||||
|
||||
batch_size, _, channels = input_embeds.shape
|
||||
vit_embeds = vit_embeds.reshape(-1, channels).to(dtype=input_embeds.dtype, device=input_embeds.device)
|
||||
tokens_per_tile = self.model.num_image_token
|
||||
actual_vis_tokens_list = img_token_mask.sum(dim=1).tolist()
|
||||
|
||||
vit_idx = 0
|
||||
for batch_index in range(batch_size):
|
||||
expected_vis_tokens = sum(batch_num_tiles_list[batch_index]) * tokens_per_tile
|
||||
mask_b = img_token_mask[batch_index]
|
||||
actual_vis_tokens = actual_vis_tokens_list[batch_index]
|
||||
|
||||
item_vit_embeds = vit_embeds[vit_idx : vit_idx + expected_vis_tokens]
|
||||
vit_idx += expected_vis_tokens
|
||||
if actual_vis_tokens > 0:
|
||||
if item_vit_embeds.shape[0] < actual_vis_tokens:
|
||||
raise ValueError(
|
||||
f"InternVL3 produced fewer image tokens than expected for sample {batch_index}: "
|
||||
f"got {item_vit_embeds.shape[0]}, need {actual_vis_tokens}"
|
||||
)
|
||||
input_embeds[batch_index, mask_b] = item_vit_embeds[:actual_vis_tokens]
|
||||
|
||||
current_token_idx = 0
|
||||
img_token_locations = torch.where(mask_b)[0]
|
||||
for image_index, num_tiles in enumerate(batch_num_tiles_list[batch_index]):
|
||||
num_tokens_for_image = num_tiles * tokens_per_tile
|
||||
if not bool(image_masks[batch_index, image_index].item()):
|
||||
start_offset = current_token_idx
|
||||
end_offset = min(current_token_idx + num_tokens_for_image, len(img_token_locations))
|
||||
if start_offset < end_offset:
|
||||
idxs = img_token_locations[start_offset:end_offset]
|
||||
attention_mask[batch_index, idxs] = 0
|
||||
current_token_idx += num_tokens_for_image
|
||||
|
||||
return input_embeds, attention_mask
|
||||
|
||||
def get_fused_image_text_embedding_from_tensor_images(
|
||||
self,
|
||||
image_tensors_batch: Sequence[Sequence[Image.Image | torch.Tensor]],
|
||||
image_masks: torch.Tensor,
|
||||
text_prompts: Sequence[str],
|
||||
return_cls_only: bool = True,
|
||||
):
|
||||
pixel_values, batch_num_tiles_list = self._preprocess_images(image_tensors_batch)
|
||||
if pixel_values.shape[0] == 0:
|
||||
logger.warning("InternVL3 received an empty image batch after preprocessing.")
|
||||
hidden_size = getattr(self.model.config, "hidden_size", None)
|
||||
if hidden_size is None and hasattr(self.model.language_model, "config"):
|
||||
hidden_size = getattr(self.model.language_model.config, "hidden_size", None)
|
||||
if hidden_size is None:
|
||||
raise RuntimeError("Unable to infer hidden size for empty InternVL3 batch.")
|
||||
empty = torch.empty(0, hidden_size, device=self.device, dtype=torch.float32)
|
||||
return empty
|
||||
|
||||
prompts = self._build_multimodal_prompts(batch_num_tiles_list, text_prompts)
|
||||
vit_embeds = self.model.extract_feature(pixel_values)
|
||||
inputs_embeds, attention_mask = self._prepare_and_fuse_embeddings(
|
||||
prompts,
|
||||
vit_embeds,
|
||||
image_masks.to(device=self.device),
|
||||
batch_num_tiles_list,
|
||||
)
|
||||
|
||||
outputs = self.model.language_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
)
|
||||
fused_hidden = outputs.hidden_states[-1].to(torch.float32)
|
||||
return fused_hidden[:, 0, :] if return_cls_only else fused_hidden
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.model.parameters()).device
|
||||
@@ -1,426 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||
from lerobot.policies.evo1.evo1_model import EVO1
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class EVO1Policy(PreTrainedPolicy):
|
||||
config_class = Evo1Config
|
||||
name = "evo1"
|
||||
|
||||
def __init__(self, config: Evo1Config, **kwargs):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
|
||||
if len(config.image_features) > config.max_views:
|
||||
raise ValueError(
|
||||
f"EVO1 supports at most {config.max_views} camera streams, got {len(config.image_features)}"
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.model = EVO1(self._build_model_config(config))
|
||||
self.model.set_finetune_flags()
|
||||
self.reset()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool | None = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if strict is None:
|
||||
strict = not (config is not None and getattr(config, "training_stage", None) == "stage2")
|
||||
return super().from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
config=config,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
strict=strict,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_model_config(config: Evo1Config) -> dict:
|
||||
return {
|
||||
"device": config.device,
|
||||
"return_cls_only": config.return_cls_only,
|
||||
"vlm_name": config.vlm_model_name,
|
||||
"vlm_num_layers": config.vlm_num_layers,
|
||||
"vlm_dtype": config.vlm_dtype,
|
||||
"use_flash_attn": config.use_flash_attn,
|
||||
"action_head": config.action_head,
|
||||
"action_horizon": config.chunk_size,
|
||||
"per_action_dim": config.max_action_dim,
|
||||
"state_dim": config.max_state_dim,
|
||||
"embed_dim": config.embed_dim,
|
||||
"hidden_dim": config.hidden_dim,
|
||||
"state_hidden_dim": config.state_hidden_dim,
|
||||
"num_heads": config.num_heads,
|
||||
"num_layers": config.num_layers,
|
||||
"dropout": config.dropout,
|
||||
"num_inference_timesteps": config.num_inference_timesteps,
|
||||
"num_categories": config.num_categories,
|
||||
"enable_gradient_checkpointing": config.enable_gradient_checkpointing,
|
||||
"gradient_checkpointing_use_reentrant": config.gradient_checkpointing_use_reentrant,
|
||||
"finetune_vlm": config.finetune_vlm,
|
||||
"finetune_language_model": config.finetune_language_model,
|
||||
"finetune_vision_model": config.finetune_vision_model,
|
||||
"finetune_action_head": config.finetune_action_head,
|
||||
}
|
||||
|
||||
@property
|
||||
def _camera_keys(self) -> list[str]:
|
||||
return list(self.config.image_features)
|
||||
|
||||
@property
|
||||
def _env_action_dim(self) -> int:
|
||||
action_feature = self.config.action_feature
|
||||
if action_feature is None:
|
||||
return self.config.max_action_dim
|
||||
return int(action_feature.shape[0])
|
||||
|
||||
@property
|
||||
def _compute_dtype(self) -> torch.dtype:
|
||||
return next(self.model.action_head.parameters()).dtype
|
||||
|
||||
@property
|
||||
def _training_compute_dtype(self) -> torch.dtype:
|
||||
if str(self.config.device).startswith("cuda"):
|
||||
return torch.bfloat16
|
||||
return self._compute_dtype
|
||||
|
||||
@property
|
||||
def _inference_compute_dtype(self) -> torch.dtype:
|
||||
if str(self.config.device).startswith("cuda") and self.config.use_amp:
|
||||
return torch.bfloat16
|
||||
return self._compute_dtype
|
||||
|
||||
def get_optim_params(self) -> list[dict]:
|
||||
decay, no_decay = [], []
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
is_bias = name.endswith("bias") or ".bias" in name
|
||||
is_norm = param.dim() == 1 or "norm" in name.lower()
|
||||
if is_bias or is_norm:
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [
|
||||
{"params": decay, "weight_decay": self.config.optimizer_weight_decay},
|
||||
{"params": no_decay, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
def _normalize_task_batch(self, batch: dict[str, Tensor | list[str] | str]) -> list[str]:
|
||||
prompts = batch.get(self.config.task_field)
|
||||
if prompts is None and self.config.task_field != "task":
|
||||
prompts = batch.get("task")
|
||||
if prompts is None:
|
||||
raise ValueError(f"EVO1 expects a '{self.config.task_field}' text field in the batch.")
|
||||
if isinstance(prompts, str):
|
||||
return [prompts]
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
return [str(prompt) for prompt in prompts]
|
||||
raise TypeError(f"Unsupported prompt batch type: {type(prompts)}")
|
||||
|
||||
def _prepare_state(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
if OBS_STATE not in batch:
|
||||
raise ValueError(f"EVO1 requires '{OBS_STATE}' in the batch.")
|
||||
state = batch[OBS_STATE]
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
elif state.dim() == 3:
|
||||
state = state[:, -1]
|
||||
elif state.dim() != 2:
|
||||
raise ValueError(f"Unsupported state tensor shape for EVO1: {tuple(state.shape)}")
|
||||
batch_size, state_dim = state.shape
|
||||
if state_dim > self.config.max_state_dim:
|
||||
raise ValueError(
|
||||
f"State dim {state_dim} exceeds configured max_state_dim {self.config.max_state_dim}"
|
||||
)
|
||||
explicit_mask = batch.get("state_mask")
|
||||
if explicit_mask is not None:
|
||||
if explicit_mask.dim() == 1:
|
||||
explicit_mask = explicit_mask.unsqueeze(0)
|
||||
elif explicit_mask.dim() == 3:
|
||||
explicit_mask = explicit_mask[:, -1]
|
||||
elif explicit_mask.dim() != 2:
|
||||
raise ValueError(
|
||||
f"Unsupported state_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||
)
|
||||
if explicit_mask.shape != (batch_size, state_dim):
|
||||
raise ValueError(
|
||||
f"state_mask shape {tuple(explicit_mask.shape)} does not match state shape {(batch_size, state_dim)}"
|
||||
)
|
||||
padded = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_state_dim,
|
||||
dtype=state.dtype,
|
||||
device=self.config.device,
|
||||
)
|
||||
padded[:, :state_dim] = state.to(device=self.config.device)
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_state_dim,
|
||||
dtype=torch.bool,
|
||||
device=self.config.device,
|
||||
)
|
||||
if explicit_mask is None:
|
||||
mask[:, :state_dim] = True
|
||||
else:
|
||||
mask[:, :state_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
|
||||
return padded.to(dtype=self._compute_dtype), mask
|
||||
|
||||
def _prepare_actions(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
if ACTION not in batch:
|
||||
raise ValueError(f"EVO1 requires '{ACTION}' in the batch for training.")
|
||||
action = batch[ACTION]
|
||||
if action.dim() == 2:
|
||||
action = action.unsqueeze(1)
|
||||
batch_size, horizon, action_dim = action.shape
|
||||
if horizon != self.config.chunk_size:
|
||||
raise ValueError(
|
||||
f"EVO1 expects chunk_size={self.config.chunk_size}, got action horizon {horizon}"
|
||||
)
|
||||
if action_dim > self.config.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Action dim {action_dim} exceeds configured max_action_dim {self.config.max_action_dim}"
|
||||
)
|
||||
explicit_mask = batch.get("action_mask")
|
||||
if explicit_mask is not None:
|
||||
if explicit_mask.dim() == 2:
|
||||
if horizon == 1:
|
||||
explicit_mask = explicit_mask.unsqueeze(1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"2D action_mask is only supported when chunk_size=1, got action horizon {horizon}"
|
||||
)
|
||||
elif explicit_mask.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Unsupported action_mask tensor shape for EVO1: {tuple(explicit_mask.shape)}"
|
||||
)
|
||||
if explicit_mask.shape != (batch_size, horizon, action_dim):
|
||||
raise ValueError(
|
||||
"action_mask shape "
|
||||
f"{tuple(explicit_mask.shape)} does not match action shape {(batch_size, horizon, action_dim)}"
|
||||
)
|
||||
padded = torch.zeros(
|
||||
batch_size,
|
||||
horizon,
|
||||
self.config.max_action_dim,
|
||||
dtype=action.dtype,
|
||||
device=self.config.device,
|
||||
)
|
||||
padded[:, :, :action_dim] = action.to(device=self.config.device)
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
horizon,
|
||||
self.config.max_action_dim,
|
||||
dtype=torch.bool,
|
||||
device=self.config.device,
|
||||
)
|
||||
if explicit_mask is None:
|
||||
mask[:, :, :action_dim] = True
|
||||
else:
|
||||
mask[:, :, :action_dim] = explicit_mask.to(device=self.config.device, dtype=torch.bool)
|
||||
return padded.to(dtype=self._compute_dtype), mask
|
||||
|
||||
def _prepare_inference_action_mask(self, batch_size: int) -> Tensor:
|
||||
mask = torch.zeros(
|
||||
batch_size,
|
||||
self.config.max_action_dim,
|
||||
dtype=torch.bool,
|
||||
device=self.config.device,
|
||||
)
|
||||
mask[:, : self._env_action_dim] = True
|
||||
return mask
|
||||
|
||||
def _get_embodiment_ids(self, batch: dict[str, Tensor], batch_size: int) -> Tensor:
|
||||
embodiment_ids = batch.get("embodiment_id")
|
||||
if embodiment_ids is None and self.config.embodiment_id_field:
|
||||
embodiment_ids = batch.get(self.config.embodiment_id_field)
|
||||
if embodiment_ids is None:
|
||||
return torch.full(
|
||||
(batch_size,),
|
||||
self.config.default_embodiment_id,
|
||||
dtype=torch.long,
|
||||
device=self.config.device,
|
||||
)
|
||||
if embodiment_ids.dim() == 0:
|
||||
embodiment_ids = embodiment_ids.unsqueeze(0)
|
||||
elif embodiment_ids.dim() > 1:
|
||||
embodiment_ids = embodiment_ids[:, -1]
|
||||
return embodiment_ids.to(device=self.config.device, dtype=torch.long)
|
||||
|
||||
def _collect_image_batches(self, batch: dict[str, Tensor]) -> tuple[list[list[Tensor]], Tensor]:
|
||||
camera_keys = self._camera_keys or sorted(key for key in batch if key.startswith(f"{OBS_IMAGES}."))
|
||||
if not camera_keys:
|
||||
raise ValueError("EVO1 requires at least one visual observation feature.")
|
||||
|
||||
# Normalize each camera tensor to (B, C, H, W) up-front so that batch_size is read
|
||||
# from a real batch dim and not from C in the unbatched (C, H, W) case.
|
||||
normalized: dict[str, Tensor] = {}
|
||||
for camera_key in camera_keys[: self.config.max_views]:
|
||||
image = batch[camera_key]
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
elif image.dim() == 5:
|
||||
image = image[:, -1]
|
||||
elif image.dim() != 4:
|
||||
raise ValueError(
|
||||
f"Unsupported image tensor shape for EVO1: key={camera_key} shape={tuple(image.shape)}"
|
||||
)
|
||||
normalized[camera_key] = image
|
||||
|
||||
batch_size = normalized[camera_keys[0]].shape[0]
|
||||
image_batches: list[list[Tensor]] = []
|
||||
image_masks = torch.zeros(batch_size, self.config.max_views, dtype=torch.bool)
|
||||
|
||||
for batch_index in range(batch_size):
|
||||
sample_images: list[Tensor] = []
|
||||
for camera_key in camera_keys[: self.config.max_views]:
|
||||
sample_images.append(normalized[camera_key][batch_index].detach().cpu())
|
||||
if not sample_images:
|
||||
raise ValueError("EVO1 received a batch without any image tensor.")
|
||||
while len(sample_images) < self.config.max_views:
|
||||
sample_images.append(torch.zeros_like(sample_images[0]))
|
||||
image_batches.append(sample_images[: self.config.max_views])
|
||||
image_masks[batch_index, : min(len(camera_keys), self.config.max_views)] = True
|
||||
|
||||
return image_batches, image_masks
|
||||
|
||||
def _compute_fused_tokens(
|
||||
self,
|
||||
prompts: list[str],
|
||||
image_batches: list[list[Tensor]],
|
||||
image_masks: Tensor,
|
||||
) -> Tensor:
|
||||
fused_tokens = self.model.get_vl_embeddings(
|
||||
images=image_batches,
|
||||
image_mask=image_masks,
|
||||
prompt=prompts,
|
||||
return_cls_only=self.config.return_cls_only,
|
||||
)
|
||||
return fused_tokens.to(device=self.config.device, dtype=self._compute_dtype)
|
||||
|
||||
def _compute_masked_loss(
|
||||
self,
|
||||
pred_velocity: Tensor,
|
||||
target_velocity: Tensor,
|
||||
action_mask: Tensor,
|
||||
reduction: str,
|
||||
) -> Tensor:
|
||||
flat_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=pred_velocity.dtype)
|
||||
sq_error = ((pred_velocity - target_velocity) * flat_mask).pow(2)
|
||||
active = flat_mask.sum(dim=1).clamp_min(1.0)
|
||||
per_sample_loss = sq_error.sum(dim=1) / active
|
||||
if reduction == "none":
|
||||
return per_sample_loss
|
||||
if reduction != "mean":
|
||||
raise ValueError(f"Unsupported reduction '{reduction}'")
|
||||
return sq_error.sum() / active.sum()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
prompts = self._normalize_task_batch(batch)
|
||||
image_batches, image_masks = self._collect_image_batches(batch)
|
||||
states, _state_mask = self._prepare_state(batch)
|
||||
actions_gt, action_mask = self._prepare_actions(batch)
|
||||
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||
states = states.to(dtype=self._training_compute_dtype)
|
||||
actions_gt = actions_gt.to(dtype=self._training_compute_dtype)
|
||||
fused_tokens = fused_tokens.to(dtype=self._training_compute_dtype)
|
||||
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||
|
||||
pred_velocity, noise = self.model(
|
||||
fused_tokens,
|
||||
state=states,
|
||||
actions_gt=actions_gt,
|
||||
action_mask=action_mask.to(device=self.config.device, dtype=self._compute_dtype),
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
flat_action_mask = action_mask.view(action_mask.shape[0], -1).to(dtype=actions_gt.dtype)
|
||||
target_velocity = (actions_gt - noise).view(actions_gt.shape[0], -1) * flat_action_mask
|
||||
loss = self._compute_masked_loss(pred_velocity, target_velocity, action_mask, reduction)
|
||||
loss_mean = loss.mean().item() if loss.ndim > 0 else loss.item()
|
||||
return loss, {
|
||||
"loss": loss_mean,
|
||||
"active_action_dims": float(action_mask.sum(dim=(1, 2)).float().mean().item()),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
prompts = self._normalize_task_batch(batch)
|
||||
image_batches, image_masks = self._collect_image_batches(batch)
|
||||
states, _state_mask = self._prepare_state(batch)
|
||||
fused_tokens = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||
states = states.to(dtype=self._inference_compute_dtype)
|
||||
fused_tokens = fused_tokens.to(dtype=self._inference_compute_dtype)
|
||||
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||
action_mask = self._prepare_inference_action_mask(states.shape[0])
|
||||
|
||||
with (
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
if self.config.use_amp and str(self.config.device).startswith("cuda")
|
||||
else nullcontext()
|
||||
):
|
||||
actions = self.model(
|
||||
fused_tokens,
|
||||
state=states,
|
||||
action_mask=action_mask,
|
||||
embodiment_ids=embodiment_ids,
|
||||
)
|
||||
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
|
||||
return actions[:, :, : self._env_action_dim]
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
if len(self._action_queue) == 0:
|
||||
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(action_chunk.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
@@ -1,106 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.evo1.configuration_evo1 import Evo1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
create_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
DONE,
|
||||
INFO,
|
||||
OBS_PREFIX,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
REWARD,
|
||||
TRUNCATED,
|
||||
)
|
||||
|
||||
|
||||
def evo1_batch_to_transition(batch: dict[str, Any]):
|
||||
transition = batch_to_transition(batch)
|
||||
complementary_data = dict(transition.get("complementary_data") or {})
|
||||
reserved = {ACTION, REWARD, DONE, TRUNCATED, INFO}
|
||||
for key, value in batch.items():
|
||||
if key in reserved or key.startswith(OBS_PREFIX):
|
||||
continue
|
||||
complementary_data.setdefault(key, value)
|
||||
return create_transition(
|
||||
observation=transition.get("observation"),
|
||||
action=transition.get("action"),
|
||||
reward=transition.get("reward", 0.0),
|
||||
done=transition.get("done", False),
|
||||
truncated=transition.get("truncated", False),
|
||||
info=transition.get("info", {}),
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
|
||||
|
||||
def make_evo1_pre_post_processors(
|
||||
config: Evo1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=evo1_batch_to_transition,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -46,14 +46,14 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .evo1.configuration_evo1 import Evo1Config
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
from .pretrained import PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig
|
||||
from .sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
@@ -89,7 +89,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x", "eo1", "evo1".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -132,10 +132,18 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "reward_classifier":
|
||||
from .sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
elif name == "smolvla":
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "sarm":
|
||||
from .sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "groot":
|
||||
from .groot.modeling_groot import GrootPolicy
|
||||
|
||||
@@ -148,14 +156,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .wall_x.modeling_wall_x import WallXPolicy
|
||||
|
||||
return WallXPolicy
|
||||
elif name == "eo1":
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
elif name == "evo1":
|
||||
from .evo1.modeling_evo1 import EVO1Policy
|
||||
|
||||
return EVO1Policy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -173,7 +173,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"smolvla", "wall_x", "eo1", "evo1".
|
||||
"smolvla", "reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -200,16 +200,14 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
elif policy_type == "xvla":
|
||||
return XVLAConfig(**kwargs)
|
||||
elif policy_type == "wall_x":
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
elif policy_type == "evo1":
|
||||
return Evo1Config(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -380,6 +378,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
from .sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
@@ -388,6 +394,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from .sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
processors = make_sarm_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from .groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
@@ -413,20 +427,6 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
processors = make_eo1_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, Evo1Config):
|
||||
from .evo1.processor_evo1 import make_evo1_pre_post_processors
|
||||
|
||||
processors = make_evo1_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
@@ -542,7 +542,7 @@ def make_policy(
|
||||
|
||||
logging.info("Loading policy's PEFT adapter.")
|
||||
|
||||
peft_pretrained_path = str(cfg.pretrained_path)
|
||||
peft_pretrained_path = cfg.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
||||
@@ -555,9 +555,7 @@ def make_policy(
|
||||
)
|
||||
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
policy = PeftModel.from_pretrained(
|
||||
policy, peft_pretrained_path, config=peft_config, is_trainable=True
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import field
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -109,6 +109,7 @@ class MultiEmbodimentActionEncoder(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -173,14 +174,17 @@ N_COLOR_CHANNELS = 3
|
||||
|
||||
|
||||
# config
|
||||
@dataclass
|
||||
class GR00TN15Config(PretrainedConfig):
|
||||
model_type = "gr00t_n1_5"
|
||||
backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."})
|
||||
|
||||
backbone_cfg: dict
|
||||
action_head_cfg: dict
|
||||
action_horizon: int
|
||||
action_dim: int
|
||||
compute_dtype: str = "float32"
|
||||
action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."})
|
||||
|
||||
action_horizon: int = field(init=False, metadata={"help": "Action horizon."})
|
||||
|
||||
action_dim: int = field(init=False, metadata={"help": "Action dimension."})
|
||||
compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."})
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -688,9 +688,8 @@ class DiffusionObjective(nn.Module):
|
||||
loss = F.mse_loss(predicted, target, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||
num_valid = mask.sum() * loss.shape[-1]
|
||||
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||
valid_actions = ~batch["action_is_pad"]
|
||||
loss = loss * valid_actions.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
@@ -753,9 +752,8 @@ class FlowMatchingObjective(nn.Module):
|
||||
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
mask = ~batch["action_is_pad"].unsqueeze(-1)
|
||||
num_valid = mask.sum() * loss.shape[-1]
|
||||
return (loss * mask).sum() / num_valid.clamp_min(1)
|
||||
valid_mask = ~batch["action_is_pad"]
|
||||
loss = loss * valid_mask.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
@@ -444,13 +444,13 @@ class PaliGemmaWithExpertModel(
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -666,7 +666,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
return lang_emb
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -747,8 +748,16 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) -> Tensor:
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
@@ -1283,11 +1292,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -728,8 +728,14 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor:
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
@@ -1256,11 +1262,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
@@ -226,7 +227,6 @@ class PI0FastPaliGemma(nn.Module):
|
||||
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
|
||||
if use_adarms[0]:
|
||||
text_config = self.paligemma.config.text_config
|
||||
del self.paligemma.model.language_model
|
||||
self.paligemma.model.language_model = PiGemmaModel(text_config)
|
||||
|
||||
self.to_bfloat16_for_selected_params(precision)
|
||||
@@ -260,15 +260,13 @@ class PI0FastPaliGemma(nn.Module):
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output
|
||||
norm = 2048**0.5
|
||||
features = features / norm * norm
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -418,7 +416,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language instruction tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
return lang_emb
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -432,7 +431,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
def fast_action_embed_func(fast_action_tokens):
|
||||
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||
return fast_emb
|
||||
fast_emb_dim = fast_emb.shape[-1]
|
||||
return fast_emb * math.sqrt(fast_emb_dim)
|
||||
|
||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||
embs.append(fast_action_emb)
|
||||
@@ -665,6 +665,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if t < max_decoding_steps - 1:
|
||||
# embed the newly generated token
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
@@ -769,6 +770,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Embed the single previous token
|
||||
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
@@ -197,9 +197,6 @@ class PiGemmaModel(GemmaModel): # type: ignore[misc]
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
# Free parent-allocated layers/norm before replacing to avoid ~2x peak memory.
|
||||
del self.layers
|
||||
del self.norm
|
||||
# if not getattr(config, "use_adarms", False):
|
||||
# return
|
||||
cond_dim = getattr(config, "adarms_cond_dim", None)
|
||||
@@ -331,7 +328,6 @@ class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
|
||||
|
||||
def __init__(self, config: GemmaConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
del self.model
|
||||
self.model = PiGemmaModel(config)
|
||||
|
||||
|
||||
@@ -340,7 +336,6 @@ class PaliGemmaModelWithPiGemma(PaliGemmaModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.language_model
|
||||
self.language_model = PiGemmaModel(config.text_config)
|
||||
|
||||
|
||||
@@ -349,7 +344,6 @@ class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGenera
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.model
|
||||
self.model = PaliGemmaModelWithPiGemma(config)
|
||||
|
||||
# Make modules available through conditional class for BC
|
||||
|
||||
@@ -19,7 +19,6 @@ from .action_queue import ActionQueue
|
||||
from .configuration_rtc import RTCConfig
|
||||
from .latency_tracker import LatencyTracker
|
||||
from .modeling_rtc import RTCProcessor
|
||||
from .relative import reanchor_relative_rtc_prefix
|
||||
|
||||
__all__ = [
|
||||
"ActionInterpolator",
|
||||
@@ -27,5 +26,4 @@ __all__ = [
|
||||
"LatencyTracker",
|
||||
"RTCConfig",
|
||||
"RTCProcessor",
|
||||
"reanchor_relative_rtc_prefix",
|
||||
]
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Relative-action helpers for Real-Time Chunking (RTC)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
to_relative_actions,
|
||||
)
|
||||
|
||||
|
||||
def reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: torch.Tensor,
|
||||
current_state: torch.Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
|
||||
|
||||
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
|
||||
is stored in absolute coordinates. Before feeding it back to the policy, this
|
||||
helper re-expresses those actions relative to the robot's current joint state
|
||||
and optionally normalizes them so the policy receives correctly scaled inputs.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
@@ -1,3 +1,5 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,15 +15,14 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import NormalizationMode
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs import NormalizationMode, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass(name="reward_classifier")
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@dataclass
|
||||
class RewardClassifierConfig(RewardModelConfig):
|
||||
class RewardClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Reward Classifier model."""
|
||||
|
||||
name: str = "reward_classifier"
|
||||
@@ -1,3 +1,5 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -17,10 +19,11 @@ import logging
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
|
||||
from ...pretrained import PreTrainedPolicy
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
@@ -96,7 +99,7 @@ class SpatialLearnedEmbeddings(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class Classifier(PreTrainedRewardModel):
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "reward_classifier"
|
||||
@@ -232,16 +235,6 @@ class Classifier(PreTrainedRewardModel):
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Returns 1.0 for success, 0.0 for failure based on image observations."""
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
output = self.predict(images)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
return (output.probabilities > 0.5).float()
|
||||
else:
|
||||
return torch.argmax(output.probabilities, dim=1).float()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Extract images and labels
|
||||
@@ -276,6 +269,10 @@ class Classifier(PreTrainedRewardModel):
|
||||
|
||||
def predict_reward(self, batch, threshold=0.5):
|
||||
"""Eval method. Returns predicted reward with the decision threshold as argument."""
|
||||
# Check for both OBS_IMAGE and OBS_IMAGES prefixes
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images from batch dict
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
@@ -285,3 +282,28 @@ class Classifier(PreTrainedRewardModel):
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return self.parameters()
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not select actions")
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not produce action chunks.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not predict action chunks")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
pass
|
||||
@@ -1,3 +1,5 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -25,7 +27,8 @@ from lerobot.processor import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
@@ -49,6 +52,8 @@ def make_classifier_processor(
|
||||
Args:
|
||||
config: The configuration object for the RewardClassifier.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
1
src/lerobot/policies/sarm/README.md
Symbolic link
1
src/lerobot/policies/sarm/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_sarm_README.md
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,6 +14,5 @@
|
||||
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .modeling_sarm import SARMRewardModel
|
||||
from .processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
__all__ = ["SARMConfig", "SARMRewardModel", "make_sarm_pre_post_processors"]
|
||||
__all__ = ["SARMConfig", "SARMRewardModel"]
|
||||
@@ -25,18 +25,18 @@ need ~num_frames/30 queries instead of one per frame (~30x speedup).
|
||||
|
||||
Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
@@ -58,9 +58,10 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
from lerobot.rewards.sarm.sarm_utils import normalize_stage_tau
|
||||
|
||||
from .modeling_sarm import SARMRewardModel
|
||||
from .processor_sarm import make_sarm_pre_post_processors
|
||||
from .sarm_utils import normalize_stage_tau
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
@@ -712,12 +713,12 @@ def main():
|
||||
epilog="""
|
||||
Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -20,15 +22,14 @@ Paper: https://arxiv.org/abs/2509.25358
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("sarm")
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@dataclass
|
||||
class SARMConfig(RewardModelConfig):
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
||||
|
||||
Supports three annotation modes:
|
||||
@@ -109,6 +110,7 @@ class SARMConfig(RewardModelConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
|
||||
raise ValueError(
|
||||
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -32,13 +34,14 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.rewards.sarm.sarm_utils import (
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .sarm_utils import (
|
||||
normalize_stage_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
|
||||
|
||||
class StageTransformer(nn.Module):
|
||||
@@ -350,7 +353,7 @@ def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
|
||||
return stage_onehot
|
||||
|
||||
|
||||
class SARMRewardModel(PreTrainedRewardModel):
|
||||
class SARMRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
SARM Reward Model for stage-aware task completion rewards.
|
||||
|
||||
@@ -468,23 +471,6 @@ class SARMRewardModel(PreTrainedRewardModel):
|
||||
self.subtask_model.to(device)
|
||||
return self
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute dense progress reward in [0, 1] from batch.
|
||||
|
||||
Expects batch to contain:
|
||||
- "observation_features" or video embeddings: (B, T, 512)
|
||||
- "language_embedding" or text embeddings: (B, 512)
|
||||
- optionally "observation.state": (B, T, state_dim)
|
||||
"""
|
||||
text_emb = batch.get("language_embedding", batch.get("text_features"))
|
||||
video_emb = batch.get("observation_features", batch.get("video_features"))
|
||||
state = batch.get("observation.state", batch.get("state_features"))
|
||||
|
||||
rewards = self.calculate_rewards(text_emb, video_emb, state)
|
||||
if isinstance(rewards, np.ndarray):
|
||||
rewards = torch.from_numpy(rewards).float()
|
||||
return rewards
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_rewards(
|
||||
self,
|
||||
@@ -645,9 +631,17 @@ class SARMRewardModel(PreTrainedRewardModel):
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""SARM has no episode-level state to reset."""
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
pass
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
raise NotImplementedError("SARM model does not predict action chunks")
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _train_step(
|
||||
self,
|
||||
img_emb: torch.Tensor, # (B, N, T, D)
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -58,15 +60,16 @@ from lerobot.processor import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.rewards.sarm.sarm_utils import (
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
find_stage_and_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
@@ -452,13 +455,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings
|
||||
# transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor
|
||||
output = self.clip_model.get_image_features(**inputs)
|
||||
if not isinstance(output, torch.Tensor):
|
||||
output = output.pooler_output
|
||||
if output is None:
|
||||
raise ValueError("pooler_output should not be None for CLIP models.")
|
||||
embeddings = output.detach().cpu()
|
||||
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
@@ -485,13 +482,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
inputs = self.clip_processor.tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# transformers 5.x returns BaseModelOutputWithPooling instead of a plain tensor
|
||||
output = self.clip_model.get_text_features(**inputs)
|
||||
if not isinstance(output, torch.Tensor):
|
||||
output = output.pooler_output
|
||||
if output is None:
|
||||
raise ValueError("pooler_output should not be None for CLIP models.")
|
||||
text_embedding = output.detach().cpu()
|
||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||
text_embedding = text_embedding.expand(batch_size, -1)
|
||||
|
||||
return text_embedding
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -394,21 +394,13 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone().mean().item()
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over valid (time, action) entries
|
||||
if actions_is_pad is None:
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
else:
|
||||
num_valid = ((~actions_is_pad).sum(dim=1) * losses.shape[-1]).clamp_min(1)
|
||||
per_sample_loss = losses.sum(dim=(1, 2)) / num_valid
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss over valid (time, action) entries
|
||||
if actions_is_pad is None:
|
||||
loss = losses.mean()
|
||||
else:
|
||||
num_valid = ((~actions_is_pad).sum() * losses.shape[-1]).clamp_min(1)
|
||||
loss = losses.sum() / num_valid
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@@ -97,8 +97,8 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
use_group_norm: bool = False
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
# VQ-VAE
|
||||
n_vqvae_training_steps: int = 20000
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -939,7 +939,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
||||
@@ -45,7 +45,7 @@ from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
@@ -985,7 +985,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_dtype(query_states.device.type)
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
|
||||
@@ -557,7 +557,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
def __post_init__(self):
|
||||
"""Initializes the reward classifier model after the dataclass is created."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||
self.reward_classifier.to(self.device)
|
||||
|
||||
@@ -142,10 +142,6 @@ class RelativeActionsProcessorStep(ProcessorStep):
|
||||
new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask)
|
||||
return new_transition
|
||||
|
||||
def get_cached_state(self) -> torch.Tensor | None:
|
||||
"""Return the cached ``observation.state`` used as the reference point for relative/absolute action conversions."""
|
||||
return self._last_state
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
@@ -186,8 +182,7 @@ class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||
"but relative_step is None. Ensure relative_step is set when constructing the postprocessor."
|
||||
)
|
||||
|
||||
cached_state = self.relative_step.get_cached_state()
|
||||
if cached_state is None:
|
||||
if self.relative_step._last_state is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep "
|
||||
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
|
||||
@@ -199,7 +194,9 @@ class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||
return new_transition
|
||||
|
||||
mask = self.relative_step._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_absolute_actions(action, cached_state, mask)
|
||||
new_transition[TransitionKey.ACTION] = to_absolute_actions(
|
||||
action, self.relative_step._last_state, mask
|
||||
)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .factory import (
|
||||
get_reward_model_class as get_reward_model_class,
|
||||
make_reward_model as make_reward_model,
|
||||
make_reward_model_config as make_reward_model_config,
|
||||
make_reward_pre_post_processors as make_reward_pre_post_processors,
|
||||
)
|
||||
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"RewardClassifierConfig",
|
||||
"SARMConfig",
|
||||
# Base class
|
||||
"PreTrainedRewardModel",
|
||||
# Factory functions
|
||||
"get_reward_model_class",
|
||||
"make_reward_model",
|
||||
"make_reward_model_config",
|
||||
"make_reward_pre_post_processors",
|
||||
]
|
||||
@@ -1,238 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
|
||||
|
||||
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
"""
|
||||
Retrieves a reward model class by its registered name.
|
||||
|
||||
This function uses dynamic imports to avoid loading all reward model classes into
|
||||
memory at once, improving startup time and reducing dependencies.
|
||||
|
||||
Args:
|
||||
name: The name of the reward model. Supported names are "reward_classifier",
|
||||
"sarm".
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the reward model name is not recognized.
|
||||
"""
|
||||
if name == "reward_classifier":
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
elif name == "sarm":
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Reward model type '{name}' is not available.") from e
|
||||
|
||||
|
||||
def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
"""
|
||||
Instantiates a reward model configuration object based on the reward type.
|
||||
|
||||
This factory function simplifies the creation of reward model configuration objects
|
||||
by mapping a string identifier to the corresponding config class.
|
||||
|
||||
Args:
|
||||
reward_type: The type of the reward model. Supported types include
|
||||
"reward_classifier", "sarm".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of a `RewardModelConfig` subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `reward_type` is not recognized.
|
||||
"""
|
||||
if reward_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif reward_type == "sarm":
|
||||
return SARMConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
return config_cls(**kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Reward model type '{reward_type}' is not available.") from e
|
||||
|
||||
|
||||
def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel:
|
||||
"""
|
||||
Instantiate a reward model from its configuration.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the reward model to be created. If
|
||||
`cfg.pretrained_path` is set, the model will be loaded with weights
|
||||
from that path.
|
||||
**kwargs: Additional keyword arguments forwarded to the model constructor
|
||||
(e.g., ``dataset_stats``, ``dataset_meta``).
|
||||
|
||||
Returns:
|
||||
An instantiated and device-placed reward model.
|
||||
"""
|
||||
reward_cls = get_reward_model_class(cfg.type)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
reward_model = reward_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
reward_model = reward_cls(**kwargs)
|
||||
|
||||
reward_model.to(cfg.device)
|
||||
assert isinstance(reward_model, torch.nn.Module)
|
||||
|
||||
return reward_model
|
||||
|
||||
|
||||
def make_reward_pre_post_processors(
|
||||
reward_cfg: RewardModelConfig,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Create pre- and post-processor pipelines for a given reward model.
|
||||
|
||||
Each reward model type has a dedicated factory function for its processors.
|
||||
|
||||
Args:
|
||||
reward_cfg: The configuration of the reward model for which to create processors.
|
||||
**kwargs: Additional keyword arguments passed to the processor factory
|
||||
(e.g., ``dataset_stats``, ``dataset_meta``).
|
||||
|
||||
Returns:
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
|
||||
Raises:
|
||||
ValueError: If a processor factory is not implemented for the given reward
|
||||
model configuration type.
|
||||
"""
|
||||
# Create a new processor based on reward model type
|
||||
if isinstance(reward_cfg, RewardClassifierConfig):
|
||||
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
|
||||
|
||||
return make_classifier_processor(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, SARMConfig):
|
||||
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
return make_sarm_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Processor for reward model type '{reward_cfg.type}' is not implemented."
|
||||
) from e
|
||||
return processors
|
||||
|
||||
|
||||
def _get_reward_model_cls_from_name(name: str) -> type[PreTrainedRewardModel]:
|
||||
"""Get reward model class from its registered name using dynamic imports.
|
||||
|
||||
This is used as a helper function to import reward models from 3rd party lerobot
|
||||
plugins.
|
||||
|
||||
Args:
|
||||
name: The name of the reward model.
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
"""
|
||||
if name not in RewardModelConfig.get_known_choices():
|
||||
raise ValueError(
|
||||
f"Unknown reward model name '{name}'. "
|
||||
f"Available reward models: {RewardModelConfig.get_known_choices()}"
|
||||
)
|
||||
|
||||
config_cls = RewardModelConfig.get_choice_class(name)
|
||||
config_cls_name = config_cls.__name__
|
||||
|
||||
model_name = config_cls_name.removesuffix("Config")
|
||||
if model_name == config_cls_name:
|
||||
raise ValueError(
|
||||
f"The config class name '{config_cls_name}' does not follow the expected naming convention. "
|
||||
f"Make sure it ends with 'Config'!"
|
||||
)
|
||||
|
||||
cls_name = model_name + "RewardModel"
|
||||
module_path = config_cls.__module__.replace("configuration_", "modeling_")
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
reward_cls = getattr(module, cls_name)
|
||||
return reward_cls
|
||||
|
||||
|
||||
def _make_processors_from_reward_model_config(
|
||||
config: RewardModelConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Create pre- and post-processors from a reward model configuration using dynamic imports.
|
||||
|
||||
This is used as a helper function to import processor factories from 3rd party
|
||||
lerobot reward model plugins.
|
||||
|
||||
Args:
|
||||
config: The reward model configuration object.
|
||||
dataset_stats: Dataset statistics for normalization.
|
||||
|
||||
Returns:
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
"""
|
||||
reward_type = config.type
|
||||
function_name = f"make_{reward_type}_pre_post_processors"
|
||||
module_path = config.__class__.__module__.replace("configuration_", "processor_")
|
||||
logging.debug(
|
||||
f"Instantiating reward pre/post processors using function '{function_name}' "
|
||||
f"from module '{module_path}'"
|
||||
)
|
||||
module = importlib.import_module(module_path)
|
||||
function = getattr(module, function_name)
|
||||
return function(config, dataset_stats=dataset_stats)
|
||||
@@ -1,244 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedRewardModel")
|
||||
|
||||
|
||||
class PreTrainedRewardModel(nn.Module, HubMixin, abc.ABC):
|
||||
"""Base class for reward models."""
|
||||
|
||||
config_class: None
|
||||
name: None
|
||||
|
||||
def __init__(self, config: RewardModelConfig, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, RewardModelConfig):
|
||||
raise ValueError(
|
||||
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
||||
"`RewardModelConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "config_class", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: RewardModelConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""
|
||||
The reward model is set in evaluation mode by default using `reward.eval()` (dropout modules are
|
||||
deactivated). To train it, you should first set it back in training mode with `reward.train()`.
|
||||
"""
|
||||
if config is None:
|
||||
config = RewardModelConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
reward.to(config.device)
|
||||
reward.eval()
|
||||
return reward
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
# Create base kwargs
|
||||
kwargs = {"strict": strict}
|
||||
|
||||
# Add device parameter for newer versions that support it
|
||||
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
|
||||
kwargs["device"] = map_location
|
||||
|
||||
# Load the model with appropriate kwargs
|
||||
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
|
||||
if missing_keys:
|
||||
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
||||
|
||||
# For older versions, manually move to device if needed
|
||||
if "device" not in kwargs and map_location != "cpu":
|
||||
logging.warning(
|
||||
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||
" This leads to a slower loading time."
|
||||
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||
)
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
def get_optim_params(self):
|
||||
"""
|
||||
Returns the reward-model-specific parameters dict to be passed on to the optimizer.
|
||||
"""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset any internal state."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute a scalar reward signal for a batch of observations.
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing at minimum observation tensors.
|
||||
May also contain "action", "next_observation.*", etc.
|
||||
|
||||
Returns:
|
||||
Tensor of shape ``(batch_size,)`` with reward values.
|
||||
"""
|
||||
...
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Training forward pass — override for trainable reward models."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} is not trainable. Only use compute_reward() for inference."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
"""Whether this reward model can be trained via ``lerobot-train``.
|
||||
|
||||
Trainable reward models override :meth:`forward`; zero-shot models
|
||||
inherit the base implementation that raises ``NotImplementedError``.
|
||||
"""
|
||||
return type(self).forward is not PreTrainedRewardModel.forward
|
||||
|
||||
def push_model_to_hub(self, cfg: "TrainPipelineConfig"):
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
# Push the files to the repo in a single commit
|
||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||
saved_path = Path(tmp) / repo_id
|
||||
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config
|
||||
|
||||
commit_info = api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
folder_path=saved_path,
|
||||
commit_message="Upload reward model weights, train config and readme",
|
||||
allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"],
|
||||
ignore_patterns=["*.tmp", "*.log"],
|
||||
)
|
||||
|
||||
logging.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
|
||||
def generate_model_card(
|
||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||
) -> ModelCard:
|
||||
card_data = ModelCardData(
|
||||
license=license or "apache-2.0",
|
||||
library_name="lerobot",
|
||||
pipeline_tag="robotics",
|
||||
tags=list(set(tags or []).union({"robotics", "lerobot", "reward-model", model_type})),
|
||||
model_name=model_type,
|
||||
datasets=dataset_repo_id,
|
||||
)
|
||||
|
||||
template_card = (
|
||||
files("lerobot.templates")
|
||||
.joinpath("lerobot_rewardmodel_modelcard_template.md")
|
||||
.read_text(encoding="utf-8")
|
||||
)
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
@@ -193,15 +193,15 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset(
|
||||
fps=int(original_dataset.fps),
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info.features,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info.features:
|
||||
new_dataset.meta.info.features[key]["shape"] = (3, *resize_size)
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||
prev_episode_index = 0
|
||||
|
||||
@@ -54,7 +54,6 @@ class BiOpenArmFollower(Robot):
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=left_cameras,
|
||||
side=config.left_arm_config.side,
|
||||
@@ -73,7 +72,6 @@ class BiOpenArmFollower(Robot):
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=right_cameras,
|
||||
side=config.right_arm_config.side,
|
||||
|
||||
@@ -46,7 +46,7 @@ class LeKiwiConfig(RobotConfig):
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = True
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -66,10 +66,6 @@ class OpenArmFollowerConfigBase:
|
||||
# Whether to disable torque when disconnecting
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# When True, expose `.vel` and `.torque` per motor in observation features.
|
||||
# Default False for compatibility with the position-only openarm_mini teleoperator.
|
||||
use_velocity_and_torque: bool = False
|
||||
|
||||
# Safety limit for relative target positions
|
||||
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
@@ -93,9 +93,8 @@ class OpenArmFollower(Robot):
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
if self.config.use_velocity_and_torque:
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
features[f"{motor}.vel"] = float # Add this
|
||||
features[f"{motor}.torque"] = float # Add this
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -236,9 +235,8 @@ class OpenArmFollower(Robot):
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
|
||||
if self.config.use_velocity_and_torque:
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
|
||||
@@ -75,7 +75,7 @@ class SentryStrategyConfig(RolloutStrategyConfig):
|
||||
# Target video file size in MB for episode rotation. Episodes are
|
||||
# saved once the estimated video duration would exceed this limit.
|
||||
# Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when set to None.
|
||||
target_video_file_size_mb: int | None = None
|
||||
target_video_file_size_mb: float | None = None
|
||||
|
||||
|
||||
@RolloutStrategyConfig.register_subclass("highlight")
|
||||
@@ -90,7 +90,7 @@ class HighlightStrategyConfig(RolloutStrategyConfig):
|
||||
"""
|
||||
|
||||
ring_buffer_seconds: float = 10.0
|
||||
ring_buffer_max_memory_mb: int = 1024
|
||||
ring_buffer_max_memory_mb: float = 1024.0
|
||||
save_key: str = "s"
|
||||
push_key: str = "h"
|
||||
|
||||
@@ -150,7 +150,7 @@ class DAggerStrategyConfig(RolloutStrategyConfig):
|
||||
upload_every_n_episodes: int = 5
|
||||
# Target video file size in MB for episode rotation (record_autonomous
|
||||
# mode only). Defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB when None.
|
||||
target_video_file_size_mb: int | None = None
|
||||
target_video_file_size_mb: float | None = None
|
||||
input_device: str = "keyboard"
|
||||
keyboard: DAggerKeyboardConfig = field(default_factory=DAggerKeyboardConfig)
|
||||
pedal: DAggerPedalConfig = field(default_factory=DAggerPedalConfig)
|
||||
@@ -209,12 +209,6 @@ class RolloutConfig:
|
||||
# Rename map for mapping robot/dataset observation keys to policy keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# Hardware teardown
|
||||
# When True (default), smoothly interpolate the robot back to the joint
|
||||
# positions captured at startup before disconnecting. Set to False to
|
||||
# leave the robot in its final achieved pose at shutdown.
|
||||
return_to_initial_position: bool = True
|
||||
|
||||
# Torch compile
|
||||
use_torch_compile: bool = False
|
||||
torch_compile_backend: str = "inductor"
|
||||
|
||||
@@ -27,7 +27,7 @@ from threading import Event
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType
|
||||
from lerobot.configs import FeatureType, PreTrainedConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
aggregate_pipeline_dataset_features,
|
||||
@@ -43,7 +43,6 @@ from lerobot.processor import (
|
||||
make_default_processors,
|
||||
rename_stats,
|
||||
)
|
||||
from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
|
||||
@@ -52,7 +51,6 @@ from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .inference import (
|
||||
InferenceEngine,
|
||||
RTCInferenceConfig,
|
||||
SyncInferenceConfig,
|
||||
create_inference_engine,
|
||||
)
|
||||
from .robot_wrapper import ThreadSafeRobot
|
||||
@@ -178,26 +176,33 @@ def build_rollout_context(
|
||||
policy_config = cfg.policy
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
for attr in ("device", "use_amp"):
|
||||
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
|
||||
cli_val = getattr(cfg.policy, attr)
|
||||
if cli_val is not None:
|
||||
setattr(full_config, attr, cli_val)
|
||||
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
if hasattr(full_config, "compile_model"):
|
||||
full_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if full_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
if policy_config.use_peft:
|
||||
if full_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
@@ -252,12 +257,10 @@ def build_rollout_context(
|
||||
teleop.connect()
|
||||
logger.info("Teleoperator connected")
|
||||
|
||||
# TODO(Steven): once Teleoperator motor-control methods are standardised
|
||||
# (``enable_torque`` / ``disable_torque`` / ``write_goal_positions``), gate
|
||||
# the DAgger strategy on their presence here and fail fast with a helpful
|
||||
# message instead of relying on the operator to pre-align the leader by
|
||||
# hand. See :func:`DAggerStrategy._apply_transition` for the matching
|
||||
# disabled call sites.
|
||||
# DAgger requires teleop with motor control capabilities (enable_torque,
|
||||
# disable_torque, write_goal_positions).
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
|
||||
# required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
|
||||
# missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
|
||||
@@ -269,13 +272,10 @@ def build_rollout_context(
|
||||
# )
|
||||
|
||||
# --- 4. Features + action-key reconciliation ---------------------
|
||||
# TODO(Steven):Only ``.pos`` joint features are routed to the policy as state and as the
|
||||
# action target; velocity and torque channels (when present) are kept in
|
||||
# the raw observation but excluded from the policy-facing tensors.
|
||||
# TODO(Steven): Only `.pos` joint features are used for policy inference — velocity and
|
||||
# torque channels are observation-only and must be excluded from the state
|
||||
# and action tensors that the policy sees.
|
||||
all_obs_features = robot.observation_features
|
||||
# ``observation_features`` values are either a tuple (camera shape) or the
|
||||
# ``float`` type itself used as a sentinel for scalar motor features —
|
||||
# see ``dict[str, type | tuple]`` annotation on ``Robot.observation_features``.
|
||||
observation_features_hw = {
|
||||
k: v
|
||||
for k, v in all_obs_features.items()
|
||||
@@ -308,9 +308,7 @@ def build_rollout_context(
|
||||
# Validate visual features if no rename_map is active
|
||||
rename_map = cfg.rename_map
|
||||
if not rename_map:
|
||||
expected_visuals = {
|
||||
k for k, v in policy_config.input_features.items() if v.type == FeatureType.VISUAL
|
||||
}
|
||||
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {
|
||||
f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
|
||||
}
|
||||
@@ -355,7 +353,6 @@ def build_rollout_context(
|
||||
"Use --dataset.repo_id=<user>/rollout_<name> for policy deployment datasets."
|
||||
)
|
||||
cfg.dataset.stamp_repo_id()
|
||||
target_video_mb = getattr(cfg.strategy, "target_video_file_size_mb", None)
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
@@ -371,7 +368,6 @@ def build_rollout_context(
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
video_files_size_in_mb=target_video_mb,
|
||||
)
|
||||
|
||||
if dataset is not None:
|
||||
@@ -395,15 +391,6 @@ def build_rollout_context(
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(cfg.inference, SyncInferenceConfig) and any(
|
||||
isinstance(step, RelativeActionsProcessorStep) and step.enabled
|
||||
for step in getattr(preprocessor, "steps", ())
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"SyncInferenceEngine does not support policies with relative actions for now."
|
||||
"Use --inference.type=rtc or remove relative action processor steps from the policy pipeline."
|
||||
)
|
||||
|
||||
# --- 7. Inference strategy (needs policy + pre/post + hardware) --
|
||||
logger.info(
|
||||
"Creating inference engine (type=%s)...",
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
|
||||
"""Inference engine package — backend-agnostic action production.
|
||||
|
||||
Concrete backends (``sync``, ``rtc``, ...) expose the same small interface so
|
||||
rollout strategies never branch on which backend is in use.
|
||||
Concrete strategies (sync, RTC, …) expose the same small interface so
|
||||
rollout strategies never branch on the inference backend.
|
||||
"""
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
"""Inference engine ABC.
|
||||
|
||||
Rollout strategies consume actions through this small interface so they
|
||||
do not need to know whether inference happens inline on the control thread
|
||||
or asynchronously in a background thread (RTC).
|
||||
do not need to know whether the inference engine is synchronous, runs in
|
||||
a background thread (RTC), or comes from an external source.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -29,10 +29,9 @@ import torch
|
||||
class InferenceEngine(abc.ABC):
|
||||
"""Abstract backend for producing actions during rollout.
|
||||
|
||||
Subclasses decide whether inference happens inline on the control
|
||||
thread or asynchronously in a background thread. The contract is
|
||||
minimal so additional backends can be plugged in without touching
|
||||
rollout strategies.
|
||||
Subclasses decide whether inference happens inline, in a background
|
||||
thread, or externally. The contract is minimal so new backends can
|
||||
be added without touching rollout strategies.
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
@@ -44,8 +43,8 @@ class InferenceEngine(abc.ABC):
|
||||
-----------------
|
||||
``get_action(obs_frame)`` — return the next action tensor, or
|
||||
``None`` if none is available (e.g. async queue empty). Sync
|
||||
backends always compute from ``obs_frame``; async backends ignore
|
||||
it (they receive observations via ``notify_observation``).
|
||||
backends always compute from ``obs_frame``; async backends may
|
||||
ignore it (they get observations via ``notify_observation``).
|
||||
|
||||
Optional hooks
|
||||
--------------
|
||||
|
||||
@@ -68,8 +68,9 @@ class SyncInferenceConfig(InferenceEngineConfig):
|
||||
class RTCInferenceConfig(InferenceEngineConfig):
|
||||
"""Real-Time Chunking: async policy inference in a background thread."""
|
||||
|
||||
# Eagerly constructed so draccus exposes nested fields directly on the CLI
|
||||
# (e.g. ``--inference.rtc.execution_horizon=...``).
|
||||
# ``RTCConfig`` is a small dataclass with default-only fields, so eagerly
|
||||
# constructing one here costs nothing and keeps draccus' CLI surface flat
|
||||
# (``--inference.rtc.execution_horizon=...`` etc.). No need to lazy-init.
|
||||
rtc: RTCConfig = field(default_factory=RTCConfig)
|
||||
queue_threshold: int = 30
|
||||
|
||||
|
||||
@@ -32,14 +32,18 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker, reanchor_relative_rtc_prefix
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
@@ -62,6 +66,35 @@ _RTC_JOIN_TIMEOUT_S: float = 3.0
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: torch.Tensor,
|
||||
current_state: torch.Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> torch.Tensor:
|
||||
"""Convert absolute leftover actions into model-space for relative-action RTC policies.
|
||||
|
||||
When using relative actions, the RTC prefix (previous chunk's unexecuted tail)
|
||||
is stored in absolute coordinates. Before feeding it back to the policy, this
|
||||
helper re-expresses those actions relative to the robot's current joint state
|
||||
and optionally normalizes them so the policy receives correctly scaled inputs.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
|
||||
|
||||
def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int) -> torch.Tensor:
|
||||
"""Pad or truncate RTC prefix actions to a fixed length for stable compiled inference."""
|
||||
if prev_actions.ndim != 2:
|
||||
@@ -76,6 +109,21 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
|
||||
return padded
|
||||
|
||||
|
||||
def _get_current_raw_state(
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
fallback_state: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
"""Return the current raw state cached by the relative-action step.
|
||||
|
||||
``RelativeActionsProcessorStep`` caches the observation state before any
|
||||
observation normalization. Re-anchoring RTC leftovers must use that raw
|
||||
state rather than the normalized observation that the policy consumes.
|
||||
"""
|
||||
if relative_step._last_state is not None:
|
||||
return relative_step._last_state
|
||||
return fallback_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTCInferenceEngine
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -285,15 +333,15 @@ class RTCInferenceEngine(InferenceEngine):
|
||||
preprocessed = self._preprocessor(obs_batch)
|
||||
|
||||
if prev_actions is not None and self._relative_step is not None:
|
||||
# Rebase against the raw cached state so the leftover tail stays in
|
||||
# the training-time coordinate frame.
|
||||
raw_state = self._relative_step.get_cached_state()
|
||||
if raw_state is not None:
|
||||
state_tensor = _get_current_raw_state(
|
||||
self._relative_step, obs_batch.get(OBS_STATE)
|
||||
)
|
||||
if state_tensor is not None:
|
||||
prev_abs = queue.get_processed_left_over()
|
||||
if prev_abs is not None and prev_abs.numel() > 0:
|
||||
prev_actions = reanchor_relative_rtc_prefix(
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_abs,
|
||||
current_state=raw_state,
|
||||
current_state=state_tensor,
|
||||
relative_step=self._relative_step,
|
||||
normalizer_step=self._normalizer_step,
|
||||
policy_device=policy_device,
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
|
||||
@@ -31,27 +32,12 @@ from .base import InferenceEngine
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO(Steven): support relative-action policies. The per-tick flow refreshes
|
||||
# ``RelativeActionsProcessorStep._last_state`` every call, so cached chunk
|
||||
# actions popped on later ticks get reanchored to the *current* robot state and
|
||||
# absolute targets drift through the chunk. Relative-action policies are
|
||||
# rejected at context-build time today; RTC postprocesses the whole chunk and
|
||||
# is unaffected.
|
||||
#
|
||||
# Candidate fix: drive the policy via ``predict_action_chunk`` and serve a
|
||||
# local FIFO of postprocessed actions. Eliminates drift by construction and
|
||||
# saves per-tick pre/post work, but bypasses ``select_action`` — needs
|
||||
# fallbacks for SAC (raises), ACT temporal ensembling (ensembler lives in
|
||||
# ``select_action``), and Diffusion-family (obs-history queues populated as a
|
||||
# side effect of ``select_action``).
|
||||
|
||||
|
||||
class SyncInferenceEngine(InferenceEngine):
|
||||
"""Inline synchronous inference: compute one action per call.
|
||||
|
||||
``get_action`` runs the full policy pipeline (pre/post-processor +
|
||||
``select_action``) on the given observation frame and returns a
|
||||
CPU action tensor reordered to match the dataset action keys.
|
||||
``get_action`` runs the full policy pipeline when its local action
|
||||
queue is empty, postprocesses the whole predicted chunk immediately,
|
||||
and then returns one already-postprocessed CPU action at a time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -73,6 +59,8 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
self._task = task
|
||||
self._device = torch.device(device or "cpu")
|
||||
self._robot_type = robot_type
|
||||
self._processed_action_queue: deque[torch.Tensor] = deque()
|
||||
|
||||
logger.info(
|
||||
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
|
||||
self._device,
|
||||
@@ -93,9 +81,28 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
self._policy.reset()
|
||||
self._preprocessor.reset()
|
||||
self._postprocessor.reset()
|
||||
self._processed_action_queue.clear()
|
||||
|
||||
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
|
||||
"""Convert a postprocessed action chunk into ordered per-step CPU tensors."""
|
||||
if action_chunk.ndim == 2:
|
||||
action_chunk = action_chunk.unsqueeze(0)
|
||||
|
||||
n_action_steps = getattr(self._policy.config, "n_action_steps", action_chunk.shape[1])
|
||||
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
|
||||
|
||||
for action in action_chunk.squeeze(0):
|
||||
action_tensor = action.cpu()
|
||||
action_dict = make_robot_action(action_tensor, self._dataset_features)
|
||||
ordered_action = torch.tensor(
|
||||
[action_dict[k] for k in self._ordered_action_keys], dtype=action_tensor.dtype
|
||||
)
|
||||
self._processed_action_queue.append(ordered_action)
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
|
||||
if self._processed_action_queue:
|
||||
return self._processed_action_queue.popleft().clone()
|
||||
if obs_frame is None:
|
||||
return None
|
||||
# Shallow copy is intentional: the caller (`send_next_action`) builds
|
||||
@@ -112,11 +119,10 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
observation, self._device, self._task, self._robot_type
|
||||
)
|
||||
observation = self._preprocessor(observation)
|
||||
action = self._policy.select_action(observation)
|
||||
action = self._postprocessor(action)
|
||||
action_tensor = action.squeeze(0).cpu()
|
||||
action_chunk = self._policy.predict_action_chunk(observation)
|
||||
processed_chunk = self._postprocessor(action_chunk)
|
||||
|
||||
# Reorder to match dataset action ordering so the caller can treat
|
||||
# the returned tensor uniformly across backends.
|
||||
action_dict = make_robot_action(action_tensor, self._dataset_features)
|
||||
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
|
||||
self._enqueue_processed_chunk(processed_chunk)
|
||||
if not self._processed_action_queue:
|
||||
return None
|
||||
return self._processed_action_queue.popleft().clone()
|
||||
|
||||
@@ -47,7 +47,7 @@ class RolloutRingBuffer:
|
||||
count.
|
||||
"""
|
||||
|
||||
def __init__(self, max_seconds: float = 30.0, max_memory_mb: int = 2048, fps: float = 30.0) -> None:
|
||||
def __init__(self, max_seconds: float = 30.0, max_memory_mb: float = 2048.0, fps: float = 30.0) -> None:
|
||||
self._max_frames = int(max_seconds * fps)
|
||||
self._max_bytes = int(max_memory_mb * 1024 * 1024)
|
||||
self._buffer: deque[dict] = deque(maxlen=self._max_frames)
|
||||
|
||||
@@ -47,8 +47,12 @@ class BaseStrategy(RolloutStrategy):
|
||||
interpolator = self._interpolator
|
||||
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
observation_interval = 1.0 / cfg.fps
|
||||
|
||||
start_time = time.perf_counter()
|
||||
next_observation_time = 0.0
|
||||
obs = None
|
||||
obs_processed = None
|
||||
engine.resume()
|
||||
logger.info("Base strategy control loop started")
|
||||
|
||||
@@ -59,12 +63,18 @@ class BaseStrategy(RolloutStrategy):
|
||||
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
if obs is None or loop_start >= next_observation_time:
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
next_observation_time = loop_start + observation_interval
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
|
||||
if obs_processed is None:
|
||||
continue
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
|
||||
@@ -78,8 +88,5 @@ class BaseStrategy(RolloutStrategy):
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Disconnect hardware and stop inference."""
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=ctx.runtime.cfg.return_to_initial_position,
|
||||
)
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Base strategy teardown complete")
|
||||
|
||||
@@ -32,7 +32,7 @@ from ..inference import InferenceEngine
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..configs import RolloutStrategyConfig
|
||||
from ..context import HardwareContext, ProcessorContext, RolloutContext, RuntimeContext
|
||||
from ..context import HardwareContext, RolloutContext, RuntimeContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +50,6 @@ class RolloutStrategy(abc.ABC):
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._warmup_flushed: bool = False
|
||||
self._cached_obs_processed: dict | None = None
|
||||
|
||||
def _init_engine(self, ctx: RolloutContext) -> None:
|
||||
"""Attach the inference engine and action interpolator, then start the backend.
|
||||
@@ -66,32 +65,8 @@ class RolloutStrategy(abc.ABC):
|
||||
self._engine.reset()
|
||||
self._engine.start()
|
||||
self._warmup_flushed = False
|
||||
self._cached_obs_processed = None
|
||||
logger.info("Inference engine started")
|
||||
|
||||
def _process_observation_and_notify(self, processors: ProcessorContext, obs_raw: dict) -> dict:
|
||||
"""Run the observation processor and notify the engine — throttled to policy ticks.
|
||||
|
||||
Callers are responsible for calling ``robot.get_observation()`` every loop
|
||||
iteration so ``obs_raw`` stays fresh for the action post-processor. This
|
||||
helper gates only the comparatively expensive bits — the processor pipeline
|
||||
and ``engine.notify_observation`` — to fire when the interpolator signals
|
||||
it needs a new action (once per ``interpolation_multiplier`` ticks). On
|
||||
interpolated ticks the cached ``obs_processed`` is reused.
|
||||
|
||||
With ``interpolation_multiplier == 1`` this is equivalent to the unthrottled
|
||||
path: ``needs_new_action()`` is True every tick.
|
||||
|
||||
The cache is implicitly invalidated whenever ``interpolator.reset()`` is
|
||||
called (warmup completion, DAgger phase transitions back to AUTONOMOUS),
|
||||
because reset makes ``needs_new_action()`` return True on the next call.
|
||||
"""
|
||||
if self._cached_obs_processed is None or self._interpolator.needs_new_action():
|
||||
obs_processed = processors.robot_observation_processor(obs_raw)
|
||||
self._engine.notify_observation(obs_processed)
|
||||
self._cached_obs_processed = obs_processed
|
||||
return self._cached_obs_processed
|
||||
|
||||
def _handle_warmup(self, use_torch_compile: bool, loop_start: float, control_interval: float) -> bool:
|
||||
"""Handle torch.compile warmup phase.
|
||||
|
||||
@@ -116,20 +91,16 @@ class RolloutStrategy(abc.ABC):
|
||||
engine.resume()
|
||||
return False
|
||||
|
||||
def _teardown_hardware(self, hw: HardwareContext, return_to_initial_position: bool = True) -> None:
|
||||
"""Stop the inference engine, optionally return robot to initial position, and disconnect hardware."""
|
||||
def _teardown_hardware(self, hw: HardwareContext) -> None:
|
||||
"""Stop the inference engine, return robot to initial position, and disconnect hardware."""
|
||||
if self._engine is not None:
|
||||
logger.info("Stopping inference engine...")
|
||||
self._engine.stop()
|
||||
robot = hw.robot_wrapper.inner
|
||||
if robot.is_connected:
|
||||
if return_to_initial_position and hw.initial_position:
|
||||
if hw.initial_position:
|
||||
logger.info("Returning robot to initial position before shutdown...")
|
||||
self._return_to_initial_position(hw)
|
||||
elif not return_to_initial_position:
|
||||
logger.info(
|
||||
"Skipping return-to-initial-position (disabled by config); leaving robot in final pose."
|
||||
)
|
||||
logger.info("Disconnecting robot...")
|
||||
robot.disconnect()
|
||||
teleop = hw.teleop
|
||||
@@ -223,7 +194,7 @@ def estimate_max_episode_seconds(
|
||||
The estimate ignores codec-specific settings (CRF, preset) on purpose:
|
||||
we only need a rough lower bound on bitrate, not a precise prediction.
|
||||
|
||||
Falls back to 300 s (5 min) when no video features are present.
|
||||
Falls back to 600 s (10 min) when no video features are present.
|
||||
"""
|
||||
# 0.1 bits-per-pixel is a *low* estimate for CRF-30 streaming video of
|
||||
# robot footage (real-world is typically 0.1 – 0.3 bpp). Under-
|
||||
@@ -237,16 +208,16 @@ def estimate_max_episode_seconds(
|
||||
if feat.get("dtype") == "video":
|
||||
shape = feat.get("shape", ())
|
||||
|
||||
# (H, W, C) — bits-per-pixel is a per-spatial-pixel metric,
|
||||
# so we exclude the channel dimension from the count.
|
||||
if len(shape) == 3:
|
||||
pixels = shape[0] * shape[1]
|
||||
camera_pixels.append(pixels)
|
||||
else:
|
||||
raise ValueError(f"Unexpected video feature shape: {shape}")
|
||||
# Assuming shape could be (C, H, W) or (T, C, H, W)
|
||||
# We want to extract the spatial dimensions.
|
||||
if len(shape) >= 3:
|
||||
h, w = shape[-2], shape[-1]
|
||||
pixels = h * w
|
||||
if pixels > 0:
|
||||
camera_pixels.append(pixels)
|
||||
|
||||
if not camera_pixels:
|
||||
return 300.0
|
||||
return 600.0
|
||||
|
||||
# Use the smallest camera: it produces the lowest bitrate and therefore
|
||||
# takes the longest to reach the target — the conservative choice.
|
||||
@@ -256,7 +227,7 @@ def estimate_max_episode_seconds(
|
||||
|
||||
# Guard against division by zero just in case
|
||||
if bytes_per_second <= 0:
|
||||
return 300.0
|
||||
return 600.0
|
||||
|
||||
return (target_size_mb * 1024 * 1024) / bytes_per_second
|
||||
|
||||
|
||||
@@ -24,22 +24,14 @@ the ``input_device`` config field. Each device exposes three actions:
|
||||
1. **pause_resume** — Toggle policy execution (AUTONOMOUS <-> PAUSED).
|
||||
2. **correction** — Toggle correction recording (PAUSED <-> CORRECTING).
|
||||
3. **upload** — Push dataset to hub on demand (corrections-only mode).
|
||||
ESC (keyboard only) — Stop session.
|
||||
ESC (keyboard only) — Stop session.
|
||||
|
||||
Recording modes:
|
||||
Recording Modes:
|
||||
``record_autonomous=True``: Sentry-like continuous recording with
|
||||
time-based episode rotation. Both autonomous and correction
|
||||
frames are recorded; corrections tagged ``intervention=True``.
|
||||
``record_autonomous=False``: Only correction windows are recorded.
|
||||
Each correction (start to stop) becomes one episode.
|
||||
|
||||
Teleoperator handover:
|
||||
On AUTONOMOUS → PAUSED, actuated teleops (those with non-empty
|
||||
``feedback_features``, e.g. SO-101, OpenArmMini) are smoothly driven to
|
||||
the follower's last position via ``send_feedback`` so the operator takes
|
||||
over without a jerk. Non-actuated teleops cannot be driven,
|
||||
so on PAUSED → CORRECTING the follower is instead slid to the teleop's
|
||||
current pose before the correction begins.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -176,27 +168,15 @@ class DAggerEvents:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
|
||||
) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
"""Smoothly move teleop to target position via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback
|
||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
||||
Requires the teleoperator to support motor control methods
|
||||
(``enable_torque``, ``write_goal_positions``, ``get_action``).
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
@@ -204,28 +184,13 @@ def _teleop_smooth_move_to(
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _follower_smooth_move_to(
|
||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
||||
to the follower, we bring the follower to the teleop's current pose.
|
||||
Both ``current`` and ``target`` must be in robot-action key space.
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.write_goal_positions(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
@@ -406,10 +371,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=ctx.runtime.cfg.return_to_initial_position,
|
||||
)
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("DAgger strategy teardown complete")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -441,6 +403,9 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
engine.resume()
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
@@ -464,35 +429,24 @@ class DAggerStrategy(RolloutStrategy):
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(
|
||||
old_phase,
|
||||
new_phase,
|
||||
engine,
|
||||
interpolator,
|
||||
ctx,
|
||||
last_action,
|
||||
)
|
||||
if new_phase == DAggerPhase.AUTONOMOUS:
|
||||
last_action = None
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
|
||||
phase = events.phase
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# --- CORRECTING: human teleop control ---
|
||||
# TODO(Steven): teleop runs at the same FPS as the policy. To
|
||||
# decouple the two, sample teleop at its native rate and
|
||||
# interpolate to the control loop's tick rate.
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
last_action = robot_action_to_send
|
||||
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
@@ -509,7 +463,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
|
||||
# --- AUTONOMOUS: policy control ---
|
||||
else:
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
@@ -518,9 +472,8 @@ class DAggerStrategy(RolloutStrategy):
|
||||
if action_dict is not None:
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
@@ -530,9 +483,9 @@ class DAggerStrategy(RolloutStrategy):
|
||||
dataset.add_frame(frame)
|
||||
record_tick += 1
|
||||
|
||||
# Episode rotation derived from the video file-size target.
|
||||
# Saving is deferred while a correction is ongoing so the
|
||||
# episode boundary lands on a clean autonomous frame.
|
||||
# Episode rotation derived from video file-size target.
|
||||
# Do NOT save mid-correction — wait for the correction
|
||||
# to finish so the episode boundary is clean.
|
||||
elapsed = time.perf_counter() - episode_start
|
||||
if elapsed >= episode_duration_s and phase != DAggerPhase.CORRECTING:
|
||||
with self._episode_lock:
|
||||
@@ -563,6 +516,9 @@ class DAggerStrategy(RolloutStrategy):
|
||||
finally:
|
||||
logger.info("DAgger continuous control loop ended — pausing engine")
|
||||
engine.pause()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
@@ -598,6 +554,9 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
engine.resume()
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
@@ -625,16 +584,8 @@ class DAggerStrategy(RolloutStrategy):
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(
|
||||
old_phase,
|
||||
new_phase,
|
||||
engine,
|
||||
interpolator,
|
||||
ctx,
|
||||
last_action,
|
||||
)
|
||||
if new_phase == DAggerPhase.AUTONOMOUS:
|
||||
last_action = None
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
|
||||
# Correction ended -> save episode (blocking if not streaming)
|
||||
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
@@ -657,13 +608,10 @@ class DAggerStrategy(RolloutStrategy):
|
||||
|
||||
phase = events.phase
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
|
||||
# --- CORRECTING: human teleop control + recording ---
|
||||
# TODO(Steven): teleop runs at the same FPS as the policy. To
|
||||
# decouple the two, sample teleop at its native rate and
|
||||
# interpolate to the control loop's tick rate.
|
||||
if phase == DAggerPhase.CORRECTING:
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
teleop_action = teleop.get_action()
|
||||
processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs))
|
||||
@@ -671,9 +619,9 @@ class DAggerStrategy(RolloutStrategy):
|
||||
last_action = robot_action_to_send
|
||||
self._log_telemetry(obs_processed, processed_teleop, ctx.runtime)
|
||||
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
dataset.add_frame(
|
||||
{
|
||||
**obs_frame,
|
||||
@@ -691,7 +639,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
|
||||
# --- AUTONOMOUS: policy control (no recording) ---
|
||||
else:
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
@@ -712,6 +660,9 @@ class DAggerStrategy(RolloutStrategy):
|
||||
finally:
|
||||
logger.info("DAgger corrections-only loop ended — pausing engine")
|
||||
engine.pause()
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
@@ -728,71 +679,35 @@ class DAggerStrategy(RolloutStrategy):
|
||||
new_phase: DAggerPhase,
|
||||
engine,
|
||||
interpolator,
|
||||
ctx: RolloutContext,
|
||||
prev_action: dict | None,
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
) -> None:
|
||||
"""Execute side-effects for a validated phase transition, including smooth handovers.
|
||||
|
||||
AUTONOMOUS -> PAUSED (actuated teleop):
|
||||
Pause the engine, then drive the leader arm to the follower's last
|
||||
commanded position so the operator takes over without a jerk.
|
||||
|
||||
PAUSED -> CORRECTING (non-actuated teleop):
|
||||
Slide the follower to the teleop's current pose so the robot meets
|
||||
the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
CORRECTING -> PAUSED (actuated teleop):
|
||||
Re-enable torque to hold position after correction.
|
||||
This will be potentially useful if cancelling the correction recording
|
||||
|
||||
PAUSED -> AUTONOMOUS:
|
||||
Reset and resume the inference engine.
|
||||
"""
|
||||
teleop = ctx.hardware.teleop
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
|
||||
"""Execute side-effects for a validated phase transition."""
|
||||
logger.info("Phase transition: %s -> %s", old_phase.value, new_phase.value)
|
||||
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
|
||||
logger.info("Pausing engine - robot holds position")
|
||||
logger.info("Pausing engine — robot holds position")
|
||||
engine.pause()
|
||||
obs = robot.get_observation()
|
||||
_robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# Consider also a method that moves the robot to the teleop smoothly (similar to what we do at HW shutdown).
|
||||
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
_teleop_smooth_move_to(teleop, prev_action)
|
||||
|
||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode - human teleop control")
|
||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
obs = robot.get_observation()
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
_follower_smooth_move_to(robot, prev_action, target)
|
||||
|
||||
# unlock the teleop for human control
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.enable_torque()
|
||||
elif new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode — human teleop control")
|
||||
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
|
||||
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
|
||||
# teleop.disable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
logger.info("Resuming autonomous mode - resetting engine and interpolator")
|
||||
logger.info("Resuming autonomous mode — resetting engine and interpolator")
|
||||
interpolator.reset()
|
||||
engine.reset()
|
||||
engine.resume()
|
||||
|
||||
# release teleop before resuming the policy
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background push (shared by both modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -64,8 +64,8 @@ class HighlightStrategy(RolloutStrategy):
|
||||
3. The episode is saved and the ring buffer resumes capturing.
|
||||
|
||||
Requires ``streaming_encoding=True`` (enforced in config validation)
|
||||
so that ``dataset.add_frame`` is a non-blocking queue put — flushing
|
||||
the entire ring buffer in one tick must not stall the control loop.
|
||||
so that ``dataset.add_frame`` is a non-blocking queue put — draining
|
||||
900 frames stays sub-ms per frame.
|
||||
"""
|
||||
|
||||
config: HighlightStrategyConfig
|
||||
@@ -135,7 +135,8 @@ class HighlightStrategy(RolloutStrategy):
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
@@ -227,10 +228,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=ctx.runtime.cfg.return_to_initial_position,
|
||||
)
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Highlight strategy teardown complete")
|
||||
|
||||
def _setup_keyboard(self, shutdown_event: ThreadingEvent) -> None:
|
||||
|
||||
@@ -111,7 +111,8 @@ class SentryStrategy(RolloutStrategy):
|
||||
break
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = self._process_observation_and_notify(ctx.processors, obs)
|
||||
obs_processed = ctx.processors.robot_observation_processor(obs)
|
||||
engine.notify_observation(obs_processed)
|
||||
|
||||
if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval):
|
||||
continue
|
||||
@@ -196,10 +197,7 @@ class SentryStrategy(RolloutStrategy):
|
||||
logger.info("Dataset uploaded to hub")
|
||||
log_say("Dataset uploaded to hub", play_sounds)
|
||||
|
||||
self._teardown_hardware(
|
||||
ctx.hardware,
|
||||
return_to_initial_position=ctx.runtime.cfg.return_to_initial_position,
|
||||
)
|
||||
self._teardown_hardware(ctx.hardware)
|
||||
logger.info("Sentry strategy teardown complete")
|
||||
|
||||
def _background_push(self, dataset, cfg) -> None:
|
||||
|
||||
@@ -70,7 +70,6 @@ from lerobot.datasets.io_utils import (
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
load_info,
|
||||
load_json,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -82,11 +81,9 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
INFO_PATH,
|
||||
LEGACY_EPISODES_PATH,
|
||||
LEGACY_EPISODES_STATS_PATH,
|
||||
LEGACY_TASKS_PATH,
|
||||
DatasetInfo,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
@@ -168,7 +165,7 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
def validate_local_dataset_version(local_path: Path) -> None:
|
||||
"""Validate that the local dataset has the expected v2.1 version."""
|
||||
info = load_info(local_path)
|
||||
dataset_version = info.codebase_version or "unknown"
|
||||
dataset_version = info.get("codebase_version", "unknown")
|
||||
if dataset_version != V21:
|
||||
raise ValueError(
|
||||
f"Local dataset has codebase version '{dataset_version}', expected '{V21}'. "
|
||||
@@ -259,14 +256,14 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
def get_video_keys(root):
|
||||
info = load_info(root)
|
||||
features = info.features
|
||||
features = info["features"]
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
return video_keys
|
||||
|
||||
|
||||
def get_image_keys(root):
|
||||
info = load_info(root)
|
||||
features = info.features
|
||||
features = info["features"]
|
||||
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
|
||||
return image_keys
|
||||
|
||||
@@ -437,8 +434,7 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
|
||||
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
# Load as raw dict to remove legacy v2.1 fields before constructing DatasetInfo.
|
||||
info = load_json(root / INFO_PATH)
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = V30
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
@@ -453,9 +449,7 @@ def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
# already has fps in video_info
|
||||
continue
|
||||
info["features"][key]["fps"] = info["fps"]
|
||||
# Convert raw dict to typed DatasetInfo before writing
|
||||
dataset_info = DatasetInfo.from_dict(info)
|
||||
write_info(dataset_info, new_root)
|
||||
write_info(info, new_root)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
|
||||
@@ -150,24 +150,11 @@ Show dataset information without feature details:
|
||||
--operation.type info \
|
||||
--operation.show_features false
|
||||
|
||||
Recompute dataset statistics (saves to lerobot/pusht_recomputed_stats by default):
|
||||
Recompute dataset statistics:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type recompute_stats
|
||||
|
||||
Recompute stats and save to a specific new repo_id:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_new_stats \
|
||||
--operation.type recompute_stats
|
||||
|
||||
Recompute stats in-place (overwrites original dataset stats):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht \
|
||||
--operation.type recompute_stats \
|
||||
--operation.overwrite true
|
||||
|
||||
Recompute stats for relative actions and push to hub:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
@@ -269,7 +256,6 @@ class RecomputeStatsConfig(OperationConfig):
|
||||
relative_exclude_joints: list[str] | None = None
|
||||
chunk_size: int = 50
|
||||
num_workers: int = 0
|
||||
overwrite: bool = False
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@@ -294,30 +280,16 @@ class EditDatasetConfig:
|
||||
push_to_hub: bool = False
|
||||
|
||||
|
||||
def _resolve_io_paths(
|
||||
repo_id: str,
|
||||
new_repo_id: str | None,
|
||||
root: Path | str | None,
|
||||
new_root: Path | str | None,
|
||||
default_new_repo_id: str | None = None,
|
||||
) -> tuple[str, Path, Path]:
|
||||
"""Resolve input/output paths and repo_id for dataset operations.
|
||||
|
||||
Returns (output_repo_id, input_path, output_path) with resolved (symlink-safe) paths.
|
||||
"""
|
||||
input_path = (Path(root) if root else HF_LEROBOT_HOME / repo_id).resolve()
|
||||
output_repo_id = new_repo_id or default_new_repo_id or repo_id
|
||||
output_path = (Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id).resolve()
|
||||
return output_repo_id, input_path, output_path
|
||||
|
||||
|
||||
def get_output_path(
|
||||
repo_id: str,
|
||||
new_repo_id: str | None,
|
||||
root: Path | str | None,
|
||||
new_root: Path | str | None,
|
||||
) -> tuple[str, Path]:
|
||||
output_repo_id, input_path, output_path = _resolve_io_paths(repo_id, new_repo_id, root, new_root)
|
||||
input_path = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
output_repo_id = new_repo_id if new_repo_id else repo_id
|
||||
output_path = Path(new_root) if new_root else HF_LEROBOT_HOME / output_repo_id
|
||||
|
||||
# In case of in-place modification, create a backup of the original dataset (if it exists)
|
||||
if output_path == input_path:
|
||||
@@ -585,39 +557,7 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, RecomputeStatsConfig):
|
||||
raise ValueError("Operation config must be RecomputeStatsConfig")
|
||||
|
||||
# Determine whether this is an in-place operation
|
||||
output_repo_id, input_root, output_root = _resolve_io_paths(
|
||||
cfg.repo_id,
|
||||
cfg.new_repo_id,
|
||||
cfg.root,
|
||||
cfg.new_root,
|
||||
default_new_repo_id=f"{cfg.repo_id}_recomputed_stats",
|
||||
)
|
||||
in_place = output_root == input_root
|
||||
|
||||
if in_place and not cfg.operation.overwrite:
|
||||
raise ValueError(
|
||||
f"recompute_stats would overwrite the dataset in-place at {input_root}. "
|
||||
"Pass --operation.overwrite true to allow in-place modification, "
|
||||
"or use --new_repo_id / --new_root to write to a different location. "
|
||||
f"Default output repo_id when neither is set: '{cfg.repo_id}_recomputed_stats'."
|
||||
)
|
||||
|
||||
if in_place:
|
||||
logging.warning(
|
||||
f"Overwriting dataset stats in-place at {input_root}. The original stats will be lost."
|
||||
)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=input_root)
|
||||
else:
|
||||
logging.info(f"Copying dataset from {input_root} to {output_root}")
|
||||
if output_root.exists():
|
||||
backup_path = output_root.with_name(output_root.name + "_old")
|
||||
logging.warning(f"Output directory {output_root} already exists. Moving to {backup_path}")
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(output_root, backup_path)
|
||||
shutil.copytree(input_root, output_root)
|
||||
dataset = LeRobotDataset(output_repo_id, root=output_root)
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
|
||||
logging.info(f"Recomputing stats for {cfg.repo_id}")
|
||||
if cfg.operation.relative_action:
|
||||
@@ -638,7 +578,7 @@ def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
||||
logging.info(f"Stats written to {dataset.root}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {dataset.repo_id}...")
|
||||
logging.info(f"Pushing to hub as {dataset.meta.repo_id}...")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
|
||||
@@ -389,8 +389,7 @@ def record(
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
|
||||
else:
|
||||
# Reject eval_ prefix — for policy evaluation use lerobot-rollout
|
||||
repo_name = cfg.dataset.repo_id.split("/", 1)[-1]
|
||||
if repo_name.startswith("eval_"):
|
||||
if cfg.dataset.repo_id.startswith("eval_"):
|
||||
raise ValueError(
|
||||
"Dataset names starting with 'eval_' are reserved for policy evaluation. "
|
||||
"lerobot-record is for data collection only. Use lerobot-rollout for policy deployment."
|
||||
|
||||
@@ -47,7 +47,6 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import make_reward_pre_post_processors
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -71,8 +70,8 @@ def update_policy(
|
||||
accelerator: "Accelerator",
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
sample_weighter=None,
|
||||
) -> tuple[MetricsTracker, dict | None]:
|
||||
rabc_weights_provider=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
@@ -88,7 +87,7 @@ def update_policy(
|
||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
|
||||
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@@ -98,31 +97,27 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
if sample_weighter is not None:
|
||||
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
|
||||
# Get RA-BC weights if enabled
|
||||
rabc_batch_weights = None
|
||||
rabc_batch_stats = None
|
||||
if rabc_weights_provider is not None:
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
if sample_weights is not None:
|
||||
# Use per-sample loss for weighted training
|
||||
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
# Get per-sample losses
|
||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||
|
||||
# Weighted loss: each sample's contribution is scaled by its weight.
|
||||
# We divide by weight sum (not batch size) so that if some weights are zero,
|
||||
# the remaining samples contribute proportionally more, preserving gradient scale.
|
||||
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
|
||||
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
||||
# rabc_batch_weights is already normalized to sum to batch_size
|
||||
epsilon = 1e-6
|
||||
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
|
||||
|
||||
# Log weighting statistics
|
||||
if output_dict is None:
|
||||
output_dict = {}
|
||||
for key, value in weight_stats.items():
|
||||
output_dict[f"sample_weight_{key}"] = value
|
||||
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
||||
# Log raw mean weight (before normalization) - this is the meaningful metric
|
||||
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
||||
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||
else:
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
@@ -193,8 +188,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||
force_cpu = cfg.trainable_config.device == "cpu"
|
||||
# Force the device to be CPU when policy.device is set to CPU.
|
||||
force_cpu = cfg.policy.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
@@ -250,49 +245,26 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
if is_main_process:
|
||||
logging.info("Creating reward model")
|
||||
from lerobot.rewards import make_reward_model
|
||||
|
||||
policy = make_reward_model(
|
||||
cfg=cfg.reward_model,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
dataset_meta=dataset.meta,
|
||||
)
|
||||
if not policy.is_trainable:
|
||||
raise ValueError(
|
||||
f"Reward model '{policy.name}' is zero-shot and cannot be trained via lerobot-train. "
|
||||
"Use it directly for inference via compute_reward() (e.g. offline precompute)."
|
||||
)
|
||||
else:
|
||||
if is_main_process:
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
if is_main_process:
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
|
||||
if cfg.peft is not None:
|
||||
if cfg.is_reward_model_training:
|
||||
raise ValueError("PEFT is only supported for policy training. ")
|
||||
from peft import PeftModel
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
# Convert CLI peft config to dict for overrides
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
if isinstance(policy, PeftModel):
|
||||
logging.info("PEFT adapter already loaded from checkpoint, skipping wrap_with_peft.")
|
||||
else:
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
# Wait for all processes to finish model creation before continuing
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
processor_pretrained_path = cfg.policy.pretrained_path
|
||||
if (
|
||||
getattr(active_cfg, "use_relative_actions", False)
|
||||
getattr(cfg.policy, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
@@ -302,15 +274,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
# For SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
if processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
@@ -330,36 +305,38 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(
|
||||
cfg.reward_model,
|
||||
**processor_kwargs,
|
||||
)
|
||||
else:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
# Create sample weighter if configured (e.g., for RA-BC training)
|
||||
sample_weighter = None
|
||||
if cfg.sample_weighting is not None:
|
||||
from lerobot.utils.sample_weighting import make_sample_weighter
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
rabc_weights = None
|
||||
if cfg.use_rabc:
|
||||
from lerobot.utils.rabc import RABCWeights
|
||||
|
||||
if is_main_process:
|
||||
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
|
||||
sample_weighter = make_sample_weighter(
|
||||
cfg.sample_weighting,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=cfg.dataset.root,
|
||||
dataset_repo_id=cfg.dataset.repo_id,
|
||||
# Get chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError("Chunk size is not found in policy config")
|
||||
|
||||
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
||||
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
||||
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
||||
rabc_weights = RABCWeights(
|
||||
progress_path=cfg.rabc_progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=head_mode,
|
||||
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
||||
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
||||
device=device,
|
||||
)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
@@ -388,13 +365,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
@@ -471,7 +448,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
sample_weighter=sample_weighter,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -490,10 +467,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
# Log RA-BC statistics if enabled
|
||||
if rabc_weights is not None:
|
||||
rabc_stats = rabc_weights.get_stats()
|
||||
wandb_log_dict.update(
|
||||
{
|
||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||
"rabc_delta_std": rabc_stats["delta_std"],
|
||||
"rabc_num_frames": rabc_stats["num_frames"],
|
||||
}
|
||||
)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
@@ -575,15 +558,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
logging.info("End of training")
|
||||
|
||||
if getattr(active_cfg, "push_to_hub", False):
|
||||
unwrapped_model = accelerator.unwrap_model(policy)
|
||||
# PEFT only applies when training a policy — reward models use the plain path.
|
||||
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
||||
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
||||
if cfg.policy.push_to_hub:
|
||||
unwrapped_policy = accelerator.unwrap_model(policy)
|
||||
if cfg.policy.use_peft:
|
||||
unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy)
|
||||
else:
|
||||
unwrapped_model.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||
unwrapped_policy.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
|
||||
# Properly clean up the distributed process group
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -49,7 +49,6 @@ class BiOpenArmLeader(Teleoperator):
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
manual_control=config.left_arm_config.manual_control,
|
||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
)
|
||||
@@ -64,7 +63,6 @@ class BiOpenArmLeader(Teleoperator):
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
manual_control=config.right_arm_config.manual_control,
|
||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user