mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
51 Commits
feat/depth
...
feat/leade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c444302c1 | ||
|
|
c868f874f1 | ||
|
|
e228f0880f | ||
|
|
fe2c32d9e7 | ||
|
|
6ed80f5a59 | ||
|
|
ef6b3b5b0f | ||
|
|
e298474bf3 | ||
|
|
577f14337a | ||
|
|
47be90f040 | ||
|
|
47dd65347e | ||
|
|
fd5a788120 | ||
|
|
9ce9e01469 | ||
|
|
21c16a27f0 | ||
|
|
b3164543f4 | ||
|
|
f3993cbbb1 | ||
|
|
c278cfa026 | ||
|
|
77d18659b1 | ||
|
|
6347edefb1 | ||
|
|
eda47eca18 | ||
|
|
a64e6f5070 | ||
|
|
3def86c2c3 | ||
|
|
356a64d8c4 | ||
|
|
38b88c414c | ||
|
|
1ed32210c7 | ||
|
|
06255996ea | ||
|
|
8065bf15c7 | ||
|
|
8191d2d87f | ||
|
|
6b93f31238 | ||
|
|
a4c0c9e358 | ||
|
|
a84b0e8132 | ||
|
|
2487a6ee6d | ||
|
|
72fb0faf62 | ||
|
|
2c97cb23c8 | ||
|
|
87d4c9879c | ||
|
|
e4c1a8472d | ||
|
|
d7e25c8326 | ||
|
|
a5ad273b62 | ||
|
|
23bece96a4 | ||
|
|
7a1c9e74c3 | ||
|
|
c88cf979f1 | ||
|
|
79a9ebdaa6 | ||
|
|
da6e36fd03 | ||
|
|
64dc08cb7b | ||
|
|
e6d282108d | ||
|
|
a8838c081b | ||
|
|
ee0814ef60 | ||
|
|
7b0bdf2a98 | ||
|
|
9422dc98c2 | ||
|
|
11a0b0174f | ||
|
|
036b310a97 | ||
|
|
e022207c75 |
@@ -33,7 +33,7 @@ jobs:
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
|
||||
with:
|
||||
package_name: lerobot
|
||||
secrets:
|
||||
|
||||
4
.github/workflows/documentation.yml
vendored
4
.github/workflows/documentation.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/templates/lerobot_rewardmodel_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -39,7 +39,6 @@ from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoEncoderConfig,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
)
|
||||
@@ -252,13 +251,10 @@ def benchmark_encoding_decoding(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
camera_encoder_config=VideoEncoderConfig(
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
preset=encoding_cfg.get("preset"),
|
||||
),
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -90,6 +90,6 @@ lerobot-record \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
@@ -194,7 +194,7 @@ lerobot-record \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ lerobot-record \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
|
||||
@@ -820,10 +820,10 @@ The LeRobot system uses a distributed actor-learner architecture for training. T
|
||||
|
||||
Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`.
|
||||
|
||||
1. Configure the policy settings (`type="sac"`, `device`, etc.)
|
||||
1. Configure the policy settings (`type="gaussian_actor"`, `device`, etc.)
|
||||
2. Set `dataset` to your cropped dataset
|
||||
3. Configure environment settings with crop parameters
|
||||
4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79).
|
||||
4. Check the other parameters related to the Gaussian Actor in [configuration_gaussian_actor.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/gaussian_actor/configuration_gaussian_actor.py#L79).
|
||||
5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task.
|
||||
|
||||
**Starting the Learner**
|
||||
@@ -926,7 +926,7 @@ The ideal behaviour is that your intervention rate should drop gradually during
|
||||
|
||||
Some configuration values have a disproportionate impact on training stability and speed:
|
||||
|
||||
- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`temperature_init`** (`algorithm.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning.
|
||||
- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency.
|
||||
- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second.
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -278,6 +278,6 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -193,7 +193,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
@@ -43,7 +43,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -203,7 +203,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ This ensures identical task states map to consistent progress values, even acros
|
||||
|
||||
## Inputs and Targets (What the new code expects)
|
||||
|
||||
SARM is trained through its processor (`src/lerobot/rewards/sarm/processor_sarm.py`), which:
|
||||
SARM is trained through its processor (`src/lerobot/policies/sarm/processor_sarm.py`), which:
|
||||
|
||||
- **Encodes** images and task text with CLIP (ViT-B/32) into `video_features` and `text_features`
|
||||
- **Pads/truncates** robot state into `state_features` (up to `max_state_dim`)
|
||||
@@ -347,7 +347,7 @@ Use `compute_rabc_weights.py` with `--visualize-only` to visualize model predict
|
||||
<hfoption id="single_stage">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -360,7 +360,7 @@ python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
<hfoption id="dense_only">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -373,7 +373,7 @@ python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
<hfoption id="dual">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--visualize-only \
|
||||
@@ -429,7 +429,7 @@ The weighting follows **Equations 8-9** from the paper:
|
||||
First, run the SARM model on all frames in your dataset to compute progress values:
|
||||
|
||||
```bash
|
||||
python -m lerobot.rewards.sarm.compute_rabc_weights \
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \
|
||||
--dataset-repo-id your-username/your-dataset \
|
||||
--reward-model-path your-username/sarm-model \
|
||||
--head-mode sparse \
|
||||
@@ -465,15 +465,15 @@ This script:
|
||||
|
||||
### Step 5b: Train Policy with RA-BC
|
||||
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.head_mode=sparse \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--use_rabc=true \
|
||||
--rabc_head_mode=sparse \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -488,13 +488,12 @@ The training script automatically:
|
||||
|
||||
**RA-BC Arguments:**
|
||||
|
||||
| Argument | Description | Default |
|
||||
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
|
||||
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
|
||||
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
|
||||
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
|
||||
| Argument | Description | Default |
|
||||
| ---------------------- | ---------------------------------------------------------- | ---------------------------------- |
|
||||
| `--use_rabc` | Enable RA-BC sample weighting | `false` |
|
||||
| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset |
|
||||
| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
|
||||
| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` |
|
||||
|
||||
### Tuning RA-BC Kappa
|
||||
|
||||
@@ -512,30 +511,30 @@ The `kappa` parameter is the threshold that determines which samples get full we
|
||||
|
||||
Monitor these WandB metrics during training:
|
||||
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ----------------------------- | ------------- | ------------------------- |
|
||||
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `sample_weighting/delta_mean` | > 0 | Should be positive |
|
||||
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
|
||||
| Metric | Healthy Range | Problem Indicator |
|
||||
| ------------------ | ------------- | ------------------------- |
|
||||
| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
|
||||
| `rabc_delta_mean` | > 0 | Should be positive |
|
||||
| `rabc_delta_std` | > 0 | Variance in data quality |
|
||||
|
||||
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
|
||||
|
||||
**Setting kappa based on your data:**
|
||||
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`:
|
||||
The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`:
|
||||
|
||||
```
|
||||
# If delta_mean ≈ 0.03 and delta_std ≈ 0.02:
|
||||
# Most deltas fall in range [0.01, 0.05]
|
||||
|
||||
# Option 1: Set kappa = delta_mean (medium selectivity)
|
||||
--sample_weighting.kappa=0.03
|
||||
--rabc_kappa=0.03
|
||||
|
||||
# Option 2: Set kappa = delta_mean + delta_std (high selectivity)
|
||||
--sample_weighting.kappa=0.05
|
||||
--rabc_kappa=0.05
|
||||
|
||||
# Option 3: Set kappa = delta_mean + 2*delta_std (very selective)
|
||||
--sample_weighting.kappa=0.07
|
||||
--rabc_kappa=0.07
|
||||
```
|
||||
|
||||
**When RA-BC may not help:**
|
||||
@@ -551,8 +550,8 @@ accelerate launch \
|
||||
src/lerobot/scripts/lerobot_train.py \
|
||||
--dataset.repo_id=your-username/your-dataset \
|
||||
--policy.type=pi0 \
|
||||
--sample_weighting.type=rabc \
|
||||
--sample_weighting.kappa=0.01 \
|
||||
--use_rabc=true \
|
||||
--rabc_kappa=0.01 \
|
||||
--output_dir=outputs/train/policy_rabc \
|
||||
--batch_size=32 \
|
||||
--steps=40000
|
||||
@@ -577,7 +576,7 @@ accelerate launch \
|
||||
### RA-BC
|
||||
|
||||
1. **Train SARM first**: RA-BC quality depends entirely on SARM quality
|
||||
2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa))
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
|
||||
@@ -14,22 +14,12 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
|
||||
|
||||
## 2. Tuning Parameters
|
||||
|
||||
All encoding parameters are grouped under `camera_encoder_config` (a `VideoEncoderConfig` dataclass), accessible from the CLI via `--dataset.camera_encoder_config.<field>`.
|
||||
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------------------- | ------------- | ------------- | ------------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.camera_encoder_config.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `pix_fmt` | `--dataset.camera_encoder_config.pix_fmt` | `str` | `"yuv420p"` | Pixel format |
|
||||
| `g` | `--dataset.camera_encoder_config.g` | `int \| None` | `2` | GOP size (keyframe interval) |
|
||||
| `crf` | `--dataset.camera_encoder_config.crf` | `int \| None` | `30` | Quality level (mapped to codec-specific parameter) |
|
||||
| `preset` | `--dataset.camera_encoder_config.preset` | `int \| None` | `12` | Speed preset (libsvtav1 only, 0 = slowest … 13 = fastest) |
|
||||
| `fast_decode` | `--dataset.camera_encoder_config.fast_decode` | `int` | `0` | Fast-decode tuning level |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance (global). `None` lets the codec decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
|
||||
> [!TIP]
|
||||
> Not all parameters apply to every codec. `VideoEncoderConfig` will warn at startup if you set a parameter that your chosen codec ignores (e.g. `preset` with `h264_nvenc`).
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
@@ -50,7 +40,7 @@ Streaming encoding means the CPU is encoding video **during** the capture loop,
|
||||
|
||||
### `encoder_threads` Tuning
|
||||
|
||||
This parameter (`--dataset.encoder_threads`) controls how many threads each encoder instance uses internally:
|
||||
This parameter controls how many threads each encoder instance uses internally:
|
||||
|
||||
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
|
||||
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
|
||||
@@ -92,15 +82,15 @@ Use HW encoding when:
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ---------------------------------------------------------- |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder_config.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder_config.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder_config.vcodec=auto` |
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
@@ -110,15 +100,15 @@ Use HW encoding when:
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder_config.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder_config.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder_config.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
@@ -156,10 +146,10 @@ On very constrained systems, streaming encoding may compete too heavily with the
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.camera_encoder_config.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
|
||||
`camera_encoder_config.vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
|
||||
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
|
||||
|
||||
@@ -117,10 +117,10 @@ lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.camera_encoder_config.vcodec libsvtav1 \
|
||||
--operation.camera_encoder_config.pix_fmt yuv420p \
|
||||
--operation.camera_encoder_config.g 2 \
|
||||
--operation.camera_encoder_config.crf 30
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
@@ -147,14 +147,11 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `camera_encoder_config`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder_config.<field>`:
|
||||
- `vcodec`: Video codec — `h264`, `hevc`, `libsvtav1`, `auto`, or hardware codecs (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format — `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: GOP size — lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Quality level — lower is better, 0 is lossless (default: 30)
|
||||
- `preset`: Speed preset, libsvtav1 only (default: 12)
|
||||
- `fast_decode`: Fast-decode tuning (default: 0)
|
||||
- `encoder_threads`: Threads per encoder instance — global setting, separate from `camera_encoder_config` (default: None)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class ComputeProgressShards(PipelineStep):
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.rewards.sarm.compute_rabc_weights import (
|
||||
from lerobot.policies.sarm.compute_rabc_weights import (
|
||||
generate_all_frame_indices,
|
||||
interpolate_progress,
|
||||
load_sarm_resources,
|
||||
|
||||
175
examples/so100_teleop/teleop.py
Normal file
175
examples/so100_teleop/teleop.py
Normal file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Simple SO100/SO101 leader-follower teleoperation with spacebar intervention toggle.
|
||||
|
||||
Modes:
|
||||
- Default (not intervening): follower holds its current position.
|
||||
The leader arm has torque ENABLED and mirrors the follower so there is no
|
||||
large position jump when intervention starts.
|
||||
- Intervention (SPACE pressed): leader torque DISABLED, human moves the leader
|
||||
freely, and the follower mirrors the leader joint-by-joint.
|
||||
|
||||
Usage:
|
||||
uv run python examples/so100_teleop/teleop.py
|
||||
|
||||
Controls:
|
||||
SPACE — toggle intervention on/off
|
||||
Ctrl+C — exit
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from threading import Event, Thread
|
||||
|
||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||
from lerobot.teleoperators.so_leader import SO101Leader
|
||||
from lerobot.teleoperators.so_leader.config_so_leader import SOLeaderTeleopConfig
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── pynput keyboard listener ─────────────────────────────────────────────────
|
||||
PYNPUT_AVAILABLE = True
|
||||
try:
|
||||
if "DISPLAY" not in os.environ and "linux" in sys.platform:
|
||||
raise ImportError("No DISPLAY set, pynput skipped.")
|
||||
from pynput import keyboard as pynput_keyboard
|
||||
except Exception:
|
||||
pynput_keyboard = None
|
||||
PYNPUT_AVAILABLE = False
|
||||
|
||||
# ── Configure ports ──────────────────────────────────────────────────────────
|
||||
FOLLOWER_PORT = "/dev/ttyUSB0" # ← change to your follower port
|
||||
LEADER_PORT = "/dev/ttyUSB1" # ← change to your leader port
|
||||
FPS = 30
|
||||
|
||||
|
||||
def hold_position(robot) -> dict:
|
||||
"""Read current joint positions and write them back as the goal.
|
||||
|
||||
This prevents the motors from snapping to a stale Goal_Position register
|
||||
value (which can happen when torque is re-enabled after calibration).
|
||||
Returns the current position dict for reuse.
|
||||
"""
|
||||
current = robot.bus.sync_read("Present_Position")
|
||||
robot.bus.sync_write("Goal_Position", current)
|
||||
return {f"{motor}.pos": val for motor, val in current.items()}
|
||||
|
||||
|
||||
# ── Connect ───────────────────────────────────────────────────────────────────
|
||||
follower_config = SO101FollowerConfig(
|
||||
port=FOLLOWER_PORT,
|
||||
id="follower_arm",
|
||||
use_degrees=True,
|
||||
)
|
||||
leader_config = SOLeaderTeleopConfig(
|
||||
port=LEADER_PORT,
|
||||
id="leader_arm",
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
follower = SO101Follower(follower_config)
|
||||
leader = SO101Leader(leader_config)
|
||||
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
# ── CRITICAL: hold both arms at their current position before doing anything ─
|
||||
# configure() enables follower torque, and the Goal_Position register may contain
|
||||
# a stale value from a previous session. Writing current→goal prevents sudden motion.
|
||||
follower_current = hold_position(follower)
|
||||
leader_current = hold_position(leader) # leader torque is still off here, but sets the register
|
||||
|
||||
# ── Intervention state + keyboard listener ───────────────────────────────────
|
||||
is_intervening = False
|
||||
stop_event = Event()
|
||||
|
||||
|
||||
def _start_keyboard_listener():
|
||||
if not PYNPUT_AVAILABLE:
|
||||
logger.warning("pynput not available — spacebar toggle disabled.")
|
||||
return None
|
||||
|
||||
def on_press(key):
|
||||
global is_intervening
|
||||
if key == pynput_keyboard.Key.space:
|
||||
is_intervening = not is_intervening
|
||||
state = "INTERVENTION (leader → follower)" if is_intervening else "IDLE (follower holds)"
|
||||
print(f"\n[SPACE] {state}\n")
|
||||
|
||||
def listen():
|
||||
with pynput_keyboard.Listener(on_press=on_press) as listener:
|
||||
while not stop_event.is_set():
|
||||
time.sleep(0.05)
|
||||
listener.stop()
|
||||
|
||||
t = Thread(target=listen, daemon=True)
|
||||
t.start()
|
||||
return t
|
||||
|
||||
|
||||
kbd_thread = _start_keyboard_listener()
|
||||
|
||||
# Enable leader torque AFTER writing its goal to current position, so it holds in place.
|
||||
leader.bus.sync_write("Torque_Enable", 1)
|
||||
leader_torque_on = True
|
||||
|
||||
print("\nTeleoperation ready.")
|
||||
print(" SPACE → toggle intervention (leader controls follower)")
|
||||
print(" Ctrl+C → exit\n")
|
||||
|
||||
try:
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if is_intervening:
|
||||
# ── Intervention: leader torque OFF, follower mirrors leader ──────
|
||||
if leader_torque_on:
|
||||
leader.bus.sync_write("Torque_Enable", 0)
|
||||
leader_torque_on = False
|
||||
|
||||
leader_action = leader.get_action() # reads present leader joints
|
||||
follower.send_action(leader_action) # follower tracks leader
|
||||
|
||||
else:
|
||||
# ── Idle: leader torque ON, leader mirrors follower, follower holds
|
||||
if not leader_torque_on:
|
||||
# Before re-enabling torque, set the leader's goal to its current
|
||||
# position so it doesn't snap to the follower position suddenly.
|
||||
hold_position(leader)
|
||||
leader.bus.sync_write("Torque_Enable", 1)
|
||||
leader_torque_on = True
|
||||
|
||||
follower_obs = follower.get_observation()
|
||||
# Command leader to match follower (so next intervention has no jump)
|
||||
goal_pos = {motor: follower_obs[f"{motor}.pos"] for motor in leader.bus.motors}
|
||||
leader.bus.sync_write("Goal_Position", goal_pos)
|
||||
# Follower holds — no send_action call
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nExiting...")
|
||||
finally:
|
||||
stop_event.set()
|
||||
leader.bus.sync_write("Torque_Enable", 0)
|
||||
follower.disconnect()
|
||||
leader.disconnect()
|
||||
365
examples/so100_to_so100_EE_deltas/teleoperate.py
Normal file
365
examples/so100_to_so100_EE_deltas/teleoperate.py
Normal file
@@ -0,0 +1,365 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
ProcessorStepRegistry,
|
||||
RobotAction,
|
||||
RobotActionProcessorStep,
|
||||
RobotObservation,
|
||||
RobotProcessorPipeline,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
create_transition,
|
||||
identity_transition,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsRLStep,
|
||||
)
|
||||
from lerobot.robots.so101_follower.config_so101_follower import SO101FollowerConfig
|
||||
from lerobot.robots.so101_follower.so101_follower import SO101Follower
|
||||
from lerobot.teleoperators.so101_leader.config_so101_leader import SO101LeaderConfig
|
||||
from lerobot.teleoperators.so101_leader.so101_leader import SO101Leader
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
|
||||
def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None:
|
||||
"""Reset robot arm to target position using smooth trajectory."""
|
||||
current_position_dict = robot_arm.bus.sync_read("Present_Position")
|
||||
current_position = np.array(
|
||||
[current_position_dict[name] for name in current_position_dict],
|
||||
dtype=np.float32,
|
||||
)
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 50)
|
||||
) # NOTE: 30 is just an arbitrary number
|
||||
for pose in trajectory:
|
||||
action_dict = dict(zip(current_position_dict, pose, strict=False))
|
||||
robot_arm.bus.sync_write("Goal_Position", action_dict)
|
||||
precise_sleep(0.015)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogRobotAction(RobotActionProcessorStep):
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
print(f"Robot action: {action}")
|
||||
return action
|
||||
|
||||
def transform_features(self, features):
|
||||
# features[PipelineFeatureType.ACTION][ACTION] = PolicyFeature(
|
||||
# type=FeatureType.ACTION, shape=(len(self.motor_names),)
|
||||
# )
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee_target_action")
|
||||
@dataclass
|
||||
class ForwardKinematicsJointsToEETargetAction(RobotActionProcessorStep):
|
||||
"""
|
||||
Computes the end-effector pose from joint positions using forward kinematics (FK).
|
||||
|
||||
This step is typically used to add the robot's Cartesian pose to the observation space,
|
||||
which can be useful for visualization or as an input to a policy.
|
||||
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
end_effector_step_sizes: dict
|
||||
max_gripper_pos: float
|
||||
use_ik_solution: bool = False
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
# return compute_forward_kinematics_joints_to_ee(action, self.kinematics, self.motor_names)
|
||||
teleop_action = action
|
||||
raw_joint_pos = self.transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
leader_pos = np.array([teleop_action[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
leader_ee = self.kinematics.forward_kinematics(leader_pos)
|
||||
|
||||
if self.use_ik_solution and "IK_solution" in self.transition.get(TransitionKey.COMPLEMENTARY_DATA):
|
||||
follower_pos = transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"]
|
||||
else:
|
||||
follower_pos = np.array([raw_joint_pos[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
follower_ee = self.kinematics.forward_kinematics(follower_pos)
|
||||
|
||||
follower_ee_pos = follower_ee[:3, 3]
|
||||
follower_ee_rvec = Rotation.from_matrix(follower_ee[:3, :3]).as_rotvec()
|
||||
# follower_gripper_pos = raw_joint_pos["gripper.pos"]
|
||||
follower_gripper_pos = follower_pos[-1] # assuming gripper is the last motor
|
||||
|
||||
leader_ee_pos = leader_ee[:3, 3]
|
||||
leader_ee_rvec = Rotation.from_matrix(leader_ee[:3, :3]).as_rotvec()
|
||||
leader_gripper_pos = np.clip(
|
||||
teleop_action["gripper.pos"], -self.max_gripper_pos, self.max_gripper_pos
|
||||
)
|
||||
|
||||
print("f pos:", follower_ee_pos)
|
||||
print("l pos:", leader_ee_pos)
|
||||
|
||||
print("f rvec:", follower_ee_rvec)
|
||||
print("l rvec:", leader_ee_rvec)
|
||||
|
||||
# follower_ee_pos = follower_ee[:3, 3]
|
||||
# follower_ee_rvec = Rotation.from_matrix(follower_ee[:3, :3]).as_rotvec()
|
||||
|
||||
delta_pos = leader_ee_pos - follower_ee_pos
|
||||
|
||||
# For rotation: compute relative rotation from follower to leader
|
||||
# R_leader = R_follower * R_delta => R_delta = R_follower^T * R_leader
|
||||
r_delta = follower_ee[:3, :3].T @ leader_ee[:3, :3]
|
||||
delta_rvec = Rotation.from_matrix(r_delta).as_rotvec()
|
||||
delta_gripper = leader_gripper_pos - follower_gripper_pos
|
||||
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = follower_ee[:3, :3] @ r_delta
|
||||
desired[:3, 3] = follower_ee[:3, 3] + delta_pos
|
||||
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
|
||||
assert np.allclose(pos, leader_ee_pos), "Position delta computation error"
|
||||
assert np.allclose(tw, leader_ee_rvec), "Orientation delta computation error"
|
||||
assert np.isclose(follower_gripper_pos + delta_gripper, leader_gripper_pos), (
|
||||
"Gripper delta computation error"
|
||||
)
|
||||
|
||||
# Normalize the action to the range [-1, 1]
|
||||
delta_pos = delta_pos / np.array(
|
||||
[
|
||||
self.end_effector_step_sizes["x"],
|
||||
self.end_effector_step_sizes["y"],
|
||||
self.end_effector_step_sizes["z"],
|
||||
]
|
||||
)
|
||||
delta_rvec = delta_rvec / np.array(
|
||||
[
|
||||
self.end_effector_step_sizes["wx"],
|
||||
self.end_effector_step_sizes["wy"],
|
||||
self.end_effector_step_sizes["wz"],
|
||||
]
|
||||
)
|
||||
|
||||
# Check if any of the normalized deltas exceed 1.0
|
||||
|
||||
max_normalized_pos = max(
|
||||
abs(delta_pos[0]),
|
||||
abs(delta_pos[1]),
|
||||
abs(delta_pos[2]),
|
||||
)
|
||||
|
||||
max_normalized_rot = max(
|
||||
abs(delta_rvec[0]),
|
||||
abs(delta_rvec[1]),
|
||||
abs(delta_rvec[2]),
|
||||
)
|
||||
|
||||
# Use the same scaling factor for both position and rotation
|
||||
max_normalized = max(max_normalized_pos, max_normalized_rot)
|
||||
if max_normalized > 1.0:
|
||||
print(f"Warning: EE delta too large, scaling. Max normalized delta: {max_normalized_pos}")
|
||||
print(f"Original delta_pos: {delta_pos}, delta_rvec: {delta_rvec}")
|
||||
# Scale proportionally
|
||||
delta_pos = delta_pos / max_normalized
|
||||
delta_rvec = delta_rvec / max_normalized
|
||||
|
||||
new_action = {}
|
||||
new_action["enabled"] = True
|
||||
new_action["target_x"] = float(delta_pos[0])
|
||||
new_action["target_y"] = float(delta_pos[1])
|
||||
new_action["target_z"] = float(delta_pos[2])
|
||||
new_action["target_wx"] = float(delta_rvec[0])
|
||||
new_action["target_wy"] = float(delta_rvec[1])
|
||||
new_action["target_wz"] = float(delta_rvec[2])
|
||||
new_action["gripper_vel"] = float(
|
||||
np.clip(delta_gripper, -self.max_gripper_pos, self.max_gripper_pos) / self.max_gripper_pos
|
||||
)
|
||||
return new_action
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# TODO: implement feature transformation
|
||||
return features
|
||||
|
||||
|
||||
FPS = 20
|
||||
|
||||
# Initialize the robot and teleoperator config
|
||||
follower_config = SO101FollowerConfig(port="/dev/usb_follower_arm_a", id="follower_arm_a", use_degrees=True)
|
||||
leader_config = SO101LeaderConfig(port="/dev/usb_leader_arm_a", id="leader_arm_a", use_degrees=True)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO101Follower(follower_config)
|
||||
leader = SO101Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="../SO-ARM100/Simulation/SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="../SO-ARM100/Simulation/SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
end_effector_step_sizes = {
|
||||
"x": 0.004,
|
||||
"y": 0.004,
|
||||
"z": 0.004,
|
||||
"wx": 5 * np.pi / 180,
|
||||
"wy": 5 * np.pi / 180,
|
||||
"wz": 5 * np.pi / 180,
|
||||
}
|
||||
|
||||
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
LogRobotAction(),
|
||||
ForwardKinematicsJointsToEETargetAction(
|
||||
kinematics=leader_kinematics_solver,
|
||||
motor_names=list(leader.bus.motors.keys()),
|
||||
end_effector_step_sizes=end_effector_step_sizes,
|
||||
max_gripper_pos=30.0,
|
||||
use_ik_solution=True,
|
||||
),
|
||||
LogRobotAction(),
|
||||
],
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
LogRobotAction(),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=follower_kinematics_solver,
|
||||
# end_effector_step_sizes={"x": 0.006, "y": 0.01, "z": 0.005},
|
||||
end_effector_step_sizes=end_effector_step_sizes,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
use_latched_reference=False,
|
||||
use_ik_solution=True,
|
||||
),
|
||||
LogRobotAction(),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={
|
||||
"min": [-0.05, -0.55, -0.0075],
|
||||
"max": [0.55, 0.55, 0.55],
|
||||
},
|
||||
# end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.05,
|
||||
),
|
||||
LogRobotAction(),
|
||||
GripperVelocityToJoint(
|
||||
clip_max=30.0,
|
||||
speed_factor=0.2,
|
||||
discrete_gripper=False,
|
||||
scale_velocity=True,
|
||||
use_ik_solution=True,
|
||||
),
|
||||
LogRobotAction(),
|
||||
InverseKinematicsRLStep(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
LogRobotAction(),
|
||||
],
|
||||
to_transition=identity_transition,
|
||||
to_output=identity_transition,
|
||||
)
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
reset_pose = [0.0, 10, 20, 60.00, 90.00, 10.00]
|
||||
|
||||
start_time = time.perf_counter()
|
||||
reset_follower_position(follower, np.array(reset_pose))
|
||||
reset_follower_position(leader, np.array(reset_pose))
|
||||
precise_sleep(5.0 - (time.perf_counter() - start_time))
|
||||
# time.sleep(10)
|
||||
leader.bus.sync_write("Torque_Enable", 0)
|
||||
|
||||
# Init rerun viewer
|
||||
# init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
transition = None
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
print("New loop iteration")
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
|
||||
# teleop joints -> teleop EE action
|
||||
if transition is None:
|
||||
transition = create_transition(action=leader_joints_obs, observation=robot_obs)
|
||||
else:
|
||||
transition = create_transition(
|
||||
action=leader_joints_obs,
|
||||
observation=robot_obs,
|
||||
complementary_data=transition.get(TransitionKey.COMPLEMENTARY_DATA),
|
||||
)
|
||||
|
||||
transition = leader_to_ee(transition)
|
||||
leader_ee_act = transition[TransitionKey.ACTION]
|
||||
|
||||
# teleop EE -> robot joints
|
||||
transition = create_transition(
|
||||
action=leader_ee_act,
|
||||
observation=robot_obs,
|
||||
complementary_data=transition.get(TransitionKey.COMPLEMENTARY_DATA),
|
||||
)
|
||||
transition = ee_to_follower_joints(transition)
|
||||
follower_joints_act = transition[TransitionKey.ACTION]
|
||||
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
# Visualize
|
||||
# log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
@@ -4,13 +4,13 @@ from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.policies import GaussianActorConfig
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
@@ -28,7 +28,7 @@ def run_learner(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_learner: SACPolicy,
|
||||
policy_learner: GaussianActorPolicy,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer,
|
||||
lr: float = 3e-4,
|
||||
@@ -40,8 +40,9 @@ def run_learner(
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||
algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
@@ -83,24 +84,26 @@ def run_learner(
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
def batch_iter(b=batch):
|
||||
while True:
|
||||
yield b
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
stats = algorithm.update(batch_iter())
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
log_dict = stats.to_log_dict()
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"[LEARNER] Training step {training_step}, "
|
||||
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
weights = algorithm.get_weights()
|
||||
parameters_queue.put_nowait(weights)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
@@ -113,7 +116,7 @@ def run_actor(
|
||||
transitions_queue: mp.Queue,
|
||||
parameters_queue: mp.Queue,
|
||||
shutdown_event: mp.Event,
|
||||
policy_actor: SACPolicy,
|
||||
policy_actor: GaussianActorPolicy,
|
||||
reward_classifier: Classifier,
|
||||
env_cfg: HILSerlRobotEnvConfig,
|
||||
device: torch.device = "mps",
|
||||
@@ -144,15 +147,15 @@ def run_actor(
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
new_weights = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_weights)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
# Get action from policy (returns full action: continuous + discrete)
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action_tensor = policy_actor.select_action(policy_obs)
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
@@ -261,14 +264,14 @@ def main():
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
policy_cfg = GaussianActorConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
policy_actor = GaussianActorPolicy(policy_cfg)
|
||||
policy_learner = GaussianActorPolicy(policy_cfg)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards import RewardClassifierConfig, make_reward_model, make_reward_pre_post_processors
|
||||
from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
|
||||
|
||||
|
||||
def main():
|
||||
@@ -22,10 +22,10 @@ def main():
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make reward model, preprocessor, and optimizer
|
||||
reward_model = make_reward_model(config, dataset_stats=dataset.meta.stats)
|
||||
optimizer = config.get_optimizer_preset().build(reward_model.parameters())
|
||||
preprocessor, _ = make_reward_pre_post_processors(config, dataset_stats=dataset.meta.stats)
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
|
||||
@@ -42,7 +42,7 @@ def main():
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = reward_model.forward(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -58,8 +58,8 @@ def main():
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained reward model.
|
||||
reward_model.push_to_hub(classifier_id)
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -133,9 +133,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
self.rs_pipeline: rs.pipeline | None = None
|
||||
self.rs_profile: rs.pipeline_profile | None = None
|
||||
# Meters per uint16 unit on the depth stream. Queried from the device
|
||||
# at connect() time. Typical D-series value is 0.001 (= 1 mm/unit).
|
||||
self.depth_scale: float | None = None
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
@@ -193,17 +190,6 @@ class RealSenseCamera(Camera):
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
# Query depth scale (meters per uint16 unit) when depth is enabled so
|
||||
# consumers can convert the raw z16 stream to metric distances.
|
||||
if self.use_depth and self.rs_profile is not None:
|
||||
try:
|
||||
depth_sensor = self.rs_profile.get_device().first_depth_sensor()
|
||||
self.depth_scale = float(depth_sensor.get_depth_scale())
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"{self}: failed to query depth scale ({e}); falling back to 0.001 m/unit.")
|
||||
self.depth_scale = 0.001
|
||||
|
||||
self._start_read_thread()
|
||||
|
||||
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
@@ -546,6 +532,7 @@ class RealSenseCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -588,6 +575,7 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
@@ -623,78 +611,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""Read the latest depth frame asynchronously, in metric meters.
|
||||
|
||||
Mirrors :meth:`async_read` but returns the depth stream rather than the
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W)``.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
the background read thread is not running.
|
||||
TimeoutError: If no frame becomes available within ``timeout_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"{self}: cannot read depth — camera was configured with use_depth=False."
|
||||
)
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if depth_frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
|
||||
|
||||
return depth_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent depth frame in metric meters (peeking).
|
||||
|
||||
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
|
||||
Output is ``np.float32`` of shape ``(H, W)`` in meters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
no depth frame has been captured yet.
|
||||
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"{self}: cannot read depth — camera was configured with use_depth=False."
|
||||
)
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if depth_frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any depth frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return depth_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
@@ -718,8 +634,6 @@ class RealSenseCamera(Camera):
|
||||
self.rs_pipeline = None
|
||||
self.rs_profile = None
|
||||
|
||||
self.depth_scale = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
|
||||
@@ -99,6 +99,7 @@ def save_checkpoint(
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
|
||||
@@ -41,12 +41,8 @@ def cfg_to_group(
|
||||
return tag
|
||||
return tag[:max_tag_length]
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
trainable_tag = f"reward_model:{cfg.reward_model.type}"
|
||||
else:
|
||||
trainable_tag = f"policy:{cfg.policy.type}"
|
||||
lst = [
|
||||
trainable_tag,
|
||||
f"policy:{cfg.policy.type}",
|
||||
f"seed:{cfg.seed}",
|
||||
]
|
||||
if cfg.dataset is not None:
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.transforms import ImageTransformsConfig
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +34,7 @@ class DatasetConfig:
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
T = TypeVar("T", bound="RewardModelConfig")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""Base configuration for reward models.
|
||||
|
||||
Args:
|
||||
input_features: A dictionary defining the PolicyFeature of the input data for the reward. The key represents
|
||||
the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
output_features: A dictionary defining the PolicyFeature of the output data for the reward. The key represents
|
||||
the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes.
|
||||
"""
|
||||
|
||||
# Reuses PolicyFeature
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None
|
||||
|
||||
pretrained_path: str | None = None
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str | None = None
|
||||
|
||||
# Hub metadata
|
||||
license: str | None = None
|
||||
tags: list[str] | None = None
|
||||
private: bool | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logger.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
pass
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict[Any, Any] | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**reward_kwargs: Any,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
if Path(model_id).is_dir():
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
logger.error(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
if config_file is None:
|
||||
raise FileNotFoundError(f"{CONFIG_NAME} not found in {model_id}")
|
||||
|
||||
# HACK: Parse the original config to get the config subclass, so that we can
|
||||
# apply cli overrides.
|
||||
with draccus.config_type("json"):
|
||||
orig_config = draccus.parse(cls, config_file, args=[])
|
||||
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
config.pop("type", None)
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(config, f)
|
||||
config_file = f.name
|
||||
|
||||
cli_overrides = reward_kwargs.pop("cli_overrides", [])
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(orig_config.__class__, config_file, args=cli_overrides)
|
||||
@@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
import builtins
|
||||
import datetime as dt
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -28,57 +26,18 @@ from lerobot import envs
|
||||
from lerobot.configs import parser
|
||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
TRAIN_CONFIG_NAME = "train_config.json"
|
||||
|
||||
|
||||
def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Return migrated payload for legacy RA-BC fields, or None when no migration is needed."""
|
||||
legacy_fields = (
|
||||
"use_rabc",
|
||||
"rabc_progress_path",
|
||||
"rabc_kappa",
|
||||
"rabc_epsilon",
|
||||
"rabc_head_mode",
|
||||
)
|
||||
if not any(key in config for key in legacy_fields):
|
||||
return None
|
||||
|
||||
migrated_config = dict(config)
|
||||
use_rabc = bool(migrated_config.pop("use_rabc", False))
|
||||
rabc_progress_path = migrated_config.pop("rabc_progress_path", None)
|
||||
rabc_kappa = migrated_config.pop("rabc_kappa", None)
|
||||
rabc_epsilon = migrated_config.pop("rabc_epsilon", None)
|
||||
rabc_head_mode = migrated_config.pop("rabc_head_mode", None)
|
||||
|
||||
# New configs may already define sample_weighting explicitly. In that case,
|
||||
# legacy fields are ignored after being stripped from the payload.
|
||||
if migrated_config.get("sample_weighting") is None and use_rabc:
|
||||
sample_weighting: dict[str, Any] = {"type": "rabc"}
|
||||
if rabc_progress_path is not None:
|
||||
sample_weighting["progress_path"] = rabc_progress_path
|
||||
if rabc_kappa is not None:
|
||||
sample_weighting["kappa"] = rabc_kappa
|
||||
if rabc_epsilon is not None:
|
||||
sample_weighting["epsilon"] = rabc_epsilon
|
||||
if rabc_head_mode is not None:
|
||||
sample_weighting["head_mode"] = rabc_head_mode
|
||||
migrated_config["sample_weighting"] = sample_weighting
|
||||
|
||||
return migrated_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainPipelineConfig(HubMixin):
|
||||
dataset: DatasetConfig
|
||||
env: envs.EnvConfig | None = None
|
||||
policy: PreTrainedConfig | None = None
|
||||
reward_model: RewardModelConfig | None = None
|
||||
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
@@ -113,41 +72,27 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
rabc_progress_path: str | None = None # Path to precomputed SARM progress parquet file
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_head_mode: str | None = "sparse" # For dual-head models: "sparse" or "dense"
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
@property
|
||||
def is_reward_model_training(self) -> bool:
|
||||
"""True when the config targets a reward model rather than a policy."""
|
||||
return self.reward_model is not None
|
||||
|
||||
@property
|
||||
def trainable_config(self) -> PreTrainedConfig | RewardModelConfig:
|
||||
"""Return whichever config (policy or reward_model) is active."""
|
||||
if self.is_reward_model_training:
|
||||
return self.reward_model # type: ignore[return-value]
|
||||
return self.policy # type: ignore[return-value]
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
reward_model_path = parser.get_path_arg("reward_model")
|
||||
|
||||
if reward_model_path:
|
||||
cli_overrides = parser.get_cli_overrides("reward_model")
|
||||
self.reward_model = RewardModelConfig.from_pretrained(
|
||||
reward_model_path, cli_overrides=cli_overrides
|
||||
)
|
||||
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||
elif policy_path:
|
||||
if policy_path:
|
||||
# Only load the policy config
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
@@ -163,22 +108,18 @@ class TrainPipelineConfig(HubMixin):
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.pretrained_path = str(policy_dir)
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
|
||||
if self.policy is None and self.reward_model is None:
|
||||
if self.policy is None:
|
||||
raise ValueError(
|
||||
"Neither policy nor reward_model is configured. "
|
||||
"Please specify one with `--policy.path` or `--reward_model.path`."
|
||||
"Policy is not configured. Please specify a pretrained policy with `--policy.path`."
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
self.job_name = f"{self.policy.type}"
|
||||
else:
|
||||
self.job_name = f"{self.env.type}_{active_cfg.type}"
|
||||
self.job_name = f"{self.env.type}_{self.policy.type}"
|
||||
|
||||
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
||||
raise FileExistsError(
|
||||
@@ -196,16 +137,26 @@ class TrainPipelineConfig(HubMixin):
|
||||
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
||||
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
||||
elif self.use_policy_training_preset and not self.resume:
|
||||
self.optimizer = active_cfg.get_optimizer_preset()
|
||||
self.scheduler = active_cfg.get_scheduler_preset()
|
||||
self.optimizer = self.policy.get_optimizer_preset()
|
||||
self.scheduler = self.policy.get_scheduler_preset()
|
||||
|
||||
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||
if self.policy.push_to_hub and not self.policy.repo_id:
|
||||
raise ValueError(
|
||||
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
|
||||
)
|
||||
|
||||
if self.use_rabc and not self.rabc_progress_path:
|
||||
# Auto-detect from dataset path
|
||||
repo_id = self.dataset.repo_id
|
||||
if self.dataset.root:
|
||||
self.rabc_progress_path = str(Path(self.dataset.root) / "sarm_progress.parquet")
|
||||
else:
|
||||
self.rabc_progress_path = f"hf://datasets/{repo_id}/sarm_progress.parquet"
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""Keys for draccus pretrained-path loading."""
|
||||
return ["policy", "reward_model"]
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return draccus.encode(self) # type: ignore[no-any-return] # because of the third-party library draccus uses Any as the return type
|
||||
@@ -256,21 +207,5 @@ class TrainPipelineConfig(HubMixin):
|
||||
) from e
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
if config_file is not None:
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
migrated_config = _migrate_legacy_rabc_fields(config)
|
||||
if migrated_config is not None:
|
||||
with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f:
|
||||
json.dump(migrated_config, f)
|
||||
config_file = f.name
|
||||
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_args)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
@@ -40,21 +40,10 @@ from .io_utils import load_episodes, write_stats
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import (
|
||||
check_video_encoder_config_pyav,
|
||||
detect_available_encoders_pyav,
|
||||
get_codec,
|
||||
)
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
VideoEncodingManager,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
from .video_utils import VideoEncodingManager
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
# and legacy migration constants are intentionally NOT re-exported here.
|
||||
@@ -69,22 +58,15 @@ __all__ = [
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"StreamingLeRobotDataset",
|
||||
"DepthEncoderConfig",
|
||||
"VideoEncoderConfig",
|
||||
"VideoEncodingManager",
|
||||
"camera_encoder_defaults",
|
||||
"depth_encoder_defaults",
|
||||
"add_features",
|
||||
"aggregate_datasets",
|
||||
"aggregate_pipeline_dataset_features",
|
||||
"aggregate_stats",
|
||||
"check_video_encoder_config_pyav",
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"delete_episodes",
|
||||
"detect_available_encoders_pyav",
|
||||
"get_codec",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
"make_dataset",
|
||||
|
||||
@@ -97,8 +97,8 @@ def update_data_df(df, src_meta, dst_meta):
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices.
|
||||
"""
|
||||
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||
df["index"] = df["index"] + dst_meta.info.total_frames
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
df["index"] = df["index"] + dst_meta.info["total_frames"]
|
||||
|
||||
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
|
||||
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
|
||||
@@ -225,9 +225,9 @@ def update_meta_data(
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info.total_frames
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info.total_frames
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info.total_episodes
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
|
||||
return df
|
||||
|
||||
@@ -237,8 +237,8 @@ def aggregate_datasets(
|
||||
aggr_repo_id: str,
|
||||
roots: list[Path] | None = None,
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
@@ -313,8 +313,8 @@ def aggregate_datasets(
|
||||
# to avoid interference between different source datasets
|
||||
data_idx.pop("src_to_dst", None)
|
||||
|
||||
dst_meta.info.total_episodes += src_meta.total_episodes
|
||||
dst_meta.info.total_frames += src_meta.total_frames
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
finalize_aggregation(dst_meta, all_metadata)
|
||||
logging.info("Aggregation complete.")
|
||||
@@ -332,6 +332,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
@@ -416,7 +417,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
compatibility_check=True,
|
||||
)
|
||||
# Update duration of this destination file
|
||||
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||
@@ -640,10 +640,14 @@ def finalize_aggregation(aggr_meta, all_metadata):
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
logging.info("write info")
|
||||
aggr_meta.info.total_tasks = len(aggr_meta.tasks)
|
||||
aggr_meta.info.total_episodes = sum(m.total_episodes for m in all_metadata)
|
||||
aggr_meta.info.total_frames = sum(m.total_frames for m in all_metadata)
|
||||
aggr_meta.info.splits = {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}
|
||||
aggr_meta.info.update(
|
||||
{
|
||||
"total_tasks": len(aggr_meta.tasks),
|
||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
||||
}
|
||||
)
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
logging.info("write stats")
|
||||
|
||||
@@ -37,18 +37,20 @@ from .io_utils import (
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_json,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
check_version_compatibility,
|
||||
get_safe_version,
|
||||
has_legacy_hub_download_metadata,
|
||||
is_valid_version,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import VideoEncoderConfig, get_video_info
|
||||
from .video_utils import get_video_info
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
@@ -226,7 +228,7 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info.codebase_version)
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
"""Return the relative parquet file path for the given episode index.
|
||||
@@ -281,27 +283,27 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info.data_path
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def video_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info.video_path
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
return self.info.robot_type
|
||||
return self.info["robot_type"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info.fps
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
"""All features contained in the dataset."""
|
||||
return self.info.features
|
||||
return self.info["features"]
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
@@ -313,20 +315,6 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos.
|
||||
|
||||
A depth video key is a feature whose ``info`` dict carries
|
||||
``"video.is_depth_map": True`` (set either at creation time by the user
|
||||
or after the first encoded episode by :meth:`update_video_info`).
|
||||
"""
|
||||
return [
|
||||
key
|
||||
for key, ft in self.features.items()
|
||||
if ft["dtype"] == "video" and ft.get("info", {}).get("video.is_depth_map", False)
|
||||
]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
@@ -345,32 +333,32 @@ class LeRobotDatasetMetadata:
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info.total_episodes
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info.total_frames
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info.total_tasks
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of files per chunk."""
|
||||
return self.info.chunks_size
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def data_files_size_in_mb(self) -> int:
|
||||
"""Max size of data file in mega bytes."""
|
||||
return self.info.data_files_size_in_mb
|
||||
return self.info["data_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def video_files_size_in_mb(self) -> int:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info.video_files_size_in_mb
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
@@ -514,48 +502,29 @@ class LeRobotDatasetMetadata:
|
||||
self._save_episode_metadata(episode_dict)
|
||||
|
||||
# Update info
|
||||
self.info.total_episodes += 1
|
||||
self.info.total_frames += episode_length
|
||||
self.info.total_tasks = len(self.tasks)
|
||||
self.info.splits = {"train": f"0:{self.info.total_episodes}"}
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
self.info["total_tasks"] = len(self.tasks)
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
camera_encoder_config: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
existing = self.features[key].get("info") or {}
|
||||
# Repopulate when codec metadata is missing — preserves user-provided
|
||||
# markers like ``video.is_depth_map`` while still recording stream
|
||||
# info on the first episode.
|
||||
if not existing or "video.codec" not in existing:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
stream_info = get_video_info(video_path, camera_encoder_config=camera_encoder_config)
|
||||
merged = {**existing, **stream_info}
|
||||
self.info.features[key]["info"] = merged
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
@@ -577,17 +546,17 @@ class LeRobotDatasetMetadata:
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {chunks_size}")
|
||||
self.info.chunks_size = chunks_size
|
||||
self.info["chunks_size"] = chunks_size
|
||||
|
||||
if data_files_size_in_mb is not None:
|
||||
if data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}")
|
||||
self.info.data_files_size_in_mb = data_files_size_in_mb
|
||||
self.info["data_files_size_in_mb"] = data_files_size_in_mb
|
||||
|
||||
if video_files_size_in_mb is not None:
|
||||
if video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info.video_files_size_in_mb = video_files_size_in_mb
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
@@ -684,7 +653,7 @@ class LeRobotDatasetMetadata:
|
||||
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
||||
"Either remove video features from the features dict, or set 'use_videos=True'."
|
||||
)
|
||||
write_info(obj.info, obj.root)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj._pq_writer = None
|
||||
obj.latest_episode = None
|
||||
|
||||
@@ -32,13 +32,7 @@ from .io_utils import (
|
||||
hf_transform_to_torch,
|
||||
load_nested_dataset,
|
||||
)
|
||||
from .video_utils import decode_depth_frames, decode_video_frames
|
||||
from .depth_utils import (
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
)
|
||||
from .video_utils import decode_video_frames
|
||||
|
||||
|
||||
class DatasetReader:
|
||||
@@ -243,31 +237,17 @@ class DatasetReader:
|
||||
"""
|
||||
ep = self._meta.episodes[ep_idx]
|
||||
|
||||
depth_keys = set(self._meta.depth_keys)
|
||||
|
||||
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
|
||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
if vid_key in depth_keys:
|
||||
feature_info = self._meta.features[vid_key].get("info") or {}
|
||||
frames = decode_depth_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
depth_min=feature_info.get("video.depth_min", DEFAULT_DEPTH_MIN),
|
||||
depth_max=feature_info.get("video.depth_max", DEFAULT_DEPTH_MAX),
|
||||
shift=feature_info.get("video.shift", DEFAULT_DEPTH_SHIFT),
|
||||
use_log=feature_info.get("video.use_log", DEFAULT_DEPTH_USE_LOG),
|
||||
)
|
||||
else:
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
@@ -62,7 +62,7 @@ from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import VideoEncoderConfig, encode_video_frames, get_video_info
|
||||
from .video_utils import encode_video_frames, get_video_info
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -92,7 +92,6 @@ def delete_episodes(
|
||||
episode_indices: list[int],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Delete episodes from a LeRobotDataset and create a new dataset.
|
||||
|
||||
@@ -101,7 +100,6 @@ def delete_episodes(
|
||||
episode_indices: List of episode indices to delete.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
camera_encoder_config: Video encoder settings used when re-encoding video segments (default: :class:`VideoEncoderConfig()`).
|
||||
"""
|
||||
if not episode_indices:
|
||||
raise ValueError("No episodes to delete")
|
||||
@@ -134,7 +132,7 @@ def delete_episodes(
|
||||
|
||||
video_metadata = None
|
||||
if dataset.meta.video_keys:
|
||||
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping, camera_encoder_config)
|
||||
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
|
||||
|
||||
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
|
||||
|
||||
@@ -156,7 +154,6 @@ def split_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
splits: dict[str, float | list[int]],
|
||||
output_dir: str | Path | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> dict[str, LeRobotDataset]:
|
||||
"""Split a LeRobotDataset into multiple smaller datasets.
|
||||
|
||||
@@ -165,7 +162,6 @@ def split_dataset(
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
split names to fractions (must sum to <= 1.0).
|
||||
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
camera_encoder_config: Video encoder settings used when re-encoding video segments (default: :class:`VideoEncoderConfig()`).
|
||||
|
||||
Examples:
|
||||
Split by specific episodes
|
||||
@@ -226,9 +222,7 @@ def split_dataset(
|
||||
|
||||
video_metadata = None
|
||||
if dataset.meta.video_keys:
|
||||
video_metadata = _copy_and_reindex_videos(
|
||||
dataset, new_meta, episode_mapping, camera_encoder_config
|
||||
)
|
||||
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
|
||||
|
||||
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
|
||||
|
||||
@@ -584,7 +578,8 @@ def _keep_episodes_from_video_with_av(
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
@@ -598,10 +593,9 @@ def _keep_episodes_from_video_with_av(
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
camera_encoder_config: Video encoder settings (default: :class:`VideoEncoderConfig()`).
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
from fractions import Fraction
|
||||
|
||||
import av
|
||||
@@ -625,12 +619,12 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Convert fps to Fraction for PyAV compatibility.
|
||||
fps_fraction = Fraction(fps).limit_denominator(1000)
|
||||
v_out = out.add_stream(camera_encoder_config.vcodec, rate=fps_fraction)
|
||||
v_out = out.add_stream(vcodec, rate=fps_fraction)
|
||||
|
||||
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
|
||||
v_out.width = v_in.codec_context.width
|
||||
v_out.height = v_in.codec_context.height
|
||||
v_out.pix_fmt = camera_encoder_config.pix_fmt
|
||||
v_out.pix_fmt = pix_fmt
|
||||
|
||||
# Set time_base to match the frame rate for proper timestamp handling.
|
||||
v_out.time_base = Fraction(1, int(fps))
|
||||
@@ -693,7 +687,8 @@ def _copy_and_reindex_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_mapping: dict[int, int],
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> dict[int, dict]:
|
||||
"""Copy and filter video files, only re-encoding files with deleted episodes.
|
||||
|
||||
@@ -705,13 +700,10 @@ def _copy_and_reindex_videos(
|
||||
src_dataset: Source dataset to copy from
|
||||
dst_meta: Destination metadata object
|
||||
episode_mapping: Mapping from old episode indices to new indices
|
||||
camera_encoder_config: Video encoder settings used when re-encoding segments (default: :class:`VideoEncoderConfig()`).
|
||||
|
||||
Returns:
|
||||
dict mapping episode index to its video metadata (chunk_index, file_index, timestamps)
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
@@ -800,7 +792,8 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
camera_encoder_config,
|
||||
vcodec,
|
||||
pix_fmt,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -904,10 +897,14 @@ def _copy_and_reindex_episodes_metadata(
|
||||
|
||||
dst_meta.finalize()
|
||||
|
||||
dst_meta.info.total_episodes = len(episode_mapping)
|
||||
dst_meta.info.total_frames = total_frames
|
||||
dst_meta.info.total_tasks = len(dst_meta.tasks) if dst_meta.tasks is not None else 0
|
||||
dst_meta.info.splits = {"train": f"0:{len(episode_mapping)}"}
|
||||
dst_meta.info.update(
|
||||
{
|
||||
"total_episodes": len(episode_mapping),
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0,
|
||||
"splits": {"train": f"0:{len(episode_mapping)}"},
|
||||
}
|
||||
)
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
if not all_stats:
|
||||
@@ -1072,20 +1069,21 @@ def _copy_episodes_metadata_and_stats(
|
||||
if episodes_dir.exists():
|
||||
shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True)
|
||||
|
||||
dst_meta.info.total_episodes = src_dataset.meta.total_episodes
|
||||
dst_meta.info.total_frames = src_dataset.meta.total_frames
|
||||
dst_meta.info.total_tasks = src_dataset.meta.total_tasks
|
||||
# Preserve original splits if available, otherwise create default
|
||||
dst_meta.info.splits = (
|
||||
src_dataset.meta.info.splits
|
||||
if src_dataset.meta.info.splits
|
||||
else {"train": f"0:{src_dataset.meta.total_episodes}"}
|
||||
dst_meta.info.update(
|
||||
{
|
||||
"total_episodes": src_dataset.meta.total_episodes,
|
||||
"total_frames": src_dataset.meta.total_frames,
|
||||
"total_tasks": src_dataset.meta.total_tasks,
|
||||
"splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}),
|
||||
}
|
||||
)
|
||||
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get(
|
||||
"info", {}
|
||||
)
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
@@ -1271,7 +1269,11 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1285,7 +1287,11 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
camera_encoder_config: Video encoder settings used for calibration encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1321,7 +1327,11 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1515,7 +1525,7 @@ def modify_tasks(
|
||||
write_tasks(new_task_df, root)
|
||||
|
||||
# Update info.json
|
||||
dataset.meta.info.total_tasks = len(unique_tasks)
|
||||
dataset.meta.info["total_tasks"] = len(unique_tasks)
|
||||
write_info(dataset.meta.info, root)
|
||||
|
||||
# Reload metadata to reflect changes
|
||||
@@ -1639,7 +1649,11 @@ def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1654,7 +1668,11 @@ def convert_image_to_video_dataset(
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
camera_encoder_config: Video encoder settings (default: :class:`VideoEncoderConfig()`).
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
@@ -1663,9 +1681,6 @@ def convert_image_to_video_dataset(
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
@@ -1689,10 +1704,7 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder_config.vcodec}, pixel format: {camera_encoder_config.pix_fmt}, "
|
||||
f"GOP: {camera_encoder_config.g}, CRF: {camera_encoder_config.crf}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1762,7 +1774,11 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1804,7 +1820,11 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1838,10 +1858,10 @@ def convert_image_to_video_dataset(
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info.total_episodes = len(episode_indices)
|
||||
new_meta.info.total_frames = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info.total_tasks = dataset.meta.total_tasks
|
||||
new_meta.info.splits = {"train": f"0:{len(episode_indices)}"}
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values())
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
@@ -1850,9 +1870,7 @@ def convert_image_to_video_dataset(
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder_config=camera_encoder_config
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
|
||||
@@ -46,19 +46,15 @@ from .io_utils import (
|
||||
write_info,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_DEPTH_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
concatenate_video_files,
|
||||
encode_video_frames,
|
||||
get_video_duration_in_s,
|
||||
is_depth_feature,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -69,19 +65,14 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
@@ -98,40 +89,33 @@ class DatasetWriter:
|
||||
self,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
vcodec: str,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
initial_frames: int = 0,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
):
|
||||
"""Initialize the writer with metadata, codec, and encoder config.
|
||||
"""Initialize the writer with metadata, codec, and encoding config.
|
||||
|
||||
Args:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
camera_encoder_config: Video encoder settings applied to all cameras.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
|
||||
for real-time encoding. ``None`` disables streaming mode.
|
||||
initial_frames: Starting frame count (non-zero when resuming).
|
||||
depth_encoder_config: Optional depth-map encoder config used in
|
||||
place of ``camera_encoder_config`` for keys present in
|
||||
``meta.depth_keys``.
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._vcodec = vcodec
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
|
||||
|
||||
# Writer state
|
||||
self.image_writer: AsyncImageWriter | None = None
|
||||
self.episode_buffer: dict = self._create_episode_buffer()
|
||||
@@ -151,16 +135,8 @@ class DatasetWriter:
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _is_depth_image_key(self, image_key: str) -> bool:
|
||||
"""Whether *image_key* is a depth feature stored as per-frame images."""
|
||||
ft = self._meta.features.get(image_key)
|
||||
if ft is None or ft.get("dtype") != "image":
|
||||
return False
|
||||
return is_depth_feature(ft.get("info") or {})
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
path_template = DEFAULT_DEPTH_PATH if self._is_depth_image_key(image_key) else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
@@ -308,7 +284,7 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder_config,
|
||||
self._vcodec,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -519,13 +495,7 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
is_depth_key = video_key in set(self._meta.depth_keys)
|
||||
cfg_for_info = (
|
||||
self._depth_encoder_config
|
||||
if is_depth_key and self._depth_encoder_config is not None
|
||||
else self._camera_encoder_config
|
||||
)
|
||||
self._meta.update_video_info(video_key, camera_encoder_config=cfg_for_info)
|
||||
self._meta.update_video_info(video_key)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -594,12 +564,7 @@ class DatasetWriter:
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
return _encode_video_worker(
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder_config,
|
||||
self._encoder_threads,
|
||||
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
|
||||
)
|
||||
|
||||
def close_writer(self) -> None:
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
DEPTH_QUANT_BITS: int = 12
|
||||
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
|
||||
_MM_PER_METRE: float = 1000.0
|
||||
_UINT16_MAX: int = 65535
|
||||
|
||||
DEFAULT_DEPTH_MIN: float = 0.01
|
||||
DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
if depth_min + shift <= 0:
|
||||
raise ValueError(
|
||||
f"depth_min + shift must be positive for logarithmic quantization, "
|
||||
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
|
||||
)
|
||||
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
"""Depth as float32 in the chosen unit, plus the resolved unit."""
|
||||
if isinstance(depth, torch.Tensor):
|
||||
t = depth.detach().cpu()
|
||||
arr = t.numpy()
|
||||
is_floating = t.is_floating_point()
|
||||
else:
|
||||
arr = np.asarray(depth)
|
||||
is_floating = np.issubdtype(arr.dtype, np.floating)
|
||||
|
||||
resolved_unit: Literal["m", "mm"]
|
||||
if input_unit == "auto":
|
||||
resolved_unit = "m" if is_floating else "mm"
|
||||
else:
|
||||
resolved_unit = input_unit
|
||||
|
||||
# Convert to float32 to keep typing consistency
|
||||
return np.asarray(arr, dtype=np.float32, order="K"), resolved_unit
|
||||
|
||||
|
||||
def quantize_depth(
|
||||
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
*,
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16]:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
Depth maps are packed into 12-bit integer frames so they fit in standard
|
||||
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
|
||||
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
|
||||
Logarithmic quantization is the default because it allocates more quanta
|
||||
to near-range depth, which matches the (1/depth) error profile of typical
|
||||
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
|
||||
|
||||
**Input units**:
|
||||
|
||||
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
|
||||
- ``input_unit="mm"``: interpret input values as millimetres.
|
||||
- ``input_unit="m"``: interpret input values as metres.
|
||||
|
||||
Quantization math runs in the **resolved input unit**.
|
||||
|
||||
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
|
||||
|
||||
Args:
|
||||
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
|
||||
depth_min: Depth (metres) at quantum ``0``.
|
||||
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
|
||||
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
|
||||
use_log: If ``True`` (default), quantize in log space.
|
||||
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
|
||||
|
||||
Returns:
|
||||
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
|
||||
``[0, DEPTH_QMAX]``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_u + shift_u))
|
||||
log_max = math.log(float(depth_max_u + shift_u))
|
||||
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
|
||||
else:
|
||||
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
|
||||
|
||||
out = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX)
|
||||
return out.astype(np.uint16, copy=False)
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
*,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32]:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
|
||||
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
|
||||
the requested output unit.
|
||||
|
||||
Args:
|
||||
quantized: 12-bit codes ``[0, DEPTH_QMAX]``, ``dtype=uint16``.
|
||||
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
|
||||
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
|
||||
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
ValueError: If ``output_unit`` is not ``\"m\"`` or ``\"mm\"``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
quantized = quantized.detach().cpu().numpy()
|
||||
q = np.asarray(quantized, dtype=np.uint16, order="K")
|
||||
norm = q.astype(np.float32, copy=False) / DEPTH_QMAX
|
||||
|
||||
depth_min_mm = np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_mm = np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_mm = np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_mm + shift_mm))
|
||||
log_max = math.log(float(depth_max_mm + shift_mm))
|
||||
depth_mm = np.exp(norm * (log_max - log_min) + log_min) - shift_mm
|
||||
else:
|
||||
depth_mm = norm * (depth_max_mm - depth_min_mm) + depth_min_mm
|
||||
|
||||
depth_mm = np.clip(depth_mm, depth_min_mm, depth_max_mm).astype(np.float32, copy=False)
|
||||
if output_unit == "m":
|
||||
return (depth_mm / np.float32(_MM_PER_METRE)).astype(np.float32, copy=False)
|
||||
mm = np.rint(depth_mm).clip(0, _UINT16_MAX)
|
||||
return mm.astype(np.uint16, copy=False)
|
||||
@@ -19,7 +19,6 @@ from pprint import pformat
|
||||
import torch
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
|
||||
@@ -31,14 +30,12 @@ from .streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
|
||||
def resolve_delta_timestamps(
|
||||
cfg: PreTrainedConfig | RewardModelConfig, ds_meta: LeRobotDatasetMetadata
|
||||
cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
|
||||
) -> dict[str, list] | None:
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the config.
|
||||
"""Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig | RewardModelConfig): The config to read delta_indices from. Both
|
||||
``PreTrainedConfig`` and concrete ``RewardModelConfig`` subclasses expose the
|
||||
``{observation,action,reward}_delta_indices`` properties used below.
|
||||
cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
|
||||
ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
|
||||
delta_timestamps against.
|
||||
|
||||
@@ -85,7 +82,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
ds_meta = LeRobotDatasetMetadata(
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, ds_meta)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
if not cfg.dataset.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
|
||||
@@ -28,7 +28,6 @@ from .utils import (
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
DatasetInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -79,8 +78,8 @@ def create_empty_dataset_info(
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> DatasetInfo:
|
||||
"""Create a template ``DatasetInfo`` object for a new dataset's ``meta/info.json``.
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
Args:
|
||||
codebase_version (str): The version of the LeRobot codebase.
|
||||
@@ -88,24 +87,25 @@ def create_empty_dataset_info(
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
chunks_size (int | None): Max files per chunk directory. Defaults to ``DEFAULT_CHUNK_SIZE``.
|
||||
data_files_size_in_mb (int | None): Max parquet file size in MB. Defaults to ``DEFAULT_DATA_FILE_SIZE_IN_MB``.
|
||||
video_files_size_in_mb (int | None): Max video file size in MB. Defaults to ``DEFAULT_VIDEO_FILE_SIZE_IN_MB``.
|
||||
|
||||
Returns:
|
||||
DatasetInfo: A typed dataset information object with initial metadata.
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
"""
|
||||
return DatasetInfo(
|
||||
codebase_version=codebase_version,
|
||||
fps=fps,
|
||||
features=features,
|
||||
robot_type=robot_type,
|
||||
chunks_size=chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
data_files_size_in_mb=data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb=video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_path=DEFAULT_DATA_PATH,
|
||||
video_path=DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
@@ -294,20 +294,10 @@ def validate_feature_image_or_video(
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = tuple(value.shape)
|
||||
expected = tuple(expected_shape)
|
||||
if len(expected) == 2:
|
||||
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
|
||||
h, w = expected
|
||||
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
|
||||
if not valid:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
|
||||
elif len(expected) == 3:
|
||||
c, h, w = expected
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -41,56 +41,15 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
# Single-channel dtypes that PIL natively maps to the matching mode
|
||||
# (``uint8`` → ``L``, ``uint16`` → ``I;16``, ``float32`` → ``F``).
|
||||
GRAYSCALE_DTYPES: tuple[np.dtype, ...] = (
|
||||
np.dtype("uint8"),
|
||||
np.dtype("uint16"),
|
||||
np.dtype("float32"),
|
||||
)
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
|
||||
Behaviour by shape:
|
||||
|
||||
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
|
||||
The native dtype is preserved using the matching PIL mode
|
||||
(``L`` / ``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
|
||||
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
|
||||
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
|
||||
(existing behaviour, gated by ``range_check``).
|
||||
|
||||
Other shapes / channel counts raise ``NotImplementedError`` or
|
||||
``ValueError``.
|
||||
"""
|
||||
if image_array.ndim not in (2, 3):
|
||||
raise ValueError(
|
||||
f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image."
|
||||
)
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in GRAYSCALE_DTYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported single-channel image dtype: {image_array.dtype}. "
|
||||
f"Supported dtypes: {sorted(str(d) for d in GRAYSCALE_DTYPES)}."
|
||||
)
|
||||
|
||||
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
|
||||
|
||||
# 3D path: must be RGB (3 channels), channels-first or channels-last.
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
elif image_array.shape[-1] != 3:
|
||||
raise NotImplementedError(
|
||||
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||
@@ -112,28 +71,13 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
|
||||
|
||||
PNG uses ``compress_level`` (0–9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
|
||||
"""
|
||||
suffix = Path(fpath).suffix.lower()
|
||||
if suffix == ".png":
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation. The output format is inferred from the *fpath*
|
||||
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
|
||||
→ lossless raw depth maps (TIFF).
|
||||
the save operation.
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
@@ -157,7 +101,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, **save_kwargs_for_path(Path(fpath), compress_level))
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
except Exception as e:
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@ from .utils import (
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
DatasetInfo,
|
||||
serialize_dict,
|
||||
)
|
||||
|
||||
@@ -116,21 +115,25 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
return dataset
|
||||
|
||||
|
||||
def write_info(info: DatasetInfo, local_dir: Path) -> None:
|
||||
write_json(info.to_dict(), local_dir / INFO_PATH)
|
||||
def write_info(info: dict, local_dir: Path) -> None:
|
||||
write_json(info, local_dir / INFO_PATH)
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> DatasetInfo:
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
"""Load dataset info metadata from its standard file path.
|
||||
|
||||
Also converts shape lists to tuples for consistency.
|
||||
|
||||
Args:
|
||||
local_dir (Path): The root directory of the dataset.
|
||||
|
||||
Returns:
|
||||
DatasetInfo: The typed dataset information object.
|
||||
dict: The dataset information dictionary.
|
||||
"""
|
||||
raw = load_json(local_dir / INFO_PATH)
|
||||
return DatasetInfo.from_dict(raw)
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
|
||||
|
||||
def write_stats(stats: dict, local_dir: Path) -> None:
|
||||
|
||||
@@ -35,11 +35,9 @@ from .utils import (
|
||||
is_valid_version,
|
||||
)
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
get_safe_default_video_backend,
|
||||
seed_depth_feature_info,
|
||||
get_safe_default_codec,
|
||||
resolve_vcodec,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -60,11 +58,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -180,15 +177,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
camera_encoder_config (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). Defaults to
|
||||
:class:`~lerobot.datasets.video_utils.VideoEncoderConfig` defaults when ``None``.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
|
||||
Note:
|
||||
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
|
||||
@@ -204,13 +202,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self._return_uint8 = return_uint8
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
if self._requested_root is not None:
|
||||
@@ -253,23 +248,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
seed_depth_feature_info(self.meta.features, self._depth_encoder_config)
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps,
|
||||
self._camera_encoder_config,
|
||||
self._encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
depth_encoder_config=self._depth_encoder_config,
|
||||
depth_keys=self.meta.depth_keys,
|
||||
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
self.writer = DatasetWriter(
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
camera_encoder_config=self._camera_encoder_config,
|
||||
depth_encoder_config=self._depth_encoder_config,
|
||||
encoder_threads=self._encoder_threads,
|
||||
vcodec=self._vcodec,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
initial_frames=self.meta.total_frames,
|
||||
@@ -310,20 +298,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@staticmethod
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
encoder_threads: int | None,
|
||||
vcodec: str,
|
||||
encoder_queue_maxsize: int,
|
||||
*,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
depth_keys: list[str] | None = None,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
depth_keys=depth_keys,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
|
||||
# ── Metadata properties ───────────────────────────────────────────
|
||||
@@ -638,8 +625,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -670,23 +656,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend (used when reading back).
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
camera_encoder_config: Video encoder settings for cameras; defaults
|
||||
match :class:`~lerobot.datasets.video_utils.VideoEncoderConfig`
|
||||
when ``None``.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
vcodec: Video codec for encoding. Options include ``'libsvtav1'``,
|
||||
``'h264'``, ``'hevc'``, ``'auto'``.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
before flushing to parquet.
|
||||
streaming_encoding: If ``True``, encode video frames in real-time
|
||||
during capture instead of writing images first.
|
||||
encoder_queue_maxsize: Max buffered frames per camera when using
|
||||
streaming encoding.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A new :class:`LeRobotDataset` in write mode.
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -707,32 +690,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps,
|
||||
camera_encoder_config,
|
||||
encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
depth_keys=obj.meta.depth_keys,
|
||||
)
|
||||
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
vcodec=vcodec,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -755,13 +729,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Resume recording on an existing dataset.
|
||||
|
||||
@@ -784,16 +757,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend for reading back data.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
camera_encoder_config: Video encoder settings for cameras; defaults
|
||||
match :class:`~lerobot.datasets.video_utils.VideoEncoderConfig`
|
||||
when ``None``.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
vcodec: Video codec for encoding.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
image_writer_threads: Threads for async image writing.
|
||||
streaming_encoding: If ``True``, encode video in real-time during
|
||||
capture.
|
||||
encoder_queue_maxsize: Max buffered frames per camera for streaming.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A :class:`LeRobotDataset` in write mode, ready to append episodes.
|
||||
@@ -804,6 +774,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
|
||||
"the shared cache. Please provide a local directory path."
|
||||
)
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj._requested_root = Path(root)
|
||||
@@ -812,9 +783,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if obj._requested_root is not None:
|
||||
obj._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -823,33 +796,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.meta = LeRobotDatasetMetadata(
|
||||
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
obj.root = obj.meta.root
|
||||
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer for appending
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps,
|
||||
camera_encoder_config,
|
||||
encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
depth_keys=obj.meta.depth_keys,
|
||||
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
vcodec=vcodec,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -123,7 +123,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info.fps
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
@@ -133,7 +133,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return len(self._datasets[0].meta.video_keys) > 0
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
|
||||
@@ -1,311 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
|
||||
|
||||
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
|
||||
Checks degrade to a no-op when the target codec isn't available locally.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.depth_utils import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
quantize_depth,
|
||||
dequantize_depth,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pixel formats supported by the depth encode/decode helpers below. Both are
|
||||
# 16-bit-word formats that carry 12 significant bits per sample, matching the
|
||||
# ``DEPTH_QMAX = 4095`` quantization range.
|
||||
DEPTH_PIX_FMTS: tuple[str, ...] = ("yuv420p12le", "gray12le")
|
||||
|
||||
# Neutral chroma for 12-bit YUV (the midpoint of [0, 4095]). Filling the U/V
|
||||
# planes with this value keeps the encoder from spending bits on chroma noise
|
||||
# when only the Y plane carries information.
|
||||
_NEUTRAL_CHROMA_12BIT: int = 2048
|
||||
|
||||
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def _write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
def encode_depth_frame_pyav(
|
||||
depth: np.ndarray | torch.Tensor,
|
||||
*,
|
||||
pix_fmt: str = "yuv420p12le",
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> av.VideoFrame:
|
||||
"""Quantize depth and pack it into a 12-bit PyAV video frame.
|
||||
|
||||
Args:
|
||||
depth: Depth frame to encode (H, W). Unit handling follows
|
||||
:func:`lerobot.datasets.depth_utils.quantize_depth`.
|
||||
pix_fmt: Target pixel format. Must be one of :data:`DEPTH_PIX_FMTS`.
|
||||
depth_min, depth_max, shift, use_log, input_unit: Forwarded to
|
||||
:func:`quantize_depth`.
|
||||
|
||||
Returns:
|
||||
An :class:`av.VideoFrame` in ``pix_fmt`` with quantized depth in the
|
||||
luminance plane.
|
||||
"""
|
||||
if pix_fmt not in DEPTH_PIX_FMTS:
|
||||
raise ValueError(f"Unsupported depth pix_fmt={pix_fmt!r}; expected one of {DEPTH_PIX_FMTS}")
|
||||
|
||||
quantized_depth = quantize_depth(
|
||||
depth,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
input_unit=input_unit,
|
||||
)
|
||||
if quantized_depth.ndim != 2:
|
||||
raise ValueError(f"depth must be a 2D frame; got shape {quantized_depth.shape}")
|
||||
|
||||
quantized_depth = np.ascontiguousarray(quantized_depth, dtype=np.uint16)
|
||||
height, width = quantized_depth.shape
|
||||
|
||||
if pix_fmt == "gray12le":
|
||||
frame = av.VideoFrame(width=width, height=height, format="gray12le")
|
||||
_write_u16_plane(frame.planes[0], quantized_depth)
|
||||
return frame
|
||||
|
||||
if height % 2 != 0 or width % 2 != 0:
|
||||
raise ValueError("yuv420p12le requires even H and W")
|
||||
|
||||
frame = av.VideoFrame(width=width, height=height, format="yuv420p12le")
|
||||
_write_u16_plane(frame.planes[0], quantized_depth)
|
||||
neutral_chroma = np.full((height // 2, width // 2), _NEUTRAL_CHROMA_12BIT, dtype=np.uint16)
|
||||
_write_u16_plane(frame.planes[1], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
|
||||
_write_u16_plane(frame.planes[2], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
|
||||
return frame
|
||||
|
||||
|
||||
def decode_depth_frame_pyav(
|
||||
frame: av.VideoFrame | list[av.VideoFrame],
|
||||
*,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
return_quantized: bool = False,
|
||||
output_unit: Literal["m", "mm"] = "m",
|
||||
) -> np.ndarray:
|
||||
"""Decode one or many depth video frames to quantized or metric depth.
|
||||
|
||||
Args:
|
||||
frame: A single depth frame or a list of depth frames.
|
||||
depth_min, depth_max, shift, use_log: Forwarded to
|
||||
:func:`dequantize_depth`.
|
||||
return_quantized: If ``True``, return raw 12-bit quanta as ``uint16``.
|
||||
output_unit: Unit for dequantized output (``"m"`` or ``"mm"``).
|
||||
|
||||
Returns:
|
||||
``(H, W)`` array for a single frame, or ``(N, H, W)`` for a list.
|
||||
"""
|
||||
frames = frame if isinstance(frame, list) else [frame]
|
||||
quantized = np.stack([f.reformat(format="gray12le").to_ndarray() for f in frames]).astype(np.uint16, copy=False)
|
||||
if return_quantized:
|
||||
return quantized[0] if len(frames) == 1 else quantized
|
||||
|
||||
decoded = dequantize_depth(
|
||||
quantized,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
output_unit=output_unit,
|
||||
)
|
||||
return decoded[0] if len(frames) == 1 else decoded
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
try:
|
||||
return av.codec.Codec(vcodec, "w")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_video_formats(vcodec: str) -> dict[str, av.option.Option]:
|
||||
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return {}
|
||||
return {opt.name: opt for opt in codec.descriptor.options}
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
|
||||
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return ()
|
||||
return tuple(fmt.name for fmt in (codec.video_formats or []))
|
||||
|
||||
|
||||
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
|
||||
|
||||
Each name is probed directly via :func:`get_codec`; input order is preserved.
|
||||
"""
|
||||
if isinstance(encoders, str):
|
||||
encoders = [encoders]
|
||||
|
||||
available: list[str] = []
|
||||
for name in encoders:
|
||||
codec = get_codec(name)
|
||||
if codec is not None and codec.type == "video":
|
||||
available.append(name)
|
||||
else:
|
||||
logger.debug("encoder '%s' not available as video encoder", name)
|
||||
return available
|
||||
|
||||
|
||||
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
|
||||
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
|
||||
type_name = opt.type.name
|
||||
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
try:
|
||||
num_val = float(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
) from e
|
||||
elif isinstance(value, (float, int)):
|
||||
num_val = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
|
||||
# Check integer type compatibility
|
||||
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
|
||||
raise ValueError(
|
||||
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
|
||||
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
|
||||
)
|
||||
|
||||
# Check numeric range compatibility
|
||||
lo, hi = float(opt.min), float(opt.max)
|
||||
if lo < hi and not (lo <= num_val <= hi):
|
||||
raise ValueError(
|
||||
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
|
||||
)
|
||||
|
||||
elif type_name == "STRING":
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
|
||||
if isinstance(value, str):
|
||||
str_val = value
|
||||
elif isinstance(value, (int, float)):
|
||||
str_val = str(value)
|
||||
else:
|
||||
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
|
||||
|
||||
# Check string choice compatibility
|
||||
choices = [c.name for c in (opt.choices or [])]
|
||||
if choices and str_val not in choices:
|
||||
raise ValueError(
|
||||
f"{label}={str_val!r} is not a supported choice for codec "
|
||||
f"{vcodec!r}; valid choices: {choices}"
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
formats = _get_codec_video_formats(vcodec)
|
||||
if formats and pix_fmt not in formats:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
|
||||
f"supported pixel formats: {list(formats)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any], config: VideoEncoderConfig) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
for key, value in codec_options.items():
|
||||
# GOP size is not a codec-specific option, it has to be validated separately.
|
||||
if key == "g":
|
||||
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
|
||||
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
|
||||
continue
|
||||
if key not in supported_options:
|
||||
continue
|
||||
opt = supported_options[key]
|
||||
label = f"extra_options[{key!r}]" if key in config.extra_options else key
|
||||
_check_option_value(vcodec, label, value, opt)
|
||||
|
||||
|
||||
def check_video_encoder_config_pyav(config: VideoEncoderConfig) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.datasets.video_utils.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
ValueError: on the first incompatibility encountered.
|
||||
"""
|
||||
vcodec = config.vcodec
|
||||
options = _get_codec_options_by_name(vcodec)
|
||||
if not options:
|
||||
logger.warning(
|
||||
"Codec %r is not available in the bundled FFmpeg build; ",
|
||||
vcodec,
|
||||
)
|
||||
return
|
||||
_check_pixel_format(config.vcodec, config.pix_fmt)
|
||||
_check_codec_options(config.vcodec, config.get_codec_options(), config)
|
||||
@@ -434,7 +434,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
def _make_padding_camera_frame(self, camera_key: str):
|
||||
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||
return torch.zeros(self.meta.info.features[camera_key]["shape"]).permute(-1, 0, 1)
|
||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||
|
||||
def _get_video_frame_padding_mask(
|
||||
self,
|
||||
|
||||
@@ -14,11 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -72,9 +70,6 @@ class ForwardCompatibilityError(CompatibilityError):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
@@ -93,133 +88,12 @@ DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
# Depth maps live alongside images on disk but use TIFF instead of PNG: PNG
|
||||
# cannot natively round-trip float32, and several common loaders silently
|
||||
# downcast 16-bit grayscale.
|
||||
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
"""Typed representation of the ``meta/info.json`` file for a LeRobot dataset.
|
||||
|
||||
Replaces the previously untyped ``dict`` returned by ``load_info()`` and
|
||||
created by ``create_empty_dataset_info()``. Using a dataclass provides
|
||||
explicit field definitions, IDE auto-completion, and validation at
|
||||
construction time.
|
||||
"""
|
||||
|
||||
codebase_version: str
|
||||
fps: int
|
||||
features: dict[str, dict]
|
||||
|
||||
# Episode / frame counters — start at zero for new datasets
|
||||
total_episodes: int = 0
|
||||
total_frames: int = 0
|
||||
total_tasks: int = 0
|
||||
|
||||
# Storage settings
|
||||
chunks_size: int = field(default=DEFAULT_CHUNK_SIZE)
|
||||
data_files_size_in_mb: int = field(default=DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: int = field(default=DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
|
||||
# File path templates
|
||||
data_path: str = field(default=DEFAULT_DATA_PATH)
|
||||
video_path: str | None = field(default=DEFAULT_VIDEO_PATH)
|
||||
|
||||
# Optional metadata
|
||||
robot_type: str | None = None
|
||||
splits: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||
# returns lists, but the rest of the codebase expects tuples.
|
||||
for ft in self.features.values():
|
||||
if isinstance(ft.get("shape"), list):
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
|
||||
if self.fps <= 0:
|
||||
raise ValueError(f"fps must be positive, got {self.fps}")
|
||||
if self.chunks_size <= 0:
|
||||
raise ValueError(f"chunks_size must be positive, got {self.chunks_size}")
|
||||
if self.data_files_size_in_mb <= 0:
|
||||
raise ValueError(f"data_files_size_in_mb must be positive, got {self.data_files_size_in_mb}")
|
||||
if self.video_files_size_in_mb <= 0:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {self.video_files_size_in_mb}")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Return a JSON-serialisable dict.
|
||||
|
||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
for ft in d["features"].values():
|
||||
if isinstance(ft.get("shape"), tuple):
|
||||
ft["shape"] = list(ft["shape"])
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DatasetInfo":
|
||||
"""Construct from a raw dict (e.g. loaded directly from JSON).
|
||||
|
||||
Unknown keys are ignored for forward compatibility with datasets that
|
||||
carry additional fields (e.g. ``total_videos`` from v2.x). A warning is
|
||||
logged when such fields are present.
|
||||
"""
|
||||
known = {f.name for f in dataclasses.fields(cls)}
|
||||
unknown = sorted(k for k in data if k not in known)
|
||||
if unknown:
|
||||
logger.warning(f"Unknown fields in DatasetInfo: {unknown}. These will be ignored.")
|
||||
return cls(**{k: v for k, v in data.items() if k in known})
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Temporary dict-style compatibility layer
|
||||
# Allows existing ``info["key"]`` call-sites to keep working without changes.
|
||||
# Once all callers have been migrated to attribute access, remove these.
|
||||
# ---------------------------------------------------------------------------
|
||||
def __getitem__(self, key: str):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Accessing DatasetInfo with dict-style syntax info['{key}'] is deprecated. "
|
||||
f"Use attribute access info.{key} instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError as err:
|
||||
raise KeyError(key) from err
|
||||
|
||||
def __setitem__(self, key: str, value) -> None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
f"Setting DatasetInfo with dict-style syntax info['{key}'] = ... is deprecated. "
|
||||
f"Use attribute assignment info.{key} = ... instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not hasattr(self, key):
|
||||
raise KeyError(f"DatasetInfo has no field '{key}'")
|
||||
setattr(self, key, value)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if a field exists (dict-like interface)."""
|
||||
return hasattr(self, key)
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
"""Get attribute value with default fallback (dict-like interface)."""
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
return default
|
||||
|
||||
|
||||
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
||||
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
||||
|
||||
@@ -420,7 +294,7 @@ def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) ->
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None,
|
||||
dataset_info: DatasetInfo | None = None,
|
||||
dataset_info: dict | None = None,
|
||||
**kwargs,
|
||||
) -> DatasetCard:
|
||||
"""Create a `DatasetCard` for a LeRobot dataset.
|
||||
@@ -431,7 +305,7 @@ def create_lerobot_dataset_card(
|
||||
|
||||
Args:
|
||||
tags (list | None): A list of tags to add to the dataset card.
|
||||
dataset_info (DatasetInfo | None): The dataset's info object, which will
|
||||
dataset_info (dict | None): The dataset's info dictionary, which will
|
||||
be displayed on the card.
|
||||
**kwargs: Additional keyword arguments to populate the card template.
|
||||
|
||||
@@ -444,7 +318,7 @@ def create_lerobot_dataset_card(
|
||||
card_tags += tags
|
||||
if dataset_info:
|
||||
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info.to_dict(), indent=4)}\n```\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
||||
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
||||
card_data = DatasetCardData(
|
||||
license=kwargs.get("license"),
|
||||
|
||||
@@ -17,13 +17,12 @@ import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
@@ -38,23 +37,7 @@ import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.datasets.pyav_utils import (
|
||||
check_video_encoder_config_pyav,
|
||||
depth_to_video_frame,
|
||||
detect_available_encoders_pyav,
|
||||
decode_depth_frame,
|
||||
encode_depth_frame_pyav,
|
||||
decode_depth_frame_pyav,
|
||||
)
|
||||
from lerobot.datasets.depth_utils import (
|
||||
quantize_depth,
|
||||
dequantize_depth,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,226 +52,70 @@ HW_ENCODERS = [
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "ffv1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
"""Video encoder configuration.
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
|
||||
Attributes:
|
||||
vcodec: FFmpeg encoder name. ``"auto"`` is resolved during
|
||||
construction (HW encoder if available, else ``libsvtav1``).
|
||||
pix_fmt: Pixel format (e.g. ``"yuv420p"``).
|
||||
g: GOP size (keyframe interval).
|
||||
crf: Quality level — mapped to the native quality parameter of the
|
||||
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
|
||||
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
|
||||
preset: Speed/quality preset. Accepted type is per-codec.
|
||||
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
|
||||
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
|
||||
set ``tune=fastdecode``. Ignored for other codecs.
|
||||
video_backend: Python library driving FFmpeg for encoding. Only ``"pyav"``
|
||||
is currently supported.
|
||||
extra_options: Free-form dictionary of additional FFmpeg options
|
||||
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
|
||||
"""
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int | None = 2
|
||||
crf: int | None = 30
|
||||
preset: int | str | None = None
|
||||
fast_decode: int = 0
|
||||
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
|
||||
# two backends (encoding and decoding).
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
|
||||
# Class-level marker persisted to ``info.json`` (via ``asdict``) so the
|
||||
# reader can tell depth datasets from RGB ones without a separate dispatch
|
||||
# path. ``init=False`` keeps it out of CLI/constructor surface; subclasses
|
||||
# flip the default (see :class:`DepthEncoderConfig`).
|
||||
is_depth_map: bool = field(default=False, init=False)
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
if self.preset is None and self.vcodec == "libsvtav1":
|
||||
self.preset = LIBSVTAV1_DEFAULT_PRESET
|
||||
|
||||
self.validate()
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Detect available encoders based on the video backend."""
|
||||
if self.video_backend == "pyav":
|
||||
return detect_available_encoders_pyav(encoders)
|
||||
else:
|
||||
return []
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate the video encoder config."""
|
||||
if self.video_backend == "pyav":
|
||||
check_video_encoder_config_pyav(self)
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1.
|
||||
|
||||
Any explicitly-requested codec that isn't in the local FFmpeg build is
|
||||
also silently rewritten to ``libsvtav1`` so encoding never hard-fails on
|
||||
a host missing the requested encoder.
|
||||
"""
|
||||
# Backward compatibility: older datasets persist ``vcodec="av1"`` in
|
||||
# ``info.json``. Rewrite to the canonical encoder name *before* the
|
||||
# validation check below so loading those datasets keeps working.
|
||||
if self.vcodec == "av1":
|
||||
self.vcodec = "libsvtav1"
|
||||
|
||||
if self.vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if self.vcodec == "auto":
|
||||
available = self.detect_available_encoders(HW_ENCODERS)
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
self.vcodec = encoder
|
||||
return
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
self.vcodec = "libsvtav1"
|
||||
|
||||
if self.detect_available_encoders(self.vcodec):
|
||||
logger.info(f"Using video codec: {self.vcodec}")
|
||||
self.vcodec = self.vcodec
|
||||
return
|
||||
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
|
||||
|
||||
def get_codec_options(
|
||||
self, encoder_threads: int | None = None, as_strings: bool = False
|
||||
) -> dict[str, str]:
|
||||
"""Translate the tuning fields to codec-specific FFmpeg options.
|
||||
|
||||
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
|
||||
|
||||
Args:
|
||||
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
|
||||
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
|
||||
For h264/hevc, this is mapped to ``threads``.
|
||||
Hardware encoders ignore this parameter.
|
||||
as_strings: If ``True``, casts values to strings.
|
||||
"""
|
||||
opts: dict[str, Any] = {}
|
||||
|
||||
def set_if(key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
opts[key] = value if not as_strings else str(value)
|
||||
|
||||
# GOP size is not a codec-specific option, so it is always set.
|
||||
set_if("g", self.g)
|
||||
|
||||
if self.vcodec == "libsvtav1":
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
svtav1_parts: list[str] = []
|
||||
if self.fast_decode is not None:
|
||||
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
|
||||
if encoder_threads is not None:
|
||||
svtav1_parts.append(f"lp={encoder_threads}")
|
||||
if svtav1_parts:
|
||||
opts["svtav1-params"] = ":".join(svtav1_parts)
|
||||
elif self.vcodec in ("h264", "hevc"):
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
if self.fast_decode:
|
||||
opts["tune"] = "fastdecode"
|
||||
set_if("threads", encoder_threads)
|
||||
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
if self.crf is not None:
|
||||
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
|
||||
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
opts["rc"] = "constqp"
|
||||
set_if("qp", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "h264_vaapi":
|
||||
set_if("qp", self.crf)
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "ffv1":
|
||||
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
|
||||
# are not meaningful.
|
||||
set_if("threads", encoder_threads)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
|
||||
# Extra options are merged last but never override structured fields (values are kept as given).
|
||||
for k, v in self.extra_options.items():
|
||||
if k not in opts:
|
||||
set_if(k, v)
|
||||
|
||||
return opts
|
||||
return options
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthEncoderConfig(VideoEncoderConfig):
|
||||
"""Encoder configuration for depth-map streams.
|
||||
|
||||
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
|
||||
preset, ``extra_options``…) and adds the four parameters of the depth
|
||||
quantization pipeline (:func:`quantize_depth`). Inheritance — rather
|
||||
than composition — keeps the CLI flat: ``--dataset.depth_encoder_config.<field>``
|
||||
works identically to its RGB counterpart.
|
||||
|
||||
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
|
||||
to ``"yuv420p12le"``, the most widely available 12-bit pixel format.
|
||||
For archive-grade lossless storage use ``vcodec="ffv1"`` together with
|
||||
``pix_fmt="gray12le"`` (and clear ``crf``/``preset`` to ``None`` since
|
||||
``ffv1`` doesn't expose those tuning knobs).
|
||||
|
||||
The :attr:`is_depth_map` marker is class-fixed to ``True`` (``init=False``,
|
||||
so it's hidden from CLI and constructor args) and is what the reader
|
||||
side keys on to tell depth datasets from RGB ones.
|
||||
|
||||
Attributes:
|
||||
depth_min: Minimum depth in physical units (e.g. metres) represented
|
||||
by quantum ``0``.
|
||||
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
|
||||
shift: Pre-log offset for numerical stability near zero.
|
||||
use_log: ``True`` for logarithmic quantization (default; matches
|
||||
sensor error profile), ``False`` for linear.
|
||||
"""
|
||||
|
||||
vcodec: str = "hevc"
|
||||
pix_fmt: str = "yuv420p12le"
|
||||
|
||||
depth_min: float = DEFAULT_DEPTH_MIN
|
||||
depth_max: float = DEFAULT_DEPTH_MAX
|
||||
shift: float = DEFAULT_DEPTH_SHIFT
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG
|
||||
|
||||
# Class invariant — kept out of ``__init__`` (and CLI) but persisted
|
||||
# via ``asdict`` into ``info.json`` for the reader to detect depth.
|
||||
is_depth_map: bool = field(default=True, init=False)
|
||||
|
||||
def quantize(self, depth: torch.Tensor | np.ndarray) -> torch.Tensor:
|
||||
"""Apply :func:`quantize_depth` bound to this config's parameters."""
|
||||
return quantize_depth(depth, self.depth_min, self.depth_max, self.shift, self.use_log)
|
||||
|
||||
def dequantize(self, quantized: torch.Tensor | np.ndarray) -> torch.Tensor:
|
||||
"""Apply :func:`dequantize_depth` bound to this config's parameters."""
|
||||
return dequantize_depth(quantized, self.depth_min, self.depth_max, self.shift, self.use_log)
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
|
||||
return DepthEncoderConfig()
|
||||
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logger.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
@@ -315,7 +142,7 @@ def decode_video_frames(
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
backend = get_safe_default_codec()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
@@ -569,136 +396,22 @@ def decode_video_frames_torchcodec(
|
||||
return closest_frames
|
||||
|
||||
|
||||
def decode_depth_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
*,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
return_quantized: bool = False,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Decode depth-map frames at the requested timestamps using PyAV.
|
||||
|
||||
Mirrors the timestamp-tolerance / closest-frame contract of
|
||||
:func:`decode_video_frames` but operates entirely through PyAV (the
|
||||
``torchvision`` and ``torchcodec`` backends don't currently round-trip
|
||||
12-bit pixel formats reliably).
|
||||
|
||||
Each decoded frame is reformatted to ``gray12le`` so the same path
|
||||
handles ``yuv420p12le`` (HEVC default) and ``gray12le`` (ffv1 archive)
|
||||
sources transparently.
|
||||
|
||||
Args:
|
||||
video_path: Path to a depth video produced with a
|
||||
:class:`DepthEncoderConfig`.
|
||||
timestamps: Frame timestamps to retrieve, in seconds.
|
||||
tolerance_s: Maximum allowed deviation between the queried and the
|
||||
actually-decoded timestamps.
|
||||
depth_min, depth_max, shift, use_log: Parameters used at quantization
|
||||
time. Should match :func:`info_to_depth_kwargs` extracted from
|
||||
``info.json`` for the source dataset.
|
||||
return_quantized: If ``True``, skip the dequantization step and
|
||||
return raw 12-bit ``uint16`` quanta.
|
||||
log_loaded_timestamps: Debug logging.
|
||||
|
||||
Returns:
|
||||
``torch.Tensor`` of shape ``(N, H, W)``:
|
||||
|
||||
* ``dtype=torch.float32`` (metric depth, default)
|
||||
* ``dtype=torch.uint16`` when ``return_quantized=True``.
|
||||
|
||||
Raises:
|
||||
FrameTimestampError: If a query timestamp can't be matched within
|
||||
*tolerance_s*, or if no frames are decoded.
|
||||
"""
|
||||
video_path_str = str(video_path)
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
loaded_frames: list[np.ndarray] = []
|
||||
loaded_ts: list[float] = []
|
||||
|
||||
av.logging.set_level(av.logging.WARNING)
|
||||
with av.open(video_path_str, "r") as container:
|
||||
try:
|
||||
stream = container.streams.video[0]
|
||||
except IndexError as e:
|
||||
raise FrameTimestampError(f"No video stream in {video_path_str}") from e
|
||||
|
||||
# Seek to the keyframe at-or-before first_ts (PyAV doesn't do
|
||||
# accurate seek, so we still iterate forward to the requested range).
|
||||
seek_pts = int(first_ts / stream.time_base)
|
||||
container.seek(seek_pts, stream=stream, any_frame=False, backward=True)
|
||||
|
||||
for frame in container.decode(stream):
|
||||
if frame.pts is None:
|
||||
continue
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"depth frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(
|
||||
decode_depth_frame(
|
||||
frame,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
return_quantized=True,
|
||||
)
|
||||
)
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not loaded_frames:
|
||||
raise FrameTimestampError(
|
||||
f"No depth frames decoded from {video_path_str} for timestamps {timestamps}"
|
||||
)
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts_t = torch.tensor(loaded_ts)
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps violate the tolerance "
|
||||
f"({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts_t}"
|
||||
f"\nvideo: {video_path_str}"
|
||||
)
|
||||
|
||||
closest = np.stack([loaded_frames[i] for i in argmin_]) # (N, H, W) uint16
|
||||
quantized = torch.from_numpy(closest)
|
||||
|
||||
if return_quantized:
|
||||
return quantized
|
||||
return dequantize_depth(quantized, depth_min, depth_max, shift, use_log)
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
vcodec = camera_encoder_config.vcodec
|
||||
pix_fmt = camera_encoder_config.pix_fmt
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -709,18 +422,42 @@ def encode_video_frames(
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logger.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
# Get input frames
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
video_options = camera_encoder_config.get_codec_options(encoder_threads, as_strings=True)
|
||||
# Define video codec options
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -757,10 +494,7 @@ def encode_video_frames(
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str],
|
||||
output_video_path: Path,
|
||||
overwrite: bool = True,
|
||||
compatibility_check: bool = False,
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
@@ -773,7 +507,6 @@ def concatenate_video_files(
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
compatibility_check: Whether to check if the input videos are compatible. Default is False.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
@@ -792,22 +525,6 @@ def concatenate_video_files(
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
|
||||
# This check may be skipped at recording time as videos are encoded with the same encoder config.
|
||||
if compatibility_check:
|
||||
reference_video_info = get_video_info(input_video_paths[0])
|
||||
for input_path in input_video_paths[1:]:
|
||||
video_info = get_video_info(input_path)
|
||||
if (
|
||||
video_info["video.height"] != reference_video_info["video.height"]
|
||||
or video_info["video.width"] != reference_video_info["video.width"]
|
||||
or video_info["video.fps"] != reference_video_info["video.fps"]
|
||||
or video_info["video.codec"] != reference_video_info["video.codec"]
|
||||
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
|
||||
)
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
@@ -874,31 +591,33 @@ class _CameraEncoderThread(threading.Thread):
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, str],
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
depth_encoder_config: "DepthEncoderConfig | None" = None,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.codec_options = codec_options
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.depth_encoder_config = depth_encoder_config
|
||||
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
|
||||
container = None
|
||||
output_stream = None
|
||||
is_depth = self.depth_encoder_config is not None
|
||||
stats_tracker = RunningQuantileStats() if not is_depth else None
|
||||
stats_tracker = RunningQuantileStats()
|
||||
frame_count = 0
|
||||
|
||||
try:
|
||||
@@ -916,45 +635,51 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
# Ensure HWC uint8 numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8 and not is_depth:
|
||||
if frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
if is_depth:
|
||||
video_frame = encode_depth_frame_pyav(frame_data, pix_fmt=self.pix_fmt, depth_min=self.depth_encoder_config.depth_min, depth_max=self.depth_encoder_config.depth_max, shift=self.depth_encoder_config.shift, use_log=self.depth_encoder_config.use_log)
|
||||
else:
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
if not is_depth:
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
@@ -969,10 +694,8 @@ class _CameraEncoderThread(threading.Thread):
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Get stats and put on result queue (depth streams skip stats)
|
||||
if is_depth:
|
||||
self.result_queue.put(("ok", None))
|
||||
elif frame_count >= 2:
|
||||
# Get stats and put on result queue
|
||||
if frame_count >= 2:
|
||||
stats = stats_tracker.get_statistics()
|
||||
self.result_queue.put(("ok", stats))
|
||||
else:
|
||||
@@ -1001,40 +724,22 @@ class StreamingVideoEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
depth_encoder_config: "DepthEncoderConfig | None" = None,
|
||||
depth_keys: list[str] | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fps: Frames per second for the output videos.
|
||||
camera_encoder_config: Video encoder settings applied to all cameras.
|
||||
When ``None``, :class:`VideoEncoderConfig` defaults are used.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
depth_encoder_config: Optional depth encoder configuration applied
|
||||
to all depth video keys listed in ``depth_keys``.
|
||||
depth_keys: Video keys (matching the dataset feature names) that
|
||||
must be encoded as quantized depth maps using
|
||||
``depth_encoder_config``. Required when ``depth_encoder_config``
|
||||
is provided.
|
||||
"""
|
||||
self.fps = fps
|
||||
self._camera_encoder_config = camera_encoder_config or VideoEncoderConfig()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._depth_keys: set[str] = set(depth_keys or [])
|
||||
if self._depth_keys and self._depth_encoder_config is None:
|
||||
raise ValueError(
|
||||
"StreamingVideoEncoder received depth_keys without a depth_encoder_config; "
|
||||
"either pass a DepthEncoderConfig or remove depth_keys."
|
||||
)
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
@@ -1065,28 +770,18 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
is_depth_key = video_key in self._depth_keys
|
||||
encoder_cfg: VideoEncoderConfig
|
||||
depth_cfg = None
|
||||
if is_depth_key:
|
||||
assert self._depth_encoder_config is not None # guaranteed by __init__
|
||||
encoder_cfg = self._depth_encoder_config
|
||||
depth_cfg = self._depth_encoder_config
|
||||
else:
|
||||
encoder_cfg = self._camera_encoder_config
|
||||
|
||||
vcodec = encoder_cfg.vcodec
|
||||
codec_options = encoder_cfg.get_codec_options(self._encoder_threads)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=encoder_cfg.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
depth_encoder_config=depth_cfg,
|
||||
encoder_threads=self.encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -1291,18 +986,8 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
video_encoder_config: "VideoEncoderConfig | None" = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
video_encoder_config: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting video stream information
|
||||
@@ -1319,6 +1004,7 @@ def get_video_info(
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
@@ -1332,67 +1018,9 @@ def get_video_info(
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided (no override of stream-derived values)
|
||||
# Depth related fields flow naturally through this path.
|
||||
if video_encoder_config is not None:
|
||||
for field_name, field_value in asdict(video_encoder_config).items():
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
# Fallback case where no encoder config is provided or the video is not a depth map.
|
||||
video_info.setdefault("video.is_depth_map", False)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
# ─── Depth metadata helpers (reader side) ────────────────────────────
|
||||
|
||||
|
||||
_DEPTH_INFO_KEYS: tuple[str, ...] = (
|
||||
"video.depth_min",
|
||||
"video.depth_max",
|
||||
"video.shift",
|
||||
"video.use_log",
|
||||
)
|
||||
|
||||
|
||||
def seed_depth_feature_info(
|
||||
features: dict[str, dict],
|
||||
depth_encoder_config: "DepthEncoderConfig | None",
|
||||
) -> None:
|
||||
"""Pre-populate per-feature ``video.<field>`` entries from *depth_encoder_config*.
|
||||
|
||||
``update_video_info`` only runs after the first episode video is encoded,
|
||||
so without this seeding step ``features[key]["info"]`` carries no
|
||||
quantization range until then. Consumers that read the dataset feature
|
||||
spec mid-recording (e.g. the rerun visualizer pinning the depth colormap
|
||||
to ``video.depth_min`` / ``video.depth_max``) would otherwise see no
|
||||
range during episode 1 and re-normalize per frame.
|
||||
|
||||
Stream-derived values written later by :func:`get_video_info` /
|
||||
``update_video_info`` win over these seeds (the merge is
|
||||
``{**existing, **stream_info}``), so callers can safely re-run this on
|
||||
a partially-populated info dict.
|
||||
|
||||
No-op when ``depth_encoder_config`` is ``None`` or no feature is flagged
|
||||
as a depth map.
|
||||
"""
|
||||
if depth_encoder_config is None:
|
||||
return
|
||||
encoder_fields = {
|
||||
f"video.{name}": value for name, value in asdict(depth_encoder_config).items()
|
||||
}
|
||||
for ft in features.values():
|
||||
if ft.get("dtype") != "video":
|
||||
continue
|
||||
info = ft.get("info") or {}
|
||||
if not info.get("video.is_depth_map", False):
|
||||
continue
|
||||
# Only fill fields not already set, so explicit user-provided info is preserved.
|
||||
for k, v in encoder_fields.items():
|
||||
info.setdefault(k, v)
|
||||
ft["info"] = info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
|
||||
@@ -299,6 +299,7 @@ class HILSerlProcessorConfig:
|
||||
inverse_kinematics: InverseKinematicsConfig | None = None
|
||||
reward_classifier: RewardClassifierConfig | None = None
|
||||
max_gripper_pos: float | None = 100.0
|
||||
gripper_speed_factor: float | None = None
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
|
||||
@@ -17,13 +17,17 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig as GaussianActorConfig
|
||||
from .gaussian_actor.reward_model.configuration_classifier import (
|
||||
RewardClassifierConfig as RewardClassifierConfig,
|
||||
)
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .utils import make_robot_action, prepare_observation_for_inference
|
||||
@@ -31,20 +35,22 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
|
||||
# NOTE: Policy modeling classes (e.g., GaussianActorPolicy) are intentionally NOT re-exported here.
|
||||
# They have heavy optional dependencies and are loaded lazily via get_policy_class().
|
||||
# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
|
||||
# Import directly: ``from lerobot.policies.gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy``
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"GaussianActorConfig",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"SACConfig",
|
||||
"RewardClassifierConfig",
|
||||
"SARMConfig",
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
|
||||
@@ -46,12 +46,14 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
|
||||
from .gaussian_actor.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
from .pi05.configuration_pi05 import PI05Config
|
||||
from .pretrained import PreTrainedPolicy
|
||||
from .sac.configuration_sac import SACConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from .utils import validate_visual_features_consistency
|
||||
@@ -87,7 +89,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "smolvla", "wall_x".
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "gaussian_actor", "reward_classifier", "smolvla", "wall_x".
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -126,14 +128,22 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "sac":
|
||||
from .sac.modeling_sac import SACPolicy
|
||||
elif name == "gaussian_actor":
|
||||
from .gaussian_actor.modeling_gaussian_actor import GaussianActorPolicy
|
||||
|
||||
return SACPolicy
|
||||
return GaussianActorPolicy
|
||||
elif name == "reward_classifier":
|
||||
from .gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
elif name == "smolvla":
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "sarm":
|
||||
from .sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "groot":
|
||||
from .groot.modeling_groot import GrootPolicy
|
||||
|
||||
@@ -162,8 +172,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"smolvla", "wall_x".
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "gaussian_actor",
|
||||
"smolvla", "reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -186,10 +196,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "gaussian_actor":
|
||||
return GaussianActorConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
elif policy_type == "xvla":
|
||||
@@ -358,10 +370,18 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from .sac.processor_sac import make_sac_pre_post_processors
|
||||
elif isinstance(policy_cfg, GaussianActorConfig):
|
||||
from .gaussian_actor.processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
processors = make_gaussian_actor_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
from .gaussian_actor.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
@@ -374,6 +394,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from .sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
processors = make_sarm_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from .groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .modeling_sac import SACPolicy
|
||||
from .processor_sac import make_sac_pre_post_processors
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
from .modeling_gaussian_actor import GaussianActorPolicy
|
||||
from .processor_gaussian_actor import make_gaussian_actor_pre_post_processors
|
||||
|
||||
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
|
||||
__all__ = ["GaussianActorConfig", "GaussianActorPolicy", "make_gaussian_actor_pre_post_processors"]
|
||||
@@ -75,18 +75,19 @@ class PolicyConfig:
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@PreTrainedConfig.register_subclass("gaussian_actor")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
class GaussianActorConfig(PreTrainedConfig):
|
||||
"""Gaussian actor configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
This configures the policy-side (actor + observation encoder) of a Gaussian
|
||||
policy, as used by SAC and related maximum-entropy continuous-control algorithms.
|
||||
By default the actor output is a tanh-squashed diagonal Gaussian
|
||||
(``TanhMultivariateNormalDiag``); the tanh squashing can be disabled via
|
||||
``policy_kwargs.use_tanh_squash``. The critics, temperature, and Bellman-update
|
||||
logic live on the algorithm side (see ``lerobot.rl.algorithms.sac``).
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
CLI: ``--policy.type=gaussian_actor``.
|
||||
"""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
@@ -122,7 +123,7 @@ class SACConfig(PreTrainedConfig):
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
@@ -135,78 +136,41 @@ class SACConfig(PreTrainedConfig):
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
offline_buffer_capacity: int = 100000
|
||||
# Whether to use asynchronous prefetching for the buffers
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
# Discount factor for the SAC algorithm
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Encoder architecture
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the actor network architecture
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the policy parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Configuration for the discrete critic network
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for actor-learner architecture
|
||||
# Online training (TODO(Khalil): relocate to TrainRLServerPipelineConfig)
|
||||
online_steps: int = 1000000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
async_prefetch: bool = False
|
||||
online_step_before_learning: int = 100
|
||||
|
||||
# Actor-learner transport (TODO(Khalil): relocate to TrainRLServerPipelineConfig).
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
# Network architecture
|
||||
# Actor network
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Gaussian head parameters
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
# Discrete critic
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
"actor": {"lr": 3e-4},
|
||||
"critic": {"lr": 3e-4},
|
||||
"temperature": {"lr": 3e-4},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -15,16 +15,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Literal
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
@@ -32,20 +28,20 @@ from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from ..utils import get_device_from_parameters
|
||||
from .configuration_sac import SACConfig, is_image_feature
|
||||
from .configuration_gaussian_actor import GaussianActorConfig, is_image_feature
|
||||
|
||||
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
class GaussianActorPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
config_class = SACConfig
|
||||
name = "sac"
|
||||
config_class = GaussianActorConfig
|
||||
name = "gaussian_actor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
config: GaussianActorConfig | None = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
@@ -54,9 +50,8 @@ class SACPolicy(
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_temperature()
|
||||
self._init_discrete_critic()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
@@ -65,11 +60,7 @@ class SACPolicy(
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["discrete_critic"] = self.discrete_critic.parameters()
|
||||
return optim_params
|
||||
|
||||
def reset(self):
|
||||
@@ -79,7 +70,9 @@ class SACPolicy(
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
||||
raise NotImplementedError(
|
||||
"GaussianActorPolicy does not support action chunking. It returns single actions!"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -92,360 +85,55 @@ class SACPolicy(
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
if self.discrete_critic is not None:
|
||||
discrete_action_value = self.discrete_critic(batch, observations_features)
|
||||
discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True)
|
||||
else:
|
||||
discrete_action = torch.ones(
|
||||
(*actions.shape[:-1], 1), device=actions.device, dtype=actions.dtype
|
||||
)
|
||||
actions = torch.cat([actions, discrete_action], dim=-1)
|
||||
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
def forward(self, batch: dict[str, Tensor | dict[str, Tensor]]) -> dict[str, Tensor]:
|
||||
"""Actor forward pass: sample actions and return log-probabilities.
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
batch: A flat observation dict, or a training dict containing
|
||||
``"state"`` (observations) and optionally ``"observation_feature"``
|
||||
(pre-computed encoder features).
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
Dict with ``"action"``, ``"log_prob"``, and ``"action_mean"`` tensors.
|
||||
"""
|
||||
observations = batch.get("state", batch)
|
||||
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
|
||||
actions, log_probs, means = self.actor(observations, observation_features)
|
||||
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
def load_actor_weights(self, state_dicts: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
from lerobot.utils.transition import move_state_dict_to_device
|
||||
|
||||
def discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
self.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
if "discrete_critic" in state_dicts and self.discrete_critic is not None:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
loss_discrete_critic = self.compute_loss_discrete_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
return {"loss_discrete_critic": loss_discrete_critic}
|
||||
if model == "actor":
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
if model == "temperature":
|
||||
return {
|
||||
"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_discrete_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self.discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self.discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self.discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
self.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_critic = GaussianActorObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
self.encoder_critic if self.shared_encoder else GaussianActorObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critics()
|
||||
|
||||
def _init_discrete_critics(self):
|
||||
"""Build discrete discrete critic ensemble and target networks."""
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
"""Initialize policy actor network."""
|
||||
# NOTE: The actor select only the continuous action part
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
@@ -455,21 +143,25 @@ class SACPolicy(
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
def _init_discrete_critic(self) -> None:
|
||||
"""Initialize discrete critic network."""
|
||||
if self.config.num_discrete_actions is None:
|
||||
self.discrete_critic = None
|
||||
return
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
# TODO(Khalil): Compile the discrete critic
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
class GaussianActorObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig) -> None:
|
||||
def __init__(self, config: GaussianActorConfig) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._init_image_layers()
|
||||
@@ -677,84 +369,6 @@ class MLP(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: float | None = None,
|
||||
init_final: float | None = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (SACObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
|
||||
|
||||
class DiscreteCritic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -800,7 +414,7 @@ class DiscreteCritic(nn.Module):
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: SACObservationEncoder,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_min: float = -5,
|
||||
@@ -811,7 +425,7 @@ class Policy(nn.Module):
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder: SACObservationEncoder = encoder
|
||||
self.encoder: GaussianActorObservationEncoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_min = std_min
|
||||
@@ -885,7 +499,7 @@ class Policy(nn.Module):
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
image_key = next(key for key in config.input_features if is_image_feature(key))
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
@@ -931,12 +545,12 @@ def freeze_image_encoder(image_encoder: nn.Module):
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config: SACConfig):
|
||||
def __init__(self, config: GaussianActorConfig):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config: SACConfig):
|
||||
def _load_pretrained_vision_encoder(self, config: GaussianActorConfig):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
@@ -32,18 +32,18 @@ from lerobot.processor import (
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_sac import SACConfig
|
||||
from .configuration_gaussian_actor import GaussianActorConfig
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig,
|
||||
def make_gaussian_actor_pre_post_processors(
|
||||
config: GaussianActorConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the SAC policy.
|
||||
Constructs pre-processor and post-processor pipelines for the Gaussian actor policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
@@ -56,7 +56,7 @@ def make_sac_pre_post_processors(
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the SAC policy.
|
||||
config: The configuration object for the tanh-Gaussian policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
|
||||
Returns:
|
||||
@@ -1,3 +1,5 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,15 +15,14 @@
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import NormalizationMode
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs import NormalizationMode, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass(name="reward_classifier")
|
||||
@PreTrainedConfig.register_subclass(name="reward_classifier")
|
||||
@dataclass
|
||||
class RewardClassifierConfig(RewardModelConfig):
|
||||
class RewardClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Reward Classifier model."""
|
||||
|
||||
name: str = "reward_classifier"
|
||||
@@ -30,7 +31,7 @@ class RewardClassifierConfig(RewardModelConfig):
|
||||
latent_dim: int = 256
|
||||
image_embedding_pooling_dim: int = 8
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
|
||||
model_name: str = "lerobot/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
@@ -1,3 +1,5 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -17,10 +19,11 @@ import logging
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
|
||||
from ...pretrained import PreTrainedPolicy
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
@@ -96,7 +99,7 @@ class SpatialLearnedEmbeddings(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class Classifier(PreTrainedRewardModel):
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "reward_classifier"
|
||||
@@ -105,6 +108,7 @@ class Classifier(PreTrainedRewardModel):
|
||||
def __init__(
|
||||
self,
|
||||
config: RewardClassifierConfig,
|
||||
**kwargs,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
@@ -232,16 +236,6 @@ class Classifier(PreTrainedRewardModel):
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Returns 1.0 for success, 0.0 for failure based on image observations."""
|
||||
images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
output = self.predict(images)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
return (output.probabilities > 0.5).float()
|
||||
else:
|
||||
return torch.argmax(output.probabilities, dim=1).float()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Extract images and labels
|
||||
@@ -285,3 +279,28 @@ class Classifier(PreTrainedRewardModel):
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return self.parameters()
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not select actions")
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not produce action chunks.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not predict action chunks")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not select actions.
|
||||
"""
|
||||
pass
|
||||
@@ -1,3 +1,5 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -25,7 +27,8 @@ from lerobot.processor import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
from .configuration_classifier import RewardClassifierConfig
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
@@ -49,6 +52,8 @@ def make_classifier_processor(
|
||||
Args:
|
||||
config: The configuration object for the RewardClassifier.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
1
src/lerobot/policies/sarm/README.md
Symbolic link
1
src/lerobot/policies/sarm/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_sarm_README.md
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,6 +14,5 @@
|
||||
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .modeling_sarm import SARMRewardModel
|
||||
from .processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
__all__ = ["SARMConfig", "SARMRewardModel", "make_sarm_pre_post_processors"]
|
||||
__all__ = ["SARMConfig", "SARMRewardModel"]
|
||||
@@ -25,18 +25,18 @@ need ~num_frames/30 queries instead of one per frame (~30x speedup).
|
||||
|
||||
Usage:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Faster computation with stride (compute every 5 frames, interpolate the rest)
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--stride 5
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
@@ -58,9 +58,10 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
from lerobot.rewards.sarm.sarm_utils import normalize_stage_tau
|
||||
|
||||
from .modeling_sarm import SARMRewardModel
|
||||
from .processor_sarm import make_sarm_pre_post_processors
|
||||
from .sarm_utils import normalize_stage_tau
|
||||
|
||||
|
||||
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
|
||||
@@ -712,12 +713,12 @@ def main():
|
||||
epilog="""
|
||||
Examples:
|
||||
# Full RA-BC computation with visualizations
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4
|
||||
|
||||
# Visualize predictions only (no RA-BC computation)
|
||||
python src/lerobot/rewards/sarm/compute_rabc_weights.py \\
|
||||
python src/lerobot/policies/sarm/compute_rabc_weights.py \\
|
||||
--dataset-repo-id lerobot/aloha_sim_insertion_human \\
|
||||
--reward-model-path <USER>/sarm_single_uni4 \\
|
||||
--visualize-only \\
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -20,15 +22,14 @@ Paper: https://arxiv.org/abs/2509.25358
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("sarm")
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@dataclass
|
||||
class SARMConfig(RewardModelConfig):
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling).
|
||||
|
||||
Supports three annotation modes:
|
||||
@@ -109,6 +110,7 @@ class SARMConfig(RewardModelConfig):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.annotation_mode not in ["single_stage", "dense_only", "dual"]:
|
||||
raise ValueError(
|
||||
f"annotation_mode must be 'single_stage', 'dense_only', or 'dual', got {self.annotation_mode}"
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Qianzhong Chen, Justin Yu, Mac Schwager, Pieter Abbeel, Yide Shentu, Philipp Wu
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
@@ -32,13 +34,14 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.rewards.sarm.sarm_utils import (
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
|
||||
from ..pretrained import PreTrainedPolicy
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .sarm_utils import (
|
||||
normalize_stage_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
|
||||
|
||||
class StageTransformer(nn.Module):
|
||||
@@ -350,7 +353,7 @@ def gen_stage_emb(num_classes: int, targets: torch.Tensor) -> torch.Tensor:
|
||||
return stage_onehot
|
||||
|
||||
|
||||
class SARMRewardModel(PreTrainedRewardModel):
|
||||
class SARMRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
SARM Reward Model for stage-aware task completion rewards.
|
||||
|
||||
@@ -468,23 +471,6 @@ class SARMRewardModel(PreTrainedRewardModel):
|
||||
self.subtask_model.to(device)
|
||||
return self
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute dense progress reward in [0, 1] from batch.
|
||||
|
||||
Expects batch to contain:
|
||||
- "observation_features" or video embeddings: (B, T, 512)
|
||||
- "language_embedding" or text embeddings: (B, 512)
|
||||
- optionally "observation.state": (B, T, state_dim)
|
||||
"""
|
||||
text_emb = batch.get("language_embedding", batch.get("text_features"))
|
||||
video_emb = batch.get("observation_features", batch.get("video_features"))
|
||||
state = batch.get("observation.state", batch.get("state_features"))
|
||||
|
||||
rewards = self.calculate_rewards(text_emb, video_emb, state)
|
||||
if isinstance(rewards, np.ndarray):
|
||||
rewards = torch.from_numpy(rewards).float()
|
||||
return rewards
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_rewards(
|
||||
self,
|
||||
@@ -645,9 +631,17 @@ class SARMRewardModel(PreTrainedRewardModel):
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""SARM has no episode-level state to reset."""
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
pass
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
raise NotImplementedError("SARM model does not predict action chunks")
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _train_step(
|
||||
self,
|
||||
img_emb: torch.Tensor, # (B, N, T, D)
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -58,15 +60,16 @@ from lerobot.processor import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.rewards.sarm.sarm_utils import (
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
from .configuration_sarm import SARMConfig
|
||||
from .sarm_utils import (
|
||||
apply_rewind_augmentation,
|
||||
compute_absolute_indices,
|
||||
find_stage_and_tau,
|
||||
pad_state_to_max_dim,
|
||||
)
|
||||
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
@@ -1,3 +1,5 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -61,6 +61,7 @@ from .hil_processor import (
|
||||
RewardClassifierProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
)
|
||||
from .leader_follower_processor import LeaderFollowerProcessor
|
||||
from .newline_task_processor import NewLineTaskProcessorStep
|
||||
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
|
||||
from .observation_processor import VanillaObservationProcessorStep
|
||||
@@ -122,6 +123,7 @@ __all__ = [
|
||||
"ImageCropResizeProcessorStep",
|
||||
"InfoProcessorStep",
|
||||
"InterventionActionProcessorStep",
|
||||
"LeaderFollowerProcessor",
|
||||
"make_default_processors",
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
|
||||
@@ -38,6 +38,7 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
"""
|
||||
|
||||
use_gripper: bool = True
|
||||
use_rotation: bool = False
|
||||
|
||||
def action(self, action: PolicyAction) -> RobotAction:
|
||||
if not isinstance(action, PolicyAction):
|
||||
@@ -52,7 +53,13 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
"delta_y": action[1].item(),
|
||||
"delta_z": action[2].item(),
|
||||
}
|
||||
if self.use_gripper:
|
||||
if self.use_rotation:
|
||||
delta_action["delta_wx"] = action[3].item()
|
||||
delta_action["delta_wy"] = action[4].item()
|
||||
delta_action["delta_wz"] = action[5].item()
|
||||
if self.use_gripper:
|
||||
delta_action["gripper"] = action[6].item()
|
||||
elif self.use_gripper:
|
||||
delta_action["gripper"] = action[3].item()
|
||||
return delta_action
|
||||
|
||||
@@ -64,6 +71,12 @@ class MapTensorToDeltaActionDictStep(ActionProcessorStep):
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
if self.use_rotation:
|
||||
for axis in ["wx", "wy", "wz"]:
|
||||
features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
if self.use_gripper:
|
||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
@@ -90,6 +103,8 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
# Scale factors for delta movements
|
||||
position_scale: float = 1.0
|
||||
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
|
||||
use_rotation: bool = False
|
||||
rotation_scale: float = 1.0
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
# NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy
|
||||
@@ -97,23 +112,34 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
delta_x = action.pop("delta_x")
|
||||
delta_y = action.pop("delta_y")
|
||||
delta_z = action.pop("delta_z")
|
||||
if self.use_rotation:
|
||||
delta_wx = action.pop("delta_wx")
|
||||
delta_wy = action.pop("delta_wy")
|
||||
delta_wz = action.pop("delta_wz")
|
||||
else:
|
||||
delta_wx = 0.0
|
||||
delta_wy = 0.0
|
||||
delta_wz = 0.0
|
||||
gripper = action.pop("gripper")
|
||||
|
||||
# Determine if the teleoperator is actively providing input
|
||||
# Consider enabled if any significant movement delta is detected
|
||||
position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position
|
||||
enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise
|
||||
rotation_magnitude = (
|
||||
delta_wx**2 + delta_wy**2 + delta_wz**2
|
||||
) ** 0.5 # TODO use proper magnitud for rotation
|
||||
enabled = (
|
||||
position_magnitude > self.noise_threshold or rotation_magnitude > self.noise_threshold
|
||||
) # Small threshold to avoid noise
|
||||
|
||||
# Scale the deltas appropriately
|
||||
scaled_delta_x = delta_x * self.position_scale
|
||||
scaled_delta_y = delta_y * self.position_scale
|
||||
scaled_delta_z = delta_z * self.position_scale
|
||||
|
||||
# For gamepad/keyboard, we don't have rotation input, so set to 0
|
||||
# These could be extended in the future for more sophisticated teleoperators
|
||||
target_wx = 0.0
|
||||
target_wy = 0.0
|
||||
target_wz = 0.0
|
||||
target_wx = delta_wx * self.rotation_scale
|
||||
target_wy = delta_wy * self.rotation_scale
|
||||
target_wz = delta_wz * self.rotation_scale
|
||||
|
||||
# Update action with robot target format
|
||||
action = {
|
||||
@@ -132,9 +158,15 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
for axis in ["x", "y", "z", "gripper"]:
|
||||
for axis in ["x", "y", "z"]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None)
|
||||
|
||||
if self.use_rotation:
|
||||
for axis in ["wx", "wy", "wz"]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None)
|
||||
|
||||
features[PipelineFeatureType.ACTION].pop("delta_gripper", None)
|
||||
|
||||
for feat in ["enabled", "target_x", "target_y", "target_z", "target_wx", "target_wy", "target_wz"]:
|
||||
features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
|
||||
@@ -321,6 +321,7 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
This step normalizes the `transition` object by:
|
||||
1. Copying `teleop_action` from `info` to `complementary_data`.
|
||||
2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key).
|
||||
3. Copying `discrete_penalty` from `info` to `complementary_data`.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -330,6 +331,9 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
if TELEOP_ACTION_KEY in info:
|
||||
complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY]
|
||||
|
||||
if DISCRETE_PENALTY_KEY in info:
|
||||
complementary_data[DISCRETE_PENALTY_KEY] = info[DISCRETE_PENALTY_KEY]
|
||||
|
||||
if "is_intervention" in info:
|
||||
info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"]
|
||||
|
||||
@@ -348,18 +352,24 @@ class GymHILAdapterProcessorStep(ProcessorStep):
|
||||
@ProcessorStepRegistry.register("gripper_penalty_processor")
|
||||
class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Applies a penalty for inefficient gripper usage.
|
||||
Applies a small per-transition cost on the discrete gripper action.
|
||||
|
||||
This step penalizes actions that attempt to close an already closed gripper or
|
||||
open an already open one, based on position thresholds.
|
||||
Fires only when the commanded action would actually transition the gripper
|
||||
from one extreme to the other (close-while-open or open-while-closed).
|
||||
This discourages gripper oscillation while leaving "stay" and saturating-further
|
||||
commands unpenalized.
|
||||
|
||||
Attributes:
|
||||
penalty: The negative reward value to apply.
|
||||
max_gripper_pos: The maximum position value for the gripper, used for normalization.
|
||||
open_threshold: Normalized state below which the gripper is considered "open".
|
||||
closed_threshold: Normalized state above which the gripper is considered "closed".
|
||||
"""
|
||||
|
||||
penalty: float = -0.01
|
||||
penalty: float = -0.02
|
||||
max_gripper_pos: float = 30.0
|
||||
open_threshold: float = 0.1
|
||||
closed_threshold: float = 0.9
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
@@ -391,9 +401,13 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
gripper_state_normalized = current_gripper_pos / self.max_gripper_pos
|
||||
|
||||
# Calculate penalty boolean as in original
|
||||
gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or (
|
||||
gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5
|
||||
)
|
||||
# - currently open AND target is closed -> close transition
|
||||
# - currently closed AND target is open -> open transition
|
||||
is_open = gripper_state_normalized < self.open_threshold
|
||||
is_closed = gripper_state_normalized > self.closed_threshold
|
||||
cmd_close = gripper_action_normalized > self.closed_threshold
|
||||
cmd_open = gripper_action_normalized < self.open_threshold
|
||||
gripper_penalty_bool = (is_open and cmd_close) or (is_closed and cmd_open)
|
||||
|
||||
gripper_penalty = self.penalty * int(gripper_penalty_bool)
|
||||
|
||||
@@ -409,11 +423,14 @@ class GripperPenaltyProcessorStep(ProcessorStep):
|
||||
Returns the configuration of the step for serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the penalty value and max gripper position.
|
||||
A dictionary containing the penalty value, max gripper position,
|
||||
and the open/closed thresholds.
|
||||
"""
|
||||
return {
|
||||
"penalty": self.penalty,
|
||||
"max_gripper_pos": self.max_gripper_pos,
|
||||
"open_threshold": self.open_threshold,
|
||||
"closed_threshold": self.closed_threshold,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -444,6 +461,7 @@ class InterventionActionProcessorStep(ProcessorStep):
|
||||
|
||||
use_gripper: bool = False
|
||||
terminate_on_success: bool = True
|
||||
use_rotation: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""
|
||||
@@ -480,6 +498,14 @@ class InterventionActionProcessorStep(ProcessorStep):
|
||||
teleop_action.get("delta_y", 0.0),
|
||||
teleop_action.get("delta_z", 0.0),
|
||||
]
|
||||
if self.use_rotation:
|
||||
action_list.extend(
|
||||
[
|
||||
teleop_action.get("delta_wx", 0.0),
|
||||
teleop_action.get("delta_wy", 0.0),
|
||||
teleop_action.get("delta_wz", 0.0),
|
||||
]
|
||||
)
|
||||
if self.use_gripper:
|
||||
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
|
||||
elif isinstance(teleop_action, np.ndarray):
|
||||
@@ -557,7 +583,7 @@ class RewardClassifierProcessorStep(ProcessorStep):
|
||||
def __post_init__(self):
|
||||
"""Initializes the reward classifier model after the dataclass is created."""
|
||||
if self.pretrained_path is not None:
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
from lerobot.policies.gaussian_actor.reward_model.modeling_classifier import Classifier
|
||||
|
||||
self.reward_classifier = Classifier.from_pretrained(self.pretrained_path)
|
||||
self.reward_classifier.to(self.device)
|
||||
|
||||
243
src/lerobot/processor/leader_follower_processor.py
Normal file
243
src/lerobot/processor/leader_follower_processor.py
Normal file
@@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
from .pipeline import ProcessorStep
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("leader_follower_processor")
|
||||
@dataclass
|
||||
class LeaderFollowerProcessor(ProcessorStep):
|
||||
"""
|
||||
Processor for leader-follower teleoperation mode.
|
||||
|
||||
This processor:
|
||||
1. Sends follower positions to leader arm when not intervening
|
||||
2. Computes EE delta actions from leader when intervening
|
||||
3. Handles teleop events from the leader device
|
||||
"""
|
||||
|
||||
leader_device: Teleoperator
|
||||
motor_names: list[str]
|
||||
robot: Robot
|
||||
kinematics: RobotKinematics
|
||||
end_effector_step_sizes: np.ndarray | None = None
|
||||
use_gripper: bool = True
|
||||
# prev_leader_gripper: float | None = None
|
||||
max_gripper_pos: float = 100.0
|
||||
use_ik_solution: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Process transition with leader-follower logic."""
|
||||
# Get current follower position from complementary data
|
||||
# raw_joint_pos = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get("raw_joint_positions")
|
||||
raw_joint_pos = transition.get(TransitionKey.OBSERVATION)
|
||||
if raw_joint_pos is not None:
|
||||
# Send follower position to leader (for follow mode)
|
||||
# follower_action = {
|
||||
# f"{motor}.pos": float(raw_joint_pos[motor])
|
||||
# for motor in self.motor_names
|
||||
# }
|
||||
self.leader_device.send_action(raw_joint_pos)
|
||||
|
||||
# Only compute EE action if intervention is active
|
||||
# (AddTeleopEventsAsInfo already added IS_INTERVENTION to info)
|
||||
info = transition.get(TransitionKey.INFO, {})
|
||||
if info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
# Get leader joint positions from teleop_action
|
||||
# (AddTeleopActionAsComplimentaryData already got the action)
|
||||
complementary = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
teleop_action = complementary.get("teleop_action", {})
|
||||
|
||||
if isinstance(teleop_action, dict) and raw_joint_pos is not None:
|
||||
leader_pos = np.array([teleop_action[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
leader_ee = self.kinematics.forward_kinematics(leader_pos)
|
||||
|
||||
if self.use_ik_solution and "IK_solution" in transition.get(TransitionKey.COMPLEMENTARY_DATA):
|
||||
follower_pos = transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"]
|
||||
else:
|
||||
follower_pos = np.array([raw_joint_pos[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
follower_ee = self.kinematics.forward_kinematics(follower_pos)
|
||||
|
||||
# follower_gripper_pos = raw_joint_pos["gripper.pos"]
|
||||
follower_gripper_pos = follower_pos[-1] # assuming gripper is the last motor
|
||||
|
||||
leader_ee_pos = leader_ee[:3, 3]
|
||||
leader_ee_rvec = Rotation.from_matrix(leader_ee[:3, :3]).as_rotvec()
|
||||
leader_gripper_pos = np.clip(
|
||||
teleop_action["gripper.pos"], -self.max_gripper_pos, self.max_gripper_pos
|
||||
)
|
||||
|
||||
follower_ee_pos = follower_ee[:3, 3]
|
||||
# follower_ee_rvec = Rotation.from_matrix(follower_ee[:3, :3]).as_rotvec()
|
||||
|
||||
delta_pos = leader_ee_pos - follower_ee_pos
|
||||
|
||||
# For rotation: compute relative rotation from follower to leader
|
||||
# R_leader = R_follower * R_delta => R_delta = R_follower^T * R_leader
|
||||
r_delta = follower_ee[:3, :3].T @ leader_ee[:3, :3]
|
||||
delta_rvec = Rotation.from_matrix(r_delta).as_rotvec()
|
||||
|
||||
delta_gripper = leader_gripper_pos - follower_gripper_pos
|
||||
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = follower_ee[:3, :3] @ r_delta
|
||||
desired[:3, 3] = follower_ee[:3, 3] + delta_pos
|
||||
|
||||
pos = desired[:3, 3]
|
||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||
|
||||
assert np.allclose(pos, leader_ee_pos), "Position delta computation error"
|
||||
assert np.allclose(tw, leader_ee_rvec), "Orientation delta computation error"
|
||||
assert np.isclose(follower_gripper_pos + delta_gripper, leader_gripper_pos), (
|
||||
"Gripper delta computation error"
|
||||
)
|
||||
|
||||
# Normalize the action to the range [-1, 1]
|
||||
delta_pos = delta_pos / np.array(
|
||||
[
|
||||
self.end_effector_step_sizes["x"],
|
||||
self.end_effector_step_sizes["y"],
|
||||
self.end_effector_step_sizes["z"],
|
||||
]
|
||||
)
|
||||
delta_rvec = delta_rvec / np.array(
|
||||
[
|
||||
self.end_effector_step_sizes["wx"],
|
||||
self.end_effector_step_sizes["wy"],
|
||||
self.end_effector_step_sizes["wz"],
|
||||
]
|
||||
)
|
||||
max_normalized_pos = max(
|
||||
abs(delta_pos[0]),
|
||||
abs(delta_pos[1]),
|
||||
abs(delta_pos[2]),
|
||||
)
|
||||
|
||||
normalized_rot = max(abs(delta_rvec[0]), abs(delta_rvec[1]), abs(delta_rvec[2]))
|
||||
|
||||
max_normalized = max(max_normalized_pos, normalized_rot)
|
||||
|
||||
if max_normalized > 1.0:
|
||||
# Scale proportionally
|
||||
delta_pos = delta_pos / max_normalized
|
||||
delta_rvec = delta_rvec / max_normalized
|
||||
|
||||
intervention_action = np.array(
|
||||
[
|
||||
delta_pos[0],
|
||||
delta_pos[1],
|
||||
delta_pos[2],
|
||||
delta_rvec[0],
|
||||
delta_rvec[1],
|
||||
delta_rvec[2],
|
||||
np.clip(delta_gripper, -self.max_gripper_pos, self.max_gripper_pos)
|
||||
/ self.max_gripper_pos,
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
# # Extract leader positions from teleop action dict
|
||||
# # leader_pos = np.array([teleop_action.get(f"{motor}.pos", 0) for motor in self.motor_names])
|
||||
# # follower_pos = np.array([raw_joint_pos[f"{motor}.pos"] for motor in self.motor_names])
|
||||
|
||||
# teleop_action = self.leader_device.bus.sync_read("Present_Position")
|
||||
# raw_joint_pos = self.robot.bus.sync_read("Present_Position")
|
||||
# leader_pos = np.array([teleop_action.get(f"{motor}", 0) for motor in self.motor_names])
|
||||
# follower_pos = np.array([raw_joint_pos[f"{motor}"] for motor in self.motor_names])
|
||||
|
||||
# # Compute EE positions
|
||||
# leader_ee_fi = self.kinematics.forward_kinematics(leader_pos)
|
||||
# leader_ee_pos = leader_ee_fi[:3, 3]
|
||||
# # leader_ee_rot = Rotation.from_matrix(leader_ee_fi[:3, :3]).as_rotvec()
|
||||
# leader_ee = np.concat([leader_ee_pos, [0,0,0]])
|
||||
|
||||
# if "IK_solution" in transition.get(TransitionKey.COMPLEMENTARY_DATA):
|
||||
# follower_ee = transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"]
|
||||
# else:
|
||||
# follower_pos = np.array([raw_joint_pos[f"{motor}.pos"] for motor in self.motor_names])
|
||||
# follower_ee_fi = self.kinematics.forward_kinematics(follower_pos)
|
||||
# follower_ee_pos = follower_ee_fi[:3, 3]
|
||||
# # follower_ee_rot = Rotation.from_matrix(follower_ee_fi[:3, :3]).as_rotvec()
|
||||
# follower_ee = np.concat([follower_ee_pos, [0,0,0]])
|
||||
|
||||
# # Compute normalized EE delta
|
||||
# if self.end_effector_step_sizes is not None:
|
||||
# ee_delta = np.clip(
|
||||
# leader_ee - follower_ee,
|
||||
# -self.end_effector_step_sizes,
|
||||
# self.end_effector_step_sizes
|
||||
# )
|
||||
# ee_delta_normalized = ee_delta / self.end_effector_step_sizes
|
||||
# else:
|
||||
# ee_delta_normalized = leader_ee - follower_ee
|
||||
|
||||
# # Handle gripper
|
||||
# if self.use_gripper and len(leader_pos) > 3:
|
||||
# if self.prev_leader_gripper is None:
|
||||
# self.prev_leader_gripper = np.clip(
|
||||
# leader_pos[-1], 0, self.max_gripper_pos
|
||||
# )
|
||||
|
||||
# leader_gripper = leader_pos[-1]
|
||||
# gripper_delta = leader_gripper - self.prev_leader_gripper
|
||||
# normalized_delta = gripper_delta / self.max_gripper_pos
|
||||
|
||||
# # Quantize gripper action
|
||||
# if normalized_delta >= 0.3:
|
||||
# gripper_action = 2
|
||||
# elif normalized_delta <= -0.1:
|
||||
# gripper_action = 0
|
||||
# else:
|
||||
# gripper_action = 1
|
||||
|
||||
# self.prev_leader_gripper = leader_gripper
|
||||
|
||||
# # Create intervention action
|
||||
# intervention_action = np.append(ee_delta_normalized, gripper_action)
|
||||
# else:
|
||||
# intervention_action = ee_delta_normalized
|
||||
|
||||
# # Override teleop_action with computed EE action
|
||||
complementary["teleop_action"] = torch.from_numpy(intervention_action).float()
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary # type: ignore[misc]
|
||||
|
||||
return transition
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset leader-follower state."""
|
||||
# self.prev_leader_gripper = None
|
||||
if hasattr(self.leader_device, "reset"):
|
||||
self.leader_device.reset()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
@@ -134,6 +134,15 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -152,6 +161,7 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -201,6 +211,7 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -211,6 +222,7 @@ class _NormalizationMixin:
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .factory import (
|
||||
get_reward_model_class as get_reward_model_class,
|
||||
make_reward_model as make_reward_model,
|
||||
make_reward_model_config as make_reward_model_config,
|
||||
make_reward_pre_post_processors as make_reward_pre_post_processors,
|
||||
)
|
||||
from .pretrained import PreTrainedRewardModel as PreTrainedRewardModel
|
||||
from .sarm.configuration_sarm import SARMConfig as SARMConfig
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"RewardClassifierConfig",
|
||||
"SARMConfig",
|
||||
# Base class
|
||||
"PreTrainedRewardModel",
|
||||
# Factory functions
|
||||
"get_reward_model_class",
|
||||
"make_reward_model",
|
||||
"make_reward_model_config",
|
||||
"make_reward_pre_post_processors",
|
||||
]
|
||||
@@ -1,238 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.rewards.classifier.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.rewards.sarm.configuration_sarm import SARMConfig
|
||||
|
||||
|
||||
def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
"""
|
||||
Retrieves a reward model class by its registered name.
|
||||
|
||||
This function uses dynamic imports to avoid loading all reward model classes into
|
||||
memory at once, improving startup time and reducing dependencies.
|
||||
|
||||
Args:
|
||||
name: The name of the reward model. Supported names are "reward_classifier",
|
||||
"sarm".
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
|
||||
Raises:
|
||||
ValueError: If the reward model name is not recognized.
|
||||
"""
|
||||
if name == "reward_classifier":
|
||||
from lerobot.rewards.classifier.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
elif name == "sarm":
|
||||
from lerobot.rewards.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Reward model type '{name}' is not available.") from e
|
||||
|
||||
|
||||
def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
"""
|
||||
Instantiates a reward model configuration object based on the reward type.
|
||||
|
||||
This factory function simplifies the creation of reward model configuration objects
|
||||
by mapping a string identifier to the corresponding config class.
|
||||
|
||||
Args:
|
||||
reward_type: The type of the reward model. Supported types include
|
||||
"reward_classifier", "sarm".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of a `RewardModelConfig` subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If the `reward_type` is not recognized.
|
||||
"""
|
||||
if reward_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif reward_type == "sarm":
|
||||
return SARMConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
return config_cls(**kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Reward model type '{reward_type}' is not available.") from e
|
||||
|
||||
|
||||
def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel:
|
||||
"""
|
||||
Instantiate a reward model from its configuration.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the reward model to be created. If
|
||||
`cfg.pretrained_path` is set, the model will be loaded with weights
|
||||
from that path.
|
||||
**kwargs: Additional keyword arguments forwarded to the model constructor
|
||||
(e.g., ``dataset_stats``, ``dataset_meta``).
|
||||
|
||||
Returns:
|
||||
An instantiated and device-placed reward model.
|
||||
"""
|
||||
reward_cls = get_reward_model_class(cfg.type)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
if cfg.pretrained_path:
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
reward_model = reward_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
reward_model = reward_cls(**kwargs)
|
||||
|
||||
reward_model.to(cfg.device)
|
||||
assert isinstance(reward_model, torch.nn.Module)
|
||||
|
||||
return reward_model
|
||||
|
||||
|
||||
def make_reward_pre_post_processors(
|
||||
reward_cfg: RewardModelConfig,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Create pre- and post-processor pipelines for a given reward model.
|
||||
|
||||
Each reward model type has a dedicated factory function for its processors.
|
||||
|
||||
Args:
|
||||
reward_cfg: The configuration of the reward model for which to create processors.
|
||||
**kwargs: Additional keyword arguments passed to the processor factory
|
||||
(e.g., ``dataset_stats``, ``dataset_meta``).
|
||||
|
||||
Returns:
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
|
||||
Raises:
|
||||
ValueError: If a processor factory is not implemented for the given reward
|
||||
model configuration type.
|
||||
"""
|
||||
# Create a new processor based on reward model type
|
||||
if isinstance(reward_cfg, RewardClassifierConfig):
|
||||
from lerobot.rewards.classifier.processor_classifier import make_classifier_processor
|
||||
|
||||
return make_classifier_processor(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, SARMConfig):
|
||||
from lerobot.rewards.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
return make_sarm_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Processor for reward model type '{reward_cfg.type}' is not implemented."
|
||||
) from e
|
||||
return processors
|
||||
|
||||
|
||||
def _get_reward_model_cls_from_name(name: str) -> type[PreTrainedRewardModel]:
|
||||
"""Get reward model class from its registered name using dynamic imports.
|
||||
|
||||
This is used as a helper function to import reward models from 3rd party lerobot
|
||||
plugins.
|
||||
|
||||
Args:
|
||||
name: The name of the reward model.
|
||||
|
||||
Returns:
|
||||
The reward model class corresponding to the given name.
|
||||
"""
|
||||
if name not in RewardModelConfig.get_known_choices():
|
||||
raise ValueError(
|
||||
f"Unknown reward model name '{name}'. "
|
||||
f"Available reward models: {RewardModelConfig.get_known_choices()}"
|
||||
)
|
||||
|
||||
config_cls = RewardModelConfig.get_choice_class(name)
|
||||
config_cls_name = config_cls.__name__
|
||||
|
||||
model_name = config_cls_name.removesuffix("Config")
|
||||
if model_name == config_cls_name:
|
||||
raise ValueError(
|
||||
f"The config class name '{config_cls_name}' does not follow the expected naming convention. "
|
||||
f"Make sure it ends with 'Config'!"
|
||||
)
|
||||
|
||||
cls_name = model_name + "RewardModel"
|
||||
module_path = config_cls.__module__.replace("configuration_", "modeling_")
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
reward_cls = getattr(module, cls_name)
|
||||
return reward_cls
|
||||
|
||||
|
||||
def _make_processors_from_reward_model_config(
|
||||
config: RewardModelConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Create pre- and post-processors from a reward model configuration using dynamic imports.
|
||||
|
||||
This is used as a helper function to import processor factories from 3rd party
|
||||
lerobot reward model plugins.
|
||||
|
||||
Args:
|
||||
config: The reward model configuration object.
|
||||
dataset_stats: Dataset statistics for normalization.
|
||||
|
||||
Returns:
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
"""
|
||||
reward_type = config.type
|
||||
function_name = f"make_{reward_type}_pre_post_processors"
|
||||
module_path = config.__class__.__module__.replace("configuration_", "processor_")
|
||||
logging.debug(
|
||||
f"Instantiating reward pre/post processors using function '{function_name}' "
|
||||
f"from module '{module_path}'"
|
||||
)
|
||||
module = importlib.import_module(module_path)
|
||||
function = getattr(module, function_name)
|
||||
return function(config, dataset_stats=dataset_stats)
|
||||
@@ -1,244 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedRewardModel")
|
||||
|
||||
|
||||
class PreTrainedRewardModel(nn.Module, HubMixin, abc.ABC):
|
||||
"""Base class for reward models."""
|
||||
|
||||
config_class: None
|
||||
name: None
|
||||
|
||||
def __init__(self, config: RewardModelConfig, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, RewardModelConfig):
|
||||
raise ValueError(
|
||||
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
||||
"`RewardModelConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "config_class", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: RewardModelConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
"""
|
||||
The reward model is set in evaluation mode by default using `reward.eval()` (dropout modules are
|
||||
deactivated). To train it, you should first set it back in training mode with `reward.train()`.
|
||||
"""
|
||||
if config is None:
|
||||
config = RewardModelConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
if os.path.isdir(model_id):
|
||||
print("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
|
||||
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=SAFETENSORS_SINGLE_FILE,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
reward = cls._load_as_safetensor(instance, model_file, config.device or "cpu", strict)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
reward.to(config.device)
|
||||
reward.eval()
|
||||
return reward
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
# Create base kwargs
|
||||
kwargs = {"strict": strict}
|
||||
|
||||
# Add device parameter for newer versions that support it
|
||||
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
|
||||
kwargs["device"] = map_location
|
||||
|
||||
# Load the model with appropriate kwargs
|
||||
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
|
||||
if missing_keys:
|
||||
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
||||
|
||||
# For older versions, manually move to device if needed
|
||||
if "device" not in kwargs and map_location != "cpu":
|
||||
logging.warning(
|
||||
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||
" This leads to a slower loading time."
|
||||
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||
)
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
def get_optim_params(self):
|
||||
"""
|
||||
Returns the reward-model-specific parameters dict to be passed on to the optimizer.
|
||||
"""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset any internal state."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute a scalar reward signal for a batch of observations.
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing at minimum observation tensors.
|
||||
May also contain "action", "next_observation.*", etc.
|
||||
|
||||
Returns:
|
||||
Tensor of shape ``(batch_size,)`` with reward values.
|
||||
"""
|
||||
...
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Training forward pass — override for trainable reward models."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} is not trainable. Only use compute_reward() for inference."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
"""Whether this reward model can be trained via ``lerobot-train``.
|
||||
|
||||
Trainable reward models override :meth:`forward`; zero-shot models
|
||||
inherit the base implementation that raises ``NotImplementedError``.
|
||||
"""
|
||||
return type(self).forward is not PreTrainedRewardModel.forward
|
||||
|
||||
def push_model_to_hub(self, cfg: "TrainPipelineConfig"):
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
repo_id=self.config.repo_id, private=self.config.private, exist_ok=True
|
||||
).repo_id
|
||||
|
||||
# Push the files to the repo in a single commit
|
||||
with TemporaryDirectory(ignore_cleanup_errors=True) as tmp:
|
||||
saved_path = Path(tmp) / repo_id
|
||||
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config
|
||||
|
||||
commit_info = api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
folder_path=saved_path,
|
||||
commit_message="Upload reward model weights, train config and readme",
|
||||
allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"],
|
||||
ignore_patterns=["*.tmp", "*.log"],
|
||||
)
|
||||
|
||||
logging.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
|
||||
def generate_model_card(
|
||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||
) -> ModelCard:
|
||||
card_data = ModelCardData(
|
||||
license=license or "apache-2.0",
|
||||
library_name="lerobot",
|
||||
pipeline_tag="robotics",
|
||||
tags=list(set(tags or []).union({"robotics", "lerobot", "reward-model", model_type})),
|
||||
model_name=model_type,
|
||||
datasets=dataset_repo_id,
|
||||
)
|
||||
|
||||
template_card = (
|
||||
files("lerobot.templates")
|
||||
.joinpath("lerobot_rewardmodel_modelcard_template.md")
|
||||
.read_text(encoding="utf-8")
|
||||
)
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
@@ -12,23 +12,33 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Reinforcement learning modules.
|
||||
"""Reinforcement learning modules.
|
||||
|
||||
Requires: ``pip install 'lerobot[hilserl]'``
|
||||
|
||||
Available modules (import directly)::
|
||||
|
||||
from lerobot.rl.actor import ...
|
||||
from lerobot.rl.learner import ...
|
||||
from lerobot.rl.learner_service import ...
|
||||
from lerobot.rl.buffer import ...
|
||||
from lerobot.rl.eval_policy import ...
|
||||
from lerobot.rl.gym_manipulator import ...
|
||||
Distributed actor / learner entry points (``actor``, ``learner``,
|
||||
``learner_service``) require ``pip install 'lerobot[hilserl]'``. Algorithms,
|
||||
buffer, data sources and trainer are gRPC-free and usable standalone.
|
||||
"""
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from .algorithms.base import RLAlgorithm as RLAlgorithm
|
||||
from .algorithms.configs import RLAlgorithmConfig as RLAlgorithmConfig, TrainingStats as TrainingStats
|
||||
from .algorithms.factory import (
|
||||
make_algorithm as make_algorithm,
|
||||
make_algorithm_config as make_algorithm_config,
|
||||
)
|
||||
from .algorithms.sac.configuration_sac import SACAlgorithmConfig as SACAlgorithmConfig
|
||||
from .buffer import ReplayBuffer as ReplayBuffer
|
||||
from .data_sources import DataMixer as DataMixer, OnlineOfflineMixer as OnlineOfflineMixer
|
||||
from .trainer import RLTrainer as RLTrainer
|
||||
|
||||
require_package("grpcio", extra="hilserl", import_name="grpc")
|
||||
|
||||
__all__: list[str] = []
|
||||
__all__ = [
|
||||
"RLAlgorithm",
|
||||
"RLAlgorithmConfig",
|
||||
"TrainingStats",
|
||||
"make_algorithm",
|
||||
"make_algorithm_config",
|
||||
"SACAlgorithmConfig",
|
||||
"RLTrainer",
|
||||
"ReplayBuffer",
|
||||
"DataMixer",
|
||||
"OnlineOfflineMixer",
|
||||
]
|
||||
|
||||
@@ -51,17 +51,19 @@ import os
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
from typing import Any
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
@@ -74,14 +76,12 @@ from lerobot.transport.utils import (
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
@@ -90,12 +90,11 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
from .gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
reset_and_build_transition,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
from .queue import get_last_item_from_queue
|
||||
|
||||
# Main entry point
|
||||
|
||||
@@ -212,7 +211,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Any, # Event
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
@@ -252,22 +251,21 @@ def act_with_policy(
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### To avoid sending a policy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
policy = policy.to(device).eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
@@ -291,8 +289,17 @@ def act_with_policy(
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
normalized_observation = preprocessor.process_observation(observation)
|
||||
action = policy.select_action(batch=normalized_observation)
|
||||
# Unnormalize only the continuous part.
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
continuous_action = postprocessor.process_action(action[..., :-1])
|
||||
discrete_action = action[..., -1:].to(
|
||||
device=continuous_action.device, dtype=continuous_action.dtype
|
||||
)
|
||||
action = torch.cat([continuous_action, discrete_action], dim=-1)
|
||||
else:
|
||||
action = postprocessor.process_action(action)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
@@ -326,7 +333,8 @@ def act_with_policy(
|
||||
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
is_intervention = bool(intervention_info.get(TeleopEvents.IS_INTERVENTION, False))
|
||||
if is_intervention:
|
||||
episode_intervention = True
|
||||
episode_intervention_steps += 1
|
||||
|
||||
@@ -334,6 +342,7 @@ def act_with_policy(
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
TeleopEvents.IS_INTERVENTION.value: is_intervention,
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
@@ -390,14 +399,7 @@ def act_with_policy(
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
transition = reset_and_build_transition(online_env, env_processor, action_processor)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
@@ -409,7 +411,7 @@ def act_with_policy(
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: services_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
shutdown_event: Any, # Event
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
@@ -461,7 +463,7 @@ def learner_service_client(
|
||||
def receive_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
shutdown_event: Any, # Event
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
@@ -513,7 +515,7 @@ def receive_policy(
|
||||
def send_transitions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Any, # Event
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
@@ -563,7 +565,7 @@ def send_transitions(
|
||||
def send_interactions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
shutdown_event: Any, # Event
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
@@ -613,7 +615,11 @@ def send_interactions(
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
|
||||
def transitions_stream(
|
||||
shutdown_event: Any, # Event
|
||||
transitions_queue: Queue,
|
||||
timeout: float,
|
||||
) -> services_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=timeout)
|
||||
@@ -629,9 +635,9 @@ def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout:
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: Event,
|
||||
shutdown_event: Any, # Event
|
||||
interactions_queue: Queue,
|
||||
timeout: float, # type: ignore
|
||||
timeout: float,
|
||||
) -> services_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
@@ -652,7 +658,7 @@ def interactions_stream(
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
@@ -667,18 +673,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
||||
# - Send critic's encoder state when shared_encoder=True
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
|
||||
# Load actor state dict
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
policy.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
# Load discrete critic if present
|
||||
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
policy.load_actor_weights(state_dicts, device=device)
|
||||
|
||||
|
||||
# Utilities functions
|
||||
|
||||
20
src/lerobot/rl/algorithms/__init__.py
Normal file
20
src/lerobot/rl/algorithms/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .sac import SACAlgorithm as SACAlgorithm, SACAlgorithmConfig as SACAlgorithmConfig
|
||||
|
||||
__all__ = [
|
||||
"SACAlgorithm",
|
||||
"SACAlgorithmConfig",
|
||||
]
|
||||
106
src/lerobot/rl/algorithms/base.py
Normal file
106
src/lerobot/rl/algorithms/base.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.data_sources.data_mixer import DataMixer
|
||||
|
||||
BatchType = dict[str, Any]
|
||||
|
||||
|
||||
class RLAlgorithm(abc.ABC):
|
||||
"""Base for all RL algorithms."""
|
||||
|
||||
config_class: type[RLAlgorithmConfig] | None = None
|
||||
name: str | None = None
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if not getattr(cls, "config_class", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One complete training step.
|
||||
|
||||
The algorithm calls ``next(batch_iterator)`` as many times as it
|
||||
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
|
||||
The iterator is owned by the trainer; the algorithm just consumes
|
||||
from it.
|
||||
"""
|
||||
...
|
||||
|
||||
def configure_data_iterator(
|
||||
self,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
) -> Iterator[BatchType]:
|
||||
"""Create the data iterator this algorithm needs.
|
||||
|
||||
The default implementation uses the standard ``data_mixer.get_iterator()``.
|
||||
Algorithms that need specialised sampling should override this method.
|
||||
"""
|
||||
return data_mixer.get_iterator(
|
||||
batch_size=batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
|
||||
"""Create, store, and return the optimizers needed for training.
|
||||
|
||||
Called on the **learner** side after construction. Subclasses must
|
||||
override this with algorithm-specific optimizer setup.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Return optimizers for checkpointing / external scheduling."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def optimization_step(self) -> int:
|
||||
"""Current learner optimization step.
|
||||
|
||||
Part of the stable contract for checkpoint/resume. Algorithms can
|
||||
either use this default storage or override for custom behavior.
|
||||
"""
|
||||
return getattr(self, "_optimization_step", 0)
|
||||
|
||||
@optimization_step.setter
|
||||
def optimization_step(self, value: int) -> None:
|
||||
self._optimization_step = int(value)
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Policy state-dict to push to actors."""
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner."""
|
||||
76
src/lerobot/rl/algorithms/configs.py
Normal file
76
src/lerobot/rl/algorithms/configs.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingStats:
|
||||
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
|
||||
|
||||
losses: dict[str, float] = field(default_factory=dict)
|
||||
grad_norms: dict[str, float] = field(default_factory=dict)
|
||||
extra: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_log_dict(self) -> dict[str, float]:
|
||||
"""Flatten all stats into a single dict for logging."""
|
||||
|
||||
d: dict[str, float] = {}
|
||||
for name, val in self.losses.items():
|
||||
d[name] = val
|
||||
for name, val in self.grad_norms.items():
|
||||
d[f"{name}_grad_norm"] = val
|
||||
for name, val in self.extra.items():
|
||||
d[name] = val
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLAlgorithmConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
"""Registry for algorithm configs."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Registered name of this algorithm config (e.g. ``"sac"``)."""
|
||||
choice_name = self.get_choice_name(self.__class__)
|
||||
if not isinstance(choice_name, str):
|
||||
raise TypeError(f"Expected string from get_choice_name, got {type(choice_name)}")
|
||||
return choice_name
|
||||
|
||||
@abc.abstractmethod
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
"""Construct the :class:`RLAlgorithm` for this config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()")
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
|
||||
"""Build an algorithm config from a policy config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
|
||||
47
src/lerobot/rl/algorithms/factory.py
Normal file
47
src/lerobot/rl/algorithms/factory.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
|
||||
|
||||
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
|
||||
"""Instantiate an `RLAlgorithmConfig` from its registered type name.
|
||||
|
||||
Args:
|
||||
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
|
||||
**kwargs: Keyword arguments forwarded to the config class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the matching ``RLAlgorithmConfig`` subclass.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``algorithm_type`` is not registered.
|
||||
"""
|
||||
try:
|
||||
cls = RLAlgorithmConfig.get_choice_class(algorithm_type)
|
||||
except KeyError as err:
|
||||
raise ValueError(
|
||||
f"Algorithm type '{algorithm_type}' is not registered. "
|
||||
f"Available: {list(RLAlgorithmConfig.get_known_choices().keys())}"
|
||||
) from err
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def make_algorithm(cfg: RLAlgorithmConfig, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
return cfg.build_algorithm(policy)
|
||||
18
src/lerobot/rl/algorithms/sac/__init__.py
Normal file
18
src/lerobot/rl/algorithms/sac/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]
|
||||
90
src/lerobot/rl/algorithms/sac/configuration_sac.py
Normal file
90
src/lerobot/rl/algorithms/sac/configuration_sac.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.gaussian_actor.configuration_gaussian_actor import (
|
||||
CriticNetworkConfig,
|
||||
GaussianActorConfig,
|
||||
)
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
|
||||
@RLAlgorithmConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""SAC algorithm hyperparameters."""
|
||||
|
||||
# Optimizer learning rates
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
|
||||
# Bellman update
|
||||
discount: float = 0.99
|
||||
use_backup_entropy: bool = True
|
||||
critic_target_update_weight: float = 0.005
|
||||
|
||||
# Critic ensemble
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
|
||||
# Temperature / entropy
|
||||
temperature_init: float = 1.0
|
||||
# Target entropy for automatic temperature tuning. If ``None``, defaults to
|
||||
# ``-|A|/2`` where ``|A|`` is the total action dimension (continuous + 1 if
|
||||
# there is a discrete action head).
|
||||
target_entropy: float | None = None
|
||||
|
||||
# Update loop
|
||||
utd_ratio: int = 1
|
||||
policy_update_freq: int = 1
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Optimizations
|
||||
# torch.compile is currently disabled by default
|
||||
use_torch_compile: bool = False
|
||||
|
||||
# Policy config
|
||||
policy_config: GaussianActorConfig | None = None
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg: GaussianActorConfig) -> SACAlgorithmConfig:
|
||||
"""Build an algorithm config with default hyperparameters for a given policy."""
|
||||
return cls(
|
||||
policy_config=policy_cfg,
|
||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||
if self.policy_config is None:
|
||||
raise ValueError(
|
||||
"SACAlgorithmConfig.policy_config is None. "
|
||||
"It must be populated (typically by TrainRLServerPipelineConfig.validate) "
|
||||
"before calling build_algorithm()."
|
||||
)
|
||||
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
return SACAlgorithm(policy=policy, config=self)
|
||||
595
src/lerobot/rl/algorithms/sac/sac_algorithm.py
Normal file
595
src/lerobot/rl/algorithms/sac/sac_algorithm.py
Normal file
@@ -0,0 +1,595 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.gaussian_actor.modeling_gaussian_actor import (
|
||||
DISCRETE_DIMENSION_INDEX,
|
||||
MLP,
|
||||
DiscreteCritic,
|
||||
GaussianActorObservationEncoder,
|
||||
GaussianActorPolicy,
|
||||
orthogonal_init,
|
||||
)
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.rl.algorithms.base import BatchType, RLAlgorithm
|
||||
from lerobot.rl.algorithms.configs import TrainingStats
|
||||
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.transition import move_state_dict_to_device
|
||||
|
||||
|
||||
class SACAlgorithm(RLAlgorithm):
|
||||
"""Soft Actor-Critic. Owns critics, targets, temperature, and loss computation."""
|
||||
|
||||
config_class = SACAlgorithmConfig
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: GaussianActorPolicy,
|
||||
config: SACAlgorithmConfig,
|
||||
):
|
||||
self.config = config
|
||||
self.policy_config = config.policy_config
|
||||
self.policy = policy
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
self._init_critics(action_dim)
|
||||
self._init_temperature(action_dim)
|
||||
|
||||
self._device = torch.device(self.policy.config.device)
|
||||
self._move_to_device()
|
||||
|
||||
def _init_critics(self, action_dim) -> None:
|
||||
"""Build critic ensemble, targets."""
|
||||
encoder = self.policy.encoder_critic
|
||||
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder.output_dim + action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=encoder, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=encoder.output_dim + action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=encoder, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
# TODO(Khalil): Investigate and fix torch.compile
|
||||
# NOTE: torch.compile is disabled, policy does not converge when enabled.
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.discrete_critic_target = None
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
self.discrete_critic_target = self._init_discrete_critic_target(encoder)
|
||||
|
||||
def _init_discrete_critic_target(self, encoder: GaussianActorObservationEncoder) -> DiscreteCritic:
|
||||
"""Build target discrete critic (main network is owned by the policy)."""
|
||||
discrete_critic_target = DiscreteCritic(
|
||||
encoder=encoder,
|
||||
input_dim=encoder.output_dim,
|
||||
output_dim=self.policy_config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
# TODO(Khalil): Compile the discrete critic
|
||||
discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
|
||||
return discrete_critic_target
|
||||
|
||||
def _init_temperature(self, continuous_action_dim: int) -> None:
|
||||
"""Set up temperature parameter (log_alpha) and target entropy."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
total_action_dim = continuous_action_dim + (
|
||||
1 if self.policy_config.num_discrete_actions is not None else 0
|
||||
)
|
||||
self.target_entropy = -total_action_dim / 2
|
||||
|
||||
def _move_to_device(self) -> None:
|
||||
self.policy.to(self._device)
|
||||
self.critic_ensemble.to(self._device)
|
||||
self.critic_target.to(self._device)
|
||||
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
|
||||
if self.discrete_critic_target is not None:
|
||||
self.discrete_critic_target.to(self._device)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def _critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def _discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.policy.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
clip = self.config.grad_clip_norm
|
||||
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch, include_complementary_info=True)
|
||||
|
||||
loss_critic = self._compute_loss_critic(fb)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip)
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_dc.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy.discrete_critic.parameters(), max_norm=clip)
|
||||
self.optimizers["discrete_critic"].step()
|
||||
|
||||
self._update_target_networks()
|
||||
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch, include_complementary_info=False)
|
||||
|
||||
loss_critic = self._compute_loss_critic(fb)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad = torch.nn.utils.clip_grad_norm_(self.critic_ensemble.parameters(), max_norm=clip).item()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
stats = TrainingStats(
|
||||
losses={"loss_critic": loss_critic.item()},
|
||||
grad_norms={"critic": critic_grad},
|
||||
)
|
||||
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
loss_dc = self._compute_loss_discrete_critic(fb)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_dc.backward()
|
||||
dc_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.discrete_critic.parameters(), max_norm=clip
|
||||
).item()
|
||||
self.optimizers["discrete_critic"].step()
|
||||
stats.losses["loss_discrete_critic"] = loss_dc.item()
|
||||
stats.grad_norms["discrete_critic"] = dc_grad
|
||||
|
||||
if self._optimization_step % self.config.policy_update_freq == 0:
|
||||
for _ in range(self.config.policy_update_freq):
|
||||
loss_actor = self._compute_loss_actor(fb)
|
||||
self.optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.actor.parameters(), max_norm=clip
|
||||
).item()
|
||||
self.optimizers["actor"].step()
|
||||
|
||||
loss_temp = self._compute_loss_temperature(fb)
|
||||
self.optimizers["temperature"].zero_grad()
|
||||
loss_temp.backward()
|
||||
temp_grad = torch.nn.utils.clip_grad_norm_([self.log_alpha], max_norm=clip).item()
|
||||
self.optimizers["temperature"].step()
|
||||
|
||||
stats.losses["loss_actor"] = loss_actor.item()
|
||||
stats.losses["loss_temperature"] = loss_temp.item()
|
||||
stats.grad_norms["actor"] = actor_grad
|
||||
stats.grad_norms["temperature"] = temp_grad
|
||||
stats.extra["temperature"] = self.temperature
|
||||
|
||||
self._update_target_networks()
|
||||
self._optimization_step += 1
|
||||
return stats
|
||||
|
||||
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
next_observation_features = batch.get("next_observation_feature")
|
||||
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.policy.actor(
|
||||
next_observations, next_observation_features
|
||||
)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self._critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self._critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
next_observation_features = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self._discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self._discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self._discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
|
||||
actions_pi, log_probs, _ = self.policy.actor(observations, observation_features)
|
||||
|
||||
q_preds = self._critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
observations = batch["state"]
|
||||
observation_features = batch.get("observation_feature")
|
||||
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.policy.actor(observations, observation_features)
|
||||
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def _update_target_networks(self) -> None:
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_p, p in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
|
||||
):
|
||||
target_p.data.copy_(
|
||||
p.data * self.config.critic_target_update_weight
|
||||
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
for target_p, p in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.policy.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_p.data.copy_(
|
||||
p.data * self.config.critic_target_update_weight
|
||||
+ target_p.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def _prepare_forward_batch(
|
||||
self, batch: BatchType, *, include_complementary_info: bool = True
|
||||
) -> dict[str, Any]:
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
observation_features, next_observation_features = self.get_observation_features(
|
||||
observations, next_observations
|
||||
)
|
||||
forward_batch: dict[str, Any] = {
|
||||
ACTION: batch[ACTION],
|
||||
"reward": batch["reward"],
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": batch["done"],
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
if include_complementary_info and "complementary_info" in batch:
|
||||
forward_batch["complementary_info"] = batch["complementary_info"]
|
||||
return forward_batch
|
||||
|
||||
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping component names ("actor", "critic", "temperature")
|
||||
to their respective Adam optimizers.
|
||||
"""
|
||||
actor_params = self.policy.get_optim_params()["actor"]
|
||||
self.optimizers = {
|
||||
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
|
||||
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
|
||||
}
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
|
||||
)
|
||||
return self.optimizers
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
return self.optimizers
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Send actor + discrete-critic state dicts."""
|
||||
state_dicts: dict[str, Any] = {
|
||||
"policy": move_state_dict_to_device(self.policy.actor.state_dict(), device="cpu"),
|
||||
}
|
||||
if self.policy_config.num_discrete_actions is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
self.policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
return state_dicts
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load actor + discrete-critic weights into the policy."""
|
||||
self.policy.load_actor_weights(weights, device=device)
|
||||
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""
|
||||
Get observation features from the policy encoder. It act as cache for the observation features.
|
||||
when the encoder is frozen, the observation features are not updated.
|
||||
We can save compute by caching the observation features.
|
||||
|
||||
Args:
|
||||
policy: The policy model
|
||||
observations: The current observations
|
||||
next_observations: The next observations
|
||||
|
||||
Returns:
|
||||
tuple: observation_features, next_observation_features
|
||||
"""
|
||||
|
||||
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = self.policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = self.policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
class CriticHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: float | None = None,
|
||||
init_final: float | None = None,
|
||||
final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = MLP(
|
||||
input_dim=input_dim,
|
||||
hidden_dims=hidden_dims,
|
||||
activations=activations,
|
||||
activate_final=activate_final,
|
||||
dropout_rate=dropout_rate,
|
||||
final_activation=final_activation,
|
||||
)
|
||||
self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.output_layer(self.net(x))
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
CriticEnsemble wraps multiple CriticHead modules into an ensemble.
|
||||
|
||||
Args:
|
||||
encoder (GaussianActorObservationEncoder): encoder for observations.
|
||||
ensemble (List[CriticHead]): list of critic heads.
|
||||
init_final (float | None): optional initializer scale for final layers.
|
||||
|
||||
Forward returns a tensor of shape (num_critics, batch_size) containing Q-values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: GaussianActorObservationEncoder,
|
||||
ensemble: list[CriticHead],
|
||||
init_final: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.init_final = init_final
|
||||
self.critics = nn.ModuleList(ensemble)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
observation_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
|
||||
obs_enc = self.encoder(observations, cache=observation_features)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
|
||||
# Loop through critics and collect outputs
|
||||
q_values = []
|
||||
for critic in self.critics:
|
||||
q_values.append(critic(inputs))
|
||||
|
||||
# Stack outputs to match expected shape [num_critics, batch_size]
|
||||
q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0)
|
||||
return q_values
|
||||
@@ -97,8 +97,8 @@ class ReplayBuffer:
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
state_keys (list[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Callable | None): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored.
|
||||
@@ -634,7 +634,7 @@ class ReplayBuffer:
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
transitions (list[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
if state_keys is None:
|
||||
|
||||
@@ -176,11 +176,11 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset(
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
crop_params_dict (dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
resize_size (tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
@@ -193,15 +193,15 @@ def convert_lerobot_dataset_to_cropped_lerobot_dataset(
|
||||
fps=int(original_dataset.fps),
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info.features,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info.features:
|
||||
new_dataset.meta.info.features[key]["shape"] = (3, *resize_size)
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||
prev_episode_index = 0
|
||||
|
||||
17
src/lerobot/rl/data_sources/__init__.py
Normal file
17
src/lerobot/rl/data_sources/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .data_mixer import BatchType, DataMixer, OnlineOfflineMixer
|
||||
|
||||
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
|
||||
96
src/lerobot/rl/data_sources/data_mixer.py
Normal file
96
src/lerobot/rl/data_sources/data_mixer.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from lerobot.rl.algorithms.base import BatchType
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
|
||||
|
||||
class DataMixer(abc.ABC):
|
||||
"""Abstract interface for all data mixing strategies."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
"""Draw one batch of ``batch_size`` transitions."""
|
||||
...
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Infinite iterator that yields batches."""
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
|
||||
|
||||
class OnlineOfflineMixer(DataMixer):
|
||||
"""Mixes transitions from an online and an offline replay buffer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer | None = None,
|
||||
online_ratio: float = 1.0,
|
||||
):
|
||||
if not 0.0 <= online_ratio <= 1.0:
|
||||
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
|
||||
self.online_buffer = online_buffer
|
||||
self.offline_buffer = offline_buffer
|
||||
self.online_ratio = online_ratio
|
||||
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
if self.offline_buffer is None:
|
||||
return self.online_buffer.sample(batch_size)
|
||||
|
||||
n_online = max(1, int(batch_size * self.online_ratio))
|
||||
n_offline = batch_size - n_online
|
||||
|
||||
online_batch = self.online_buffer.sample(n_online)
|
||||
offline_batch = self.offline_buffer.sample(n_offline)
|
||||
return concatenate_batch_transitions(online_batch, offline_batch)
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Yield batches by composing buffer async iterators."""
|
||||
|
||||
n_online = max(1, int(batch_size * self.online_ratio))
|
||||
|
||||
online_iter = self.online_buffer.get_iterator(
|
||||
batch_size=n_online,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
if self.offline_buffer is None:
|
||||
yield from online_iter
|
||||
return
|
||||
|
||||
n_offline = batch_size - n_online
|
||||
offline_iter = self.offline_buffer.get_iterator(
|
||||
batch_size=n_offline,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
while True:
|
||||
yield concatenate_batch_transitions(next(online_iter), next(offline_iter))
|
||||
@@ -17,9 +17,9 @@ import logging
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
make_robot_from_config,
|
||||
|
||||
@@ -383,10 +383,21 @@ def make_processors(
|
||||
GymHILAdapterProcessorStep(),
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=device),
|
||||
]
|
||||
|
||||
# Add time limit processor if reset config exists
|
||||
if cfg.processor.reset is not None:
|
||||
env_pipeline_steps.append(
|
||||
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
|
||||
)
|
||||
|
||||
env_pipeline_steps.extend(
|
||||
[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=device),
|
||||
]
|
||||
)
|
||||
|
||||
return DataProcessorPipeline(
|
||||
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
|
||||
), DataProcessorPipeline(
|
||||
@@ -551,8 +562,19 @@ def step_env_and_process_transition(
|
||||
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
||||
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
||||
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
||||
|
||||
if hasattr(env, "get_raw_joint_positions"):
|
||||
raw_joint_positions = env.get_raw_joint_positions()
|
||||
if raw_joint_positions is not None:
|
||||
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||
|
||||
# Merge env and action-processor info: env wins for str keys, action-processor
|
||||
# wins for `TeleopEvents` enum keys
|
||||
action_info = processed_action_transition[TransitionKey.INFO]
|
||||
new_info = info.copy()
|
||||
new_info.update(processed_action_transition[TransitionKey.INFO])
|
||||
for key, value in action_info.items():
|
||||
if isinstance(key, TeleopEvents):
|
||||
new_info[key] = value
|
||||
|
||||
new_transition = create_transition(
|
||||
observation=obs,
|
||||
@@ -568,6 +590,24 @@ def step_env_and_process_transition(
|
||||
return new_transition
|
||||
|
||||
|
||||
def reset_and_build_transition(
|
||||
env: gym.Env,
|
||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
) -> EnvTransition:
|
||||
"""Reset env + processors and return the first env-processed transition."""
|
||||
obs, info = env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
complementary_data: dict[str, Any] = {}
|
||||
if hasattr(env, "get_raw_joint_positions"):
|
||||
raw_joint_positions = env.get_raw_joint_positions()
|
||||
if raw_joint_positions is not None:
|
||||
complementary_data["raw_joint_positions"] = raw_joint_positions
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
return env_processor(data=transition)
|
||||
|
||||
|
||||
def control_loop(
|
||||
env: gym.Env,
|
||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
@@ -593,17 +633,7 @@ def control_loop(
|
||||
print("- When not intervening, robot will stay still")
|
||||
print("- Press Ctrl+C to exit")
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = env.reset()
|
||||
complementary_data = (
|
||||
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
|
||||
)
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(data=transition)
|
||||
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||
|
||||
# Determine if gripper is used
|
||||
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
|
||||
@@ -659,79 +689,81 @@ def control_loop(
|
||||
episode_step = 0
|
||||
episode_start_time = time.perf_counter()
|
||||
|
||||
while episode_idx < cfg.dataset.num_episodes_to_record:
|
||||
step_start_time = time.perf_counter()
|
||||
try:
|
||||
while episode_idx < cfg.dataset.num_episodes_to_record:
|
||||
step_start_time = time.perf_counter()
|
||||
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
env=env,
|
||||
transition=transition,
|
||||
action=neutral_action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
if cfg.mode == "record":
|
||||
observations = {
|
||||
k: v.squeeze(0).cpu()
|
||||
for k, v in transition[TransitionKey.OBSERVATION].items()
|
||||
if isinstance(v, torch.Tensor)
|
||||
}
|
||||
# Use teleop_action if available, otherwise use the action from the transition
|
||||
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
||||
"teleop_action", transition[TransitionKey.ACTION]
|
||||
)
|
||||
frame = {
|
||||
**observations,
|
||||
ACTION: action_to_record.cpu(),
|
||||
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
DONE: np.array([terminated or truncated], dtype=bool),
|
||||
}
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)
|
||||
frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32)
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
if dataset is not None:
|
||||
frame["task"] = cfg.dataset.task
|
||||
dataset.add_frame(frame)
|
||||
|
||||
episode_step += 1
|
||||
|
||||
# Handle episode termination
|
||||
if terminated or truncated:
|
||||
episode_time = time.perf_counter() - episode_start_time
|
||||
logging.info(
|
||||
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
||||
transition = step_env_and_process_transition(
|
||||
env=env,
|
||||
transition=transition,
|
||||
action=neutral_action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
episode_step = 0
|
||||
episode_idx += 1
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
if dataset is not None:
|
||||
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
|
||||
logging.info(f"Re-recording episode {episode_idx}")
|
||||
dataset.clear_episode_buffer()
|
||||
episode_idx -= 1
|
||||
else:
|
||||
logging.info(f"Saving episode {episode_idx}")
|
||||
dataset.save_episode()
|
||||
if cfg.mode == "record":
|
||||
observations = {
|
||||
k: v.squeeze(0).cpu()
|
||||
for k, v in transition[TransitionKey.OBSERVATION].items()
|
||||
if isinstance(v, torch.Tensor)
|
||||
}
|
||||
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
||||
"teleop_action", transition[TransitionKey.ACTION]
|
||||
)
|
||||
frame = {
|
||||
**observations,
|
||||
ACTION: action_to_record.cpu(),
|
||||
REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
DONE: np.array([terminated or truncated], dtype=bool),
|
||||
}
|
||||
if use_gripper:
|
||||
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
||||
"discrete_penalty", 0.0
|
||||
)
|
||||
frame["complementary_info.discrete_penalty"] = np.array(
|
||||
[discrete_penalty], dtype=np.float32
|
||||
)
|
||||
|
||||
# Reset for new episode
|
||||
obs, info = env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
if dataset is not None:
|
||||
frame["task"] = cfg.dataset.task
|
||||
dataset.add_frame(frame)
|
||||
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
episode_step += 1
|
||||
|
||||
# Maintain fps timing
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
# Handle episode termination
|
||||
if terminated or truncated:
|
||||
episode_time = time.perf_counter() - episode_start_time
|
||||
logging.info(
|
||||
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
||||
)
|
||||
episode_step = 0
|
||||
episode_idx += 1
|
||||
|
||||
if dataset is not None:
|
||||
if transition[TransitionKey.INFO].get(TeleopEvents.RERECORD_EPISODE, False):
|
||||
logging.info(f"Re-recording episode {episode_idx}")
|
||||
dataset.clear_episode_buffer()
|
||||
episode_idx -= 1
|
||||
else:
|
||||
logging.info(f"Saving episode {episode_idx}")
|
||||
dataset.save_episode()
|
||||
|
||||
# Reset for new episode
|
||||
transition = reset_and_build_transition(env, env_processor, action_processor)
|
||||
|
||||
# Maintain fps timing
|
||||
precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0))
|
||||
finally:
|
||||
if dataset is not None and dataset.writer is not None and dataset.writer.image_writer is not None:
|
||||
logging.info("Waiting for image writer to finish...")
|
||||
dataset.writer.image_writer.stop()
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Finalizing dataset before pushing to hub")
|
||||
|
||||
@@ -51,6 +51,7 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
@@ -68,10 +69,14 @@ from lerobot.common.train_utils import (
|
||||
)
|
||||
from lerobot.common.wandb_utils import WandBLogger
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets import LeRobotDataset, make_dataset
|
||||
from lerobot.policies import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies import make_policy, make_pre_post_processors
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm
|
||||
from lerobot.rl.algorithms.factory import make_algorithm
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.rl.train_rl import TrainRLServerPipelineConfig
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
@@ -92,13 +97,11 @@ from lerobot.utils.constants import (
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from .buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
|
||||
|
||||
|
||||
@@ -179,7 +182,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
|
||||
def start_learner_threads(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Any, # Event
|
||||
) -> None:
|
||||
"""
|
||||
Start the learner threads for training.
|
||||
@@ -253,7 +256,7 @@ def start_learner_threads(
|
||||
def add_actor_information_and_train(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Any, # Event
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
parameters_queue: Queue,
|
||||
@@ -266,8 +269,8 @@ def add_actor_information_and_train(
|
||||
- Transfers transitions from the actor to the replay buffer.
|
||||
- Logs received interaction messages.
|
||||
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
||||
- Samples batches from the replay buffer and performs multiple critic updates.
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Delegates training updates to an ``RLAlgorithm``.
|
||||
- Periodically pushes updated weights to actors.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
@@ -286,17 +289,13 @@ def add_actor_information_and_train(
|
||||
# of 7%
|
||||
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||
clip_grad_norm_value = cfg.policy.grad_clip_norm
|
||||
online_step_before_learning = cfg.policy.online_step_before_learning
|
||||
utd_ratio = cfg.policy.utd_ratio
|
||||
fps = cfg.env.fps
|
||||
log_freq = cfg.log_freq
|
||||
save_freq = cfg.save_freq
|
||||
policy_update_freq = cfg.policy.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||
saving_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
async_prefetch = cfg.policy.async_prefetch
|
||||
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
@@ -308,7 +307,7 @@ def add_actor_information_and_train(
|
||||
|
||||
logging.info("Initializing policy")
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
@@ -317,15 +316,17 @@ def add_actor_information_and_train(
|
||||
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
algorithm = make_algorithm(cfg=cfg.algorithm, policy=policy)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
# Push initial policy weights to actors
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg=cfg, policy=policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||
@@ -338,21 +339,35 @@ def add_actor_information_and_train(
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
# DataMixer: online-only or online/offline 50-50 mix
|
||||
data_mixer = OnlineOfflineMixer(
|
||||
online_buffer=replay_buffer,
|
||||
offline_buffer=offline_replay_buffer,
|
||||
online_ratio=cfg.online_ratio,
|
||||
)
|
||||
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
|
||||
trainer = RLTrainer(
|
||||
algorithm=algorithm,
|
||||
data_mixer=data_mixer,
|
||||
batch_size=batch_size,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
optimizers = algorithm.get_optimizers()
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message = None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
algorithm.optimization_step = optimization_step
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
dataset_repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
dataset_repo_id = cfg.dataset.repo_id
|
||||
|
||||
# Initialize iterators
|
||||
online_iterator = None
|
||||
offline_iterator = None
|
||||
|
||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||
while True:
|
||||
# Exit the training loop if shutdown is requested
|
||||
@@ -365,7 +380,6 @@ def add_actor_information_and_train(
|
||||
transition_queue=transition_queue,
|
||||
replay_buffer=replay_buffer,
|
||||
offline_replay_buffer=offline_replay_buffer,
|
||||
device=device,
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
shutdown_event=shutdown_event,
|
||||
)
|
||||
@@ -382,180 +396,20 @@ def add_actor_information_and_train(
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
if online_iterator is None:
|
||||
online_iterator = replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None and offline_iterator is None:
|
||||
offline_iterator = offline_replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
# Sample from the iterators
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Sample for the last update in the UTD ratio
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Initialize training info dictionary
|
||||
training_infos = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"critic_grad_norm": critic_grad_norm,
|
||||
}
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Add discrete critic info to training info
|
||||
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
|
||||
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
|
||||
|
||||
# Actor and temperature optimization (at specified frequency)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
# Actor optimization
|
||||
actor_output = policy.forward(forward_batch, model="actor")
|
||||
loss_actor = actor_output["loss_actor"]
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["actor"].step()
|
||||
|
||||
# Add actor info to training info
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization
|
||||
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||
loss_temperature = temperature_output["loss_temperature"]
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
# Add temperature info to training info
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
|
||||
stats = trainer.training_step()
|
||||
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, algorithm=algorithm)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
training_infos = stats.to_log_dict()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
optimization_step = algorithm.optimization_step
|
||||
if optimization_step % log_freq == 0:
|
||||
training_infos["replay_buffer_size"] = len(replay_buffer)
|
||||
if offline_replay_buffer is not None:
|
||||
@@ -583,7 +437,6 @@ def add_actor_information_and_train(
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
@@ -600,6 +453,8 @@ def add_actor_information_and_train(
|
||||
offline_replay_buffer=offline_replay_buffer,
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
fps=fps,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -607,7 +462,7 @@ def start_learner(
|
||||
parameters_queue: Queue,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
shutdown_event: Any, # Event
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
):
|
||||
"""
|
||||
@@ -684,6 +539,8 @@ def save_training_checkpoint(
|
||||
offline_replay_buffer: ReplayBuffer | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
fps: int = 30,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
) -> None:
|
||||
"""
|
||||
Save training checkpoint and associated data.
|
||||
@@ -707,6 +564,8 @@ def save_training_checkpoint(
|
||||
offline_replay_buffer: Optional offline replay buffer to save
|
||||
dataset_repo_id: Repository ID for dataset
|
||||
fps: Frames per second for dataset
|
||||
preprocessor: Optional preprocessor pipeline to save
|
||||
postprocessor: Optional postprocessor pipeline to save
|
||||
"""
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
@@ -723,6 +582,8 @@ def save_training_checkpoint(
|
||||
policy=policy,
|
||||
optimizer=optimizers,
|
||||
scheduler=None,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
# Save interaction step manually
|
||||
@@ -760,58 +621,6 @@ def save_training_checkpoint(
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
||||
A tuple containing:
|
||||
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
||||
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_discrete_critic = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizers["discrete_critic"] = optimizer_discrete_critic
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
# Training setup functions
|
||||
|
||||
|
||||
@@ -1016,33 +825,6 @@ def initialize_offline_replay_buffer(
|
||||
# Utilities/Helpers functions
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Get observation features from the policy encoder. It act as cache for the observation features.
|
||||
when the encoder is frozen, the observation features are not updated.
|
||||
We can save compute by caching the observation features.
|
||||
|
||||
Args:
|
||||
policy: The policy model
|
||||
observations: The current observations
|
||||
next_observations: The next observations
|
||||
|
||||
Returns:
|
||||
tuple: observation_features, next_observation_features
|
||||
"""
|
||||
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.learner == "threads"
|
||||
|
||||
@@ -1093,19 +875,11 @@ def check_nan_in_transition(
|
||||
return nan_detected
|
||||
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, algorithm: RLAlgorithm) -> None:
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
|
||||
# Create a dictionary to hold all the state dicts
|
||||
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
|
||||
|
||||
# Add discrete critic if it exists
|
||||
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
logging.debug("[LEARNER] Including discrete critic in state dict push")
|
||||
|
||||
state_dicts = algorithm.get_weights()
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
@@ -1129,9 +903,8 @@ def process_transitions(
|
||||
transition_queue: Queue,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
device: str,
|
||||
dataset_repo_id: str | None,
|
||||
shutdown_event: any,
|
||||
shutdown_event: Any, # Event
|
||||
):
|
||||
"""Process all available transitions from the queue.
|
||||
|
||||
@@ -1139,7 +912,6 @@ def process_transitions(
|
||||
transition_queue: Queue for receiving transitions from the actor
|
||||
replay_buffer: Replay buffer to add transitions to
|
||||
offline_replay_buffer: Offline replay buffer to add transitions to
|
||||
device: Device to move transitions to
|
||||
dataset_repo_id: Repository ID for dataset
|
||||
shutdown_event: Event to signal shutdown
|
||||
"""
|
||||
@@ -1148,8 +920,6 @@ def process_transitions(
|
||||
transition_list = bytes_to_transitions(buffer=transition_list)
|
||||
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition=transition, device=device)
|
||||
|
||||
# Skip transitions with NaN values
|
||||
if check_nan_in_transition(
|
||||
observations=transition["state"],
|
||||
@@ -1163,7 +933,7 @@ def process_transitions(
|
||||
|
||||
# Add to offline buffer if it's an intervention
|
||||
if dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
TeleopEvents.IS_INTERVENTION
|
||||
TeleopEvents.IS_INTERVENTION.value
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
@@ -1172,7 +942,7 @@ def process_interaction_messages(
|
||||
interaction_message_queue: Queue,
|
||||
interaction_step_shift: int,
|
||||
wandb_logger: WandBLogger | None,
|
||||
shutdown_event: any,
|
||||
shutdown_event: Any, # Event
|
||||
) -> dict | None:
|
||||
"""Process all available interaction messages from the queue.
|
||||
|
||||
|
||||
49
src/lerobot/rl/train_rl.py
Normal file
49
src/lerobot/rl/train_rl.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Top-level pipeline config for distributed RL training (actor / learner)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
from lerobot.rl.algorithms.factory import make_algorithm_config
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithmConfig # noqa: F401
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
# Algorithm config.
|
||||
algorithm: RLAlgorithmConfig | None = None
|
||||
|
||||
# Data mixer strategy name. Currently supports "online_offline".
|
||||
mixer: str = "online_offline"
|
||||
# Fraction sampled from online replay when using OnlineOfflineMixer.
|
||||
online_ratio: float = 0.5
|
||||
|
||||
def validate(self) -> None:
|
||||
super().validate()
|
||||
|
||||
if self.algorithm is None:
|
||||
self.algorithm = make_algorithm_config("sac")
|
||||
|
||||
if getattr(self.algorithm, "policy_config", None) is None:
|
||||
self.algorithm.policy_config = self.policy
|
||||
99
src/lerobot/rl/trainer.py
Normal file
99
src/lerobot/rl/trainer.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from lerobot.rl.algorithms.base import BatchType, RLAlgorithm
|
||||
from lerobot.rl.algorithms.configs import TrainingStats
|
||||
from lerobot.rl.data_sources.data_mixer import DataMixer
|
||||
|
||||
|
||||
class RLTrainer:
|
||||
"""Unified training step orchestrator.
|
||||
|
||||
Holds the algorithm, a DataMixer, and an optional preprocessor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: RLAlgorithm,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
preprocessor: Any | None = None,
|
||||
):
|
||||
self.algorithm = algorithm
|
||||
self.data_mixer = data_mixer
|
||||
self.batch_size = batch_size
|
||||
self._preprocessor = preprocessor
|
||||
|
||||
self._iterator: Iterator[BatchType] | None = None
|
||||
|
||||
self.algorithm.make_optimizers_and_scheduler()
|
||||
|
||||
def _build_data_iterator(self) -> Iterator[BatchType]:
|
||||
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
|
||||
raw = self.algorithm.configure_data_iterator(
|
||||
data_mixer=self.data_mixer,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
if self._preprocessor is not None:
|
||||
return _PreprocessedIterator(raw, self._preprocessor)
|
||||
return raw
|
||||
|
||||
def reset_data_iterator(self) -> None:
|
||||
"""Discard the current iterator so it will be rebuilt lazily next step."""
|
||||
self._iterator = None
|
||||
|
||||
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
|
||||
"""Swap the active data mixer, optionally resetting the iterator."""
|
||||
self.data_mixer = data_mixer
|
||||
if reset:
|
||||
self.reset_data_iterator()
|
||||
|
||||
def training_step(self) -> TrainingStats:
|
||||
"""Run one training step (algorithm-agnostic)."""
|
||||
if self._iterator is None:
|
||||
self._iterator = self._build_data_iterator()
|
||||
return self.algorithm.update(self._iterator)
|
||||
|
||||
|
||||
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
|
||||
"""Apply policy preprocessing to RL observations only."""
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
batch["state"] = preprocessor.process_observation(observations)
|
||||
batch["next_state"] = preprocessor.process_observation(next_observations)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class _PreprocessedIterator:
|
||||
"""Iterator wrapper that preprocesses each sampled RL batch."""
|
||||
|
||||
__slots__ = ("_raw", "_preprocessor")
|
||||
|
||||
def __init__(self, raw_iterator: Iterator[BatchType], preprocessor: Any) -> None:
|
||||
self._raw = raw_iterator
|
||||
self._preprocessor = preprocessor
|
||||
|
||||
def __iter__(self) -> _PreprocessedIterator:
|
||||
return self
|
||||
|
||||
def __next__(self) -> BatchType:
|
||||
batch = next(self._raw)
|
||||
return preprocess_rl_batch(self._preprocessor, batch)
|
||||
@@ -18,6 +18,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.model import RobotKinematics
|
||||
@@ -31,6 +32,7 @@ from lerobot.processor import (
|
||||
RobotObservation,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
|
||||
@@ -126,9 +128,18 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
delta_r = np.array(
|
||||
[
|
||||
wx * self.end_effector_step_sizes.get("wx", 1),
|
||||
wy * self.end_effector_step_sizes.get("wy", 1),
|
||||
wz * self.end_effector_step_sizes.get("wz", 1),
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
|
||||
r_mat = Rotation.from_rotvec(delta_r).as_matrix()
|
||||
desired = np.eye(4, dtype=float)
|
||||
desired[:3, :3] = ref[:3, :3] @ r_abs
|
||||
desired[:3, :3] = ref[:3, :3] @ r_mat
|
||||
desired[:3, 3] = ref[:3, 3] + delta_p
|
||||
|
||||
self._command_when_disabled = desired.copy()
|
||||
@@ -353,13 +364,16 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
speed_factor: A scaling factor to convert the normalized velocity command to a position change.
|
||||
clip_min: The minimum allowed gripper joint position.
|
||||
clip_max: The maximum allowed gripper joint position.
|
||||
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
|
||||
discrete_gripper: If True, interpret the input as a discrete class index
|
||||
{0 = close, 1 = stay, 2 = open}, matching `GamepadTeleop.GripperAction`.
|
||||
"""
|
||||
|
||||
speed_factor: float = 20.0
|
||||
clip_min: float = 0.0
|
||||
clip_max: float = 100.0
|
||||
discrete_gripper: bool = False
|
||||
scale_velocity: bool = False
|
||||
use_ik_solution: bool = False
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
observation = self.transition.get(TransitionKey.OBSERVATION).copy()
|
||||
@@ -369,18 +383,21 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
if observation is None:
|
||||
raise ValueError("Joints observation is require for computing robot kinematics")
|
||||
|
||||
q_raw = np.array(
|
||||
[float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")],
|
||||
dtype=float,
|
||||
)
|
||||
if self.use_ik_solution and "IK_solution" in self.transition.get(TransitionKey.COMPLEMENTARY_DATA):
|
||||
q_raw = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"]
|
||||
else:
|
||||
q_raw = np.array(
|
||||
[float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")],
|
||||
dtype=float,
|
||||
)
|
||||
if q_raw is None:
|
||||
raise ValueError("Joints observation is require for computing robot kinematics")
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_vel = (gripper_vel - 1) * self.clip_max
|
||||
if self.discrete_gripper or self.scale_velocity:
|
||||
# Map discrete command {0=close, 1=stay, 2=open} -> signed velocity.
|
||||
# Negation accounts for SO100 sign (joint position increases on close).
|
||||
# 0 -> +clip_max (close), 1 -> 0 (stay), 2 -> -clip_max (open)
|
||||
gripper_vel = -(gripper_vel - 1) * self.clip_max
|
||||
|
||||
# Compute desired gripper position
|
||||
delta = gripper_vel * float(self.speed_factor)
|
||||
@@ -578,6 +595,7 @@ class InverseKinematicsRLStep(ProcessorStep):
|
||||
|
||||
# Compute inverse kinematics
|
||||
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
|
||||
q_target[-1] = gripper_pos # Set gripper position
|
||||
self.q_curr = q_target
|
||||
|
||||
# TODO: This is sentitive to order of motor_names = q_target mapping
|
||||
@@ -609,3 +627,50 @@ class InverseKinematicsRLStep(ProcessorStep):
|
||||
def reset(self):
|
||||
"""Resets the initial guess for the IK solver."""
|
||||
self.q_curr = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register("ee_observation")
|
||||
class EEObservationStep(ObservationProcessorStep):
|
||||
use_rotation: bool = False
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
ee_pose_list = [
|
||||
observation["ee.x"],
|
||||
observation["ee.y"],
|
||||
observation["ee.z"],
|
||||
]
|
||||
if self.use_rotation:
|
||||
ee_pose_list.extend(
|
||||
[
|
||||
observation["ee.wx"],
|
||||
observation["ee.wy"],
|
||||
observation["ee.wz"],
|
||||
]
|
||||
)
|
||||
# gripper_pos = action.pop("ee.gripper_pos")
|
||||
ee_pose = torch.tensor(ee_pose_list, dtype=torch.float32).unsqueeze(0)
|
||||
|
||||
current_state = observation.get(OBS_STATE)
|
||||
if current_state is None:
|
||||
return observation
|
||||
|
||||
extended_state = torch.cat([current_state, ee_pose], dim=-1)
|
||||
|
||||
# Create new observation dict
|
||||
new_observation = dict(observation)
|
||||
new_observation[OBS_STATE] = extended_state
|
||||
|
||||
return new_observation
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
if OBS_STATE in features[PipelineFeatureType.OBSERVATION]:
|
||||
original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE]
|
||||
new_shape = (original_feature.shape[0] + 3,) + original_feature.shape[1:]
|
||||
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature(
|
||||
type=original_feature.type, shape=new_shape
|
||||
)
|
||||
return features
|
||||
|
||||
@@ -68,16 +68,9 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
features: dict[str, tuple] = {}
|
||||
for cam in self.cameras:
|
||||
cam_cfg = self.config.cameras[cam]
|
||||
features[cam] = (cam_cfg.height, cam_cfg.width, 3)
|
||||
# Cameras with a depth stream (e.g. RealSense with use_depth=True) also
|
||||
# emit a 2D depth feature; hw_to_dataset_features routes 2D shapes to
|
||||
# ``observation.depth.<bare>`` with the depth-map marker.
|
||||
if getattr(cam_cfg, "use_depth", False):
|
||||
features[f"{cam}_depth"] = (cam_cfg.height, cam_cfg.width)
|
||||
return features
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -175,6 +168,12 @@ class SOFollower(Robot):
|
||||
self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout
|
||||
self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded
|
||||
|
||||
# Set Goal_Position = Present_Position while torque is still disabled so
|
||||
# that when torque is re-enabled at the end of this block the motors have
|
||||
# zero positional error and do not snap to a stale register value.
|
||||
present = self.bus.sync_read("Present_Position")
|
||||
self.bus.sync_write("Goal_Position", present)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
@@ -197,14 +196,6 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Cameras with a depth stream populate a sibling ``<cam>_depth`` key
|
||||
# (consumed by hw_to_dataset_features / build_dataset_frame).
|
||||
if getattr(self.config.cameras[cam_key], "use_depth", False):
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -27,7 +27,7 @@ from threading import Event
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType
|
||||
from lerobot.configs import FeatureType, PreTrainedConfig
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
aggregate_pipeline_dataset_features,
|
||||
@@ -178,26 +178,33 @@ def build_rollout_context(
|
||||
policy_config = cfg.policy
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
for attr in ("device", "use_amp"):
|
||||
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
|
||||
cli_val = getattr(cfg.policy, attr)
|
||||
if cli_val is not None:
|
||||
setattr(full_config, attr, cli_val)
|
||||
|
||||
if policy_config.type == "vqbet" and cfg.device == "mps":
|
||||
if hasattr(full_config, "compile_model"):
|
||||
full_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
if full_config.type == "vqbet" and cfg.device == "mps":
|
||||
raise NotImplementedError(
|
||||
"Current implementation of VQBeT does not support `mps` backend. "
|
||||
"Please use `cpu` or `cuda` backend."
|
||||
)
|
||||
|
||||
if policy_config.use_peft:
|
||||
if full_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = policy_config.pretrained_path
|
||||
peft_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(policy_config.pretrained_path, config=policy_config)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
|
||||
|
||||
if is_rtc:
|
||||
policy.config.rtc_config = cfg.inference.rtc
|
||||
@@ -308,9 +315,7 @@ def build_rollout_context(
|
||||
# Validate visual features if no rename_map is active
|
||||
rename_map = cfg.rename_map
|
||||
if not rename_map:
|
||||
expected_visuals = {
|
||||
k for k, v in policy_config.input_features.items() if v.type == FeatureType.VISUAL
|
||||
}
|
||||
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {
|
||||
f"observation.images.{k}" for k, v in robot.observation_features.items() if isinstance(v, tuple)
|
||||
}
|
||||
|
||||
@@ -70,7 +70,6 @@ from lerobot.datasets.io_utils import (
|
||||
get_parquet_file_size_in_mb,
|
||||
get_parquet_num_frames,
|
||||
load_info,
|
||||
load_json,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -82,11 +81,9 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
INFO_PATH,
|
||||
LEGACY_EPISODES_PATH,
|
||||
LEGACY_EPISODES_STATS_PATH,
|
||||
LEGACY_TASKS_PATH,
|
||||
DatasetInfo,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
@@ -168,7 +165,7 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||
def validate_local_dataset_version(local_path: Path) -> None:
|
||||
"""Validate that the local dataset has the expected v2.1 version."""
|
||||
info = load_info(local_path)
|
||||
dataset_version = info.codebase_version or "unknown"
|
||||
dataset_version = info.get("codebase_version", "unknown")
|
||||
if dataset_version != V21:
|
||||
raise ValueError(
|
||||
f"Local dataset has codebase version '{dataset_version}', expected '{V21}'. "
|
||||
@@ -259,14 +256,14 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
|
||||
|
||||
def get_video_keys(root):
|
||||
info = load_info(root)
|
||||
features = info.features
|
||||
features = info["features"]
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
return video_keys
|
||||
|
||||
|
||||
def get_image_keys(root):
|
||||
info = load_info(root)
|
||||
features = info.features
|
||||
features = info["features"]
|
||||
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
|
||||
return image_keys
|
||||
|
||||
@@ -437,8 +434,7 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
|
||||
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
# Load as raw dict to remove legacy v2.1 fields before constructing DatasetInfo.
|
||||
info = load_json(root / INFO_PATH)
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = V30
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
@@ -453,9 +449,7 @@ def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
# already has fps in video_info
|
||||
continue
|
||||
info["features"][key]["fps"] = info["fps"]
|
||||
# Convert raw dict to typed DatasetInfo before writing
|
||||
dataset_info = DatasetInfo.from_dict(info)
|
||||
write_info(dataset_info, new_root)
|
||||
write_info(info, new_root)
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
|
||||
@@ -49,14 +49,6 @@ Delete episodes and save to a new dataset at a specific path and with a new repo
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and re-encode video segments with h264:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--operation.camera_encoder_config.vcodec h264 \
|
||||
--operation.camera_encoder_config.crf 23
|
||||
|
||||
Split dataset by fractions (pusht_train, pusht_val):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
@@ -82,14 +74,6 @@ Split into more than two splits:
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||
|
||||
Split dataset and re-encode video segments with h264:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}' \
|
||||
--operation.camera_encoder_config.vcodec h264 \
|
||||
--operation.camera_encoder_config.crf 23
|
||||
|
||||
Merge multiple datasets:
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
@@ -203,7 +187,7 @@ import abc
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
@@ -211,8 +195,6 @@ import draccus
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
convert_image_to_video_dataset,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
@@ -236,14 +218,12 @@ class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@dataclass
|
||||
class DeleteEpisodesConfig(OperationConfig):
|
||||
episode_indices: list[int] | None = None
|
||||
camera_encoder_config: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("split")
|
||||
@dataclass
|
||||
class SplitConfig(OperationConfig):
|
||||
splits: dict[str, float | list[int]] | None = None
|
||||
camera_encoder_config: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("merge")
|
||||
@@ -270,7 +250,11 @@ class ModifyTasksConfig(OperationConfig):
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
output_dir: str | None = None
|
||||
camera_encoder_config: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int = 2
|
||||
crf: int = 30
|
||||
fast_decode: int = 0
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
max_episodes_per_batch: int | None = None
|
||||
@@ -372,7 +356,6 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||
episode_indices=cfg.operation.episode_indices,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
camera_encoder_config=cfg.operation.camera_encoder_config,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
@@ -404,7 +387,6 @@ def handle_split(cfg: EditDatasetConfig) -> None:
|
||||
dataset,
|
||||
splits=cfg.operation.splits,
|
||||
output_dir=cfg.new_root,
|
||||
camera_encoder_config=cfg.operation.camera_encoder_config,
|
||||
)
|
||||
|
||||
for split_name, split_ds in split_datasets.items():
|
||||
@@ -575,8 +557,11 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
dataset=dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
camera_encoder_config=getattr(cfg.operation, "camera_encoder_config", None)
|
||||
or camera_encoder_defaults(),
|
||||
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
|
||||
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
|
||||
g=getattr(cfg.operation, "g", 2),
|
||||
crf=getattr(cfg.operation, "crf", 30),
|
||||
fast_decode=getattr(cfg.operation, "fast_decode", 0),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None),
|
||||
|
||||
@@ -63,27 +63,6 @@ lerobot-record \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
Example recording with custom video encoding parameters:
|
||||
```shell
|
||||
lerobot-record \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \\
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\
|
||||
--robot.id=black \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \\
|
||||
--teleop.id=blue \\
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \\
|
||||
--dataset.num_episodes=2 \\
|
||||
--dataset.single_task="Grab the cube" \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
--dataset.encoder_threads=2 \\
|
||||
--dataset.camera_encoder_config.vcodec=h264 \\
|
||||
--dataset.camera_encoder_config.preset=fast \\
|
||||
--dataset.camera_encoder_config.extra_options={"tune": "film", "profile:v": "high", "bf": 2} \\
|
||||
--display_data=true
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -104,12 +83,10 @@ from lerobot.common.control_utils import (
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
DepthEncoderConfig,
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
aggregate_pipeline_dataset_features,
|
||||
create_initial_features,
|
||||
depth_encoder_defaults,
|
||||
safe_stop_image_writer,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
@@ -328,10 +305,7 @@ def record_loop(
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(
|
||||
observation=obs_processed,
|
||||
action=action_values,
|
||||
compress_images=display_compressed_images,
|
||||
features=dataset.features if dataset is not None else None,
|
||||
observation=obs_processed, action=action_values, compress_images=display_compressed_images
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
@@ -403,11 +377,10 @@ def record(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder_config=cfg.dataset.camera_encoder_config,
|
||||
depth_encoder_config=cfg.dataset.depth_encoder_config,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
@@ -433,11 +406,10 @@ def record(
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder_config=cfg.dataset.camera_encoder_config,
|
||||
depth_encoder_config=cfg.dataset.depth_encoder_config,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
@@ -448,7 +420,7 @@ def record(
|
||||
|
||||
if not cfg.dataset.streaming_encoding:
|
||||
logging.info(
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.camera_encoder_config.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
|
||||
@@ -47,7 +47,6 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import make_reward_pre_post_processors
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -71,8 +70,8 @@ def update_policy(
|
||||
accelerator: "Accelerator",
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
sample_weighter=None,
|
||||
) -> tuple[MetricsTracker, dict | None]:
|
||||
rabc_weights_provider=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
@@ -88,7 +87,7 @@ def update_policy(
|
||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
|
||||
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@@ -98,31 +97,27 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
if sample_weighter is not None:
|
||||
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
|
||||
# Get RA-BC weights if enabled
|
||||
rabc_batch_weights = None
|
||||
rabc_batch_stats = None
|
||||
if rabc_weights_provider is not None:
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
if sample_weights is not None:
|
||||
# Use per-sample loss for weighted training
|
||||
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
# Get per-sample losses
|
||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||
|
||||
# Weighted loss: each sample's contribution is scaled by its weight.
|
||||
# We divide by weight sum (not batch size) so that if some weights are zero,
|
||||
# the remaining samples contribute proportionally more, preserving gradient scale.
|
||||
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
|
||||
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
||||
# rabc_batch_weights is already normalized to sum to batch_size
|
||||
epsilon = 1e-6
|
||||
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
|
||||
|
||||
# Log weighting statistics
|
||||
if output_dict is None:
|
||||
output_dict = {}
|
||||
for key, value in weight_stats.items():
|
||||
output_dict[f"sample_weight_{key}"] = value
|
||||
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
||||
# Log raw mean weight (before normalization) - this is the meaningful metric
|
||||
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
||||
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||
else:
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
@@ -193,8 +188,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||
force_cpu = cfg.trainable_config.device == "cpu"
|
||||
# Force the device to be CPU when policy.device is set to CPU.
|
||||
force_cpu = cfg.policy.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
@@ -250,44 +245,26 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
if is_main_process:
|
||||
logging.info("Creating reward model")
|
||||
from lerobot.rewards import make_reward_model
|
||||
|
||||
policy = make_reward_model(
|
||||
cfg=cfg.reward_model,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
dataset_meta=dataset.meta,
|
||||
)
|
||||
if not policy.is_trainable:
|
||||
raise ValueError(
|
||||
f"Reward model '{policy.name}' is zero-shot and cannot be trained via lerobot-train. "
|
||||
"Use it directly for inference via compute_reward() (e.g. offline precompute)."
|
||||
)
|
||||
else:
|
||||
if is_main_process:
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
if is_main_process:
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
|
||||
if cfg.peft is not None:
|
||||
if cfg.is_reward_model_training:
|
||||
raise ValueError("PEFT is only supported for policy training. ")
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
# Convert CLI peft config to dict for overrides
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
# Wait for all processes to finish model creation before continuing
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
processor_pretrained_path = cfg.policy.pretrained_path
|
||||
if (
|
||||
getattr(active_cfg, "use_relative_actions", False)
|
||||
getattr(cfg.policy, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
@@ -297,15 +274,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
# For SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
if processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
@@ -325,36 +305,38 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
},
|
||||
}
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(
|
||||
cfg.reward_model,
|
||||
**processor_kwargs,
|
||||
)
|
||||
else:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
# Create sample weighter if configured (e.g., for RA-BC training)
|
||||
sample_weighter = None
|
||||
if cfg.sample_weighting is not None:
|
||||
from lerobot.utils.sample_weighting import make_sample_weighter
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
rabc_weights = None
|
||||
if cfg.use_rabc:
|
||||
from lerobot.utils.rabc import RABCWeights
|
||||
|
||||
if is_main_process:
|
||||
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
|
||||
sample_weighter = make_sample_weighter(
|
||||
cfg.sample_weighting,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=cfg.dataset.root,
|
||||
dataset_repo_id=cfg.dataset.repo_id,
|
||||
# Get chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError("Chunk size is not found in policy config")
|
||||
|
||||
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
||||
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
||||
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
||||
rabc_weights = RABCWeights(
|
||||
progress_path=cfg.rabc_progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=head_mode,
|
||||
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
||||
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
||||
device=device,
|
||||
)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
@@ -383,13 +365,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
@@ -466,7 +448,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
sample_weighter=sample_weighter,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -485,10 +467,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
# Log RA-BC statistics if enabled
|
||||
if rabc_weights is not None:
|
||||
rabc_stats = rabc_weights.get_stats()
|
||||
wandb_log_dict.update(
|
||||
{
|
||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||
"rabc_delta_std": rabc_stats["delta_std"],
|
||||
"rabc_num_frames": rabc_stats["num_frames"],
|
||||
}
|
||||
)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
@@ -570,15 +558,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
logging.info("End of training")
|
||||
|
||||
if getattr(active_cfg, "push_to_hub", False):
|
||||
unwrapped_model = accelerator.unwrap_model(policy)
|
||||
# PEFT only applies when training a policy — reward models use the plain path.
|
||||
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
||||
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
||||
if cfg.policy.push_to_hub:
|
||||
unwrapped_policy = accelerator.unwrap_model(policy)
|
||||
if cfg.policy.use_peft:
|
||||
unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy)
|
||||
else:
|
||||
unwrapped_model.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||
unwrapped_policy.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
|
||||
# Properly clean up the distributed process group
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -104,11 +104,14 @@ class KeyboardTeleop(Teleoperator):
|
||||
|
||||
def _on_press(self, key):
|
||||
if hasattr(key, "char"):
|
||||
self.event_queue.put((key.char, True))
|
||||
key = key.char
|
||||
self.event_queue.put((key, True))
|
||||
|
||||
def _on_release(self, key):
|
||||
if hasattr(key, "char"):
|
||||
self.event_queue.put((key.char, False))
|
||||
key = key.char
|
||||
self.event_queue.put((key, False))
|
||||
|
||||
if key == keyboard.Key.esc:
|
||||
logging.info("ESC pressed, disconnecting.")
|
||||
self.disconnect()
|
||||
@@ -204,8 +207,6 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
# this is useful for retrieving other events like interventions for RL, episode success, etc.
|
||||
self.misc_keys_queue.put(key)
|
||||
|
||||
self.current_pressed.clear()
|
||||
|
||||
action_dict = {
|
||||
"delta_x": delta_x,
|
||||
"delta_y": delta_y,
|
||||
@@ -256,6 +257,8 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
]
|
||||
is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys)
|
||||
|
||||
self.current_pressed.clear()
|
||||
|
||||
# Check for episode control commands from misc_keys_queue
|
||||
terminate_episode = False
|
||||
success = False
|
||||
|
||||
@@ -20,6 +20,7 @@ from .config_so_leader import (
|
||||
SOLeaderConfig,
|
||||
SOLeaderTeleopConfig,
|
||||
)
|
||||
from .so101_leader_follower import SO101LeaderFollower
|
||||
from .so_leader import SO100Leader, SO101Leader, SOLeader
|
||||
|
||||
__all__ = [
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"SO100LeaderConfig",
|
||||
"SO101Leader",
|
||||
"SO101LeaderConfig",
|
||||
"SO101LeaderFollower",
|
||||
"SOLeader",
|
||||
"SOLeaderConfig",
|
||||
"SOLeaderTeleopConfig",
|
||||
|
||||
@@ -29,6 +29,11 @@ class SOLeaderConfig:
|
||||
# Whether to use degrees for angles
|
||||
use_degrees: bool = True
|
||||
|
||||
# Enable leader-follower mode where leader can both lead and follow
|
||||
leader_follower_mode: bool = False
|
||||
|
||||
use_gripper: bool = True
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("so101_leader")
|
||||
@TeleoperatorConfig.register_subclass("so100_leader")
|
||||
|
||||
261
src/lerobot/teleoperators/so_leader/so101_leader_follower.py
Normal file
261
src/lerobot/teleoperators/so_leader/so101_leader_follower.py
Normal file
@@ -0,0 +1,261 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from threading import Event, Thread
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.teleoperators.so_leader.so_leader import SOLeader as SO101Leader
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
PYNPUT_AVAILABLE = True
|
||||
try:
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
logging.info("No DISPLAY set. Skipping pynput import.")
|
||||
raise ImportError("pynput blocked intentionally due to no display.")
|
||||
|
||||
from pynput import keyboard
|
||||
except ImportError:
|
||||
keyboard = None
|
||||
PYNPUT_AVAILABLE = False
|
||||
except Exception as e:
|
||||
keyboard = None
|
||||
PYNPUT_AVAILABLE = False
|
||||
logging.info(f"Could not import pynput: {e}")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SO101LeaderFollower(SO101Leader):
|
||||
"""
|
||||
Extended SO101 Leader that can both lead (human control) and follow (mimic follower).
|
||||
|
||||
This class adds leader-follower functionality where:
|
||||
- In follow mode: The leader arm mimics the follower's position (torque enabled)
|
||||
- In lead mode: Human controls the leader (torque disabled) and provides actions
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
# Leader-follower state
|
||||
self.is_intervening = False
|
||||
# Initialize as False because configure() disables torque at connect time;
|
||||
# send_action() will re-enable it on the first call when not intervening.
|
||||
self.leader_torque_enabled = False
|
||||
|
||||
# Tracking error for automatic intervention detection
|
||||
self.leader_tracking_error_queue = deque(maxlen=4)
|
||||
|
||||
# Keyboard event handling
|
||||
self.keyboard_events = {
|
||||
"intervention": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
"rerecord": False,
|
||||
}
|
||||
self.keyboard_thread = None
|
||||
self.stop_event = Event()
|
||||
|
||||
# Store last follower position for action computation
|
||||
self.last_follower_pos = None
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
if self.config.use_gripper:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {
|
||||
"delta_x": 0,
|
||||
"delta_y": 1,
|
||||
"delta_z": 2,
|
||||
"delta_wx": 3,
|
||||
"delta_wy": 4,
|
||||
"delta_wz": 5,
|
||||
"gripper": 6,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {
|
||||
"delta_x": 0,
|
||||
"delta_y": 1,
|
||||
"delta_z": 2,
|
||||
"delta_wx": 3,
|
||||
"delta_wy": 4,
|
||||
"delta_wz": 5,
|
||||
},
|
||||
}
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""Connect and configure for leader-follower mode."""
|
||||
super().connect(calibrate)
|
||||
|
||||
# Configure for leader-follower mode with lower gains
|
||||
# Lower gains allow manual intervention without injury risk
|
||||
# self.bus.sync_write("Torque_Enable", 1)
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("P_Coefficient", motor, 16)
|
||||
self.bus.write("I_Coefficient", motor, 0)
|
||||
self.bus.write("D_Coefficient", motor, 16)
|
||||
|
||||
# Start keyboard listener
|
||||
self._start_keyboard_listener()
|
||||
|
||||
print("- Leader-Follower Mode:")
|
||||
print(" - Press SPACE to toggle intervention (leader control)")
|
||||
print(" - When not intervening, leader follows follower position")
|
||||
print(" - When intervening, follower follows leader in end-effector space")
|
||||
print(" - Press 's' to mark episode as success")
|
||||
print(" - Press ESC to end episode as failure")
|
||||
print(" - Press 'r' to re-record episode")
|
||||
|
||||
def _start_keyboard_listener(self):
|
||||
"""Start keyboard listener thread for intervention control."""
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.space:
|
||||
self.keyboard_events["intervention"] = not self.keyboard_events["intervention"]
|
||||
self.is_intervening = self.keyboard_events["intervention"]
|
||||
state = "INTERVENTION MODE" if self.is_intervening else "FOLLOWING MODE"
|
||||
logger.info(f"Toggled to {state}")
|
||||
elif key == keyboard.Key.esc:
|
||||
self.keyboard_events["failure"] = True
|
||||
elif hasattr(key, "char"):
|
||||
if key.char == "s":
|
||||
self.keyboard_events["success"] = True
|
||||
elif key.char == "r":
|
||||
self.keyboard_events["rerecord"] = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling key press: {e}")
|
||||
|
||||
def listen():
|
||||
with keyboard.Listener(on_press=on_press) as listener:
|
||||
while not self.stop_event.is_set():
|
||||
time.sleep(0.1)
|
||||
listener.stop()
|
||||
|
||||
self.keyboard_thread = Thread(target=listen, daemon=True)
|
||||
self.keyboard_thread.start()
|
||||
|
||||
def send_action(self, action: dict[str, float]) -> None:
|
||||
"""
|
||||
Send position commands to leader arm (follow mode).
|
||||
|
||||
Args:
|
||||
action: Dictionary of motor positions to command
|
||||
"""
|
||||
# Store follower position for later use
|
||||
self.last_follower_pos = np.array([action.get(f"{motor}.pos", 0) for motor in self.bus.motors])
|
||||
|
||||
if not self.is_intervening:
|
||||
# Follow mode: enable torque and track follower
|
||||
if not self.leader_torque_enabled:
|
||||
self.bus.sync_write("Torque_Enable", 1)
|
||||
self.leader_torque_enabled = True
|
||||
|
||||
# Send follower positions to leader
|
||||
goal_pos = {motor: action[f"{motor}.pos"] for motor in self.bus.motors}
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
|
||||
# Track error for automatic intervention detection
|
||||
current_pos = self.bus.sync_read("Present_Position")
|
||||
current_array = np.array([current_pos[motor] for motor in self.bus.motors])
|
||||
error = np.linalg.norm(self.last_follower_pos[:-1] - current_array[:-1])
|
||||
self.leader_tracking_error_queue.append(error)
|
||||
|
||||
def get_action(self) -> dict[str, float]:
|
||||
"""
|
||||
Get action from leader arm.
|
||||
|
||||
In follow mode: Returns neutral/current positions
|
||||
In lead mode: Returns actual leader positions for follower to track
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
if self.is_intervening:
|
||||
# Lead mode: disable torque if needed and return leader positions
|
||||
if self.leader_torque_enabled:
|
||||
self.bus.sync_write("Torque_Enable", 0)
|
||||
self.leader_torque_enabled = False
|
||||
|
||||
# Get current leader position
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
|
||||
# Track error
|
||||
if self.last_follower_pos is not None:
|
||||
current_array = np.array([action[f"{motor}.pos"] for motor in self.bus.motors])
|
||||
error = np.linalg.norm(self.last_follower_pos[:-1] - current_array[:-1])
|
||||
self.leader_tracking_error_queue.append(error)
|
||||
else:
|
||||
# Follow mode: return current/neutral positions
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def get_teleop_events(self) -> dict[TeleopEvents, bool]:
|
||||
"""Get current keyboard events."""
|
||||
events = {}
|
||||
|
||||
# Map keyboard events to TeleopEvents
|
||||
if self.keyboard_events["success"]:
|
||||
events[TeleopEvents.SUCCESS] = True
|
||||
self.keyboard_events["success"] = False
|
||||
if self.keyboard_events["failure"]:
|
||||
events[TeleopEvents.FAILURE] = True
|
||||
events[TeleopEvents.TERMINATE_EPISODE] = True
|
||||
self.keyboard_events["failure"] = False
|
||||
if self.keyboard_events["rerecord"]:
|
||||
events[TeleopEvents.RERECORD_EPISODE] = True
|
||||
events[TeleopEvents.TERMINATE_EPISODE] = True
|
||||
self.keyboard_events["rerecord"] = False
|
||||
|
||||
# Always report intervention state
|
||||
events[TeleopEvents.IS_INTERVENTION] = self.is_intervening
|
||||
|
||||
return events
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect and cleanup."""
|
||||
self.stop_event.set()
|
||||
if self.keyboard_thread:
|
||||
self.keyboard_thread.join(timeout=1.0)
|
||||
super().disconnect()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset leader-follower state."""
|
||||
self.is_intervening = False
|
||||
self.leader_torque_enabled = True
|
||||
self.leader_tracking_error_queue.clear()
|
||||
self.keyboard_events = {
|
||||
"intervention": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
"rerecord": False,
|
||||
}
|
||||
@@ -52,9 +52,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
|
||||
return SO100Leader(config)
|
||||
elif config.type == "so101_leader":
|
||||
from .so_leader import SO101Leader
|
||||
from .so_leader import SO101LeaderFollower
|
||||
|
||||
return SO101Leader(config)
|
||||
if getattr(config, "leader_follower_mode", False):
|
||||
return SO101LeaderFollower(config)
|
||||
elif config.type == "mock_teleop":
|
||||
from tests.mocks.mock_teleop import MockTeleop
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user