mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
9 Commits
016799dfa1
...
chore/colo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6021554770 | ||
|
|
e99c55af4b | ||
|
|
408e0ca763 | ||
|
|
ce24063efd | ||
|
|
82934719db | ||
|
|
401a217597 | ||
|
|
40094b0464 | ||
|
|
fdbfc015a2 | ||
|
|
d656da8ccc |
12
.github/workflows/stale.yml
vendored
12
.github/workflows/stale.yml
vendored
@@ -24,14 +24,14 @@ on:
|
||||
|
||||
env:
|
||||
CLOSE_ISSUE_MESSAGE: >
|
||||
This issue was closed because it has been stalled for 14 days with no activity.
|
||||
This issue was closed because it has been stalled for 30 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
CLOSE_PR_MESSAGE: >
|
||||
This PR was closed because it has been stalled for 21 days with no activity.
|
||||
This PR was closed because it has been stalled for 30 days with no activity.
|
||||
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
|
||||
WARN_ISSUE_MESSAGE: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity (6 months). It will be closed if no further activity occurs.
|
||||
recent activity (1 year). It will be closed if no further activity occurs.
|
||||
Any change, comment or update to this issue will reset this count.
|
||||
Thank you for your contributions.
|
||||
WARN_PR_MESSAGE: >
|
||||
@@ -59,10 +59,10 @@ jobs:
|
||||
stale-pr-label: stale
|
||||
exempt-issue-labels: never-stale
|
||||
exempt-pr-labels: never-stale
|
||||
days-before-issue-stale: 180
|
||||
days-before-issue-close: 14
|
||||
days-before-issue-stale: 365
|
||||
days-before-issue-close: 30
|
||||
days-before-pr-stale: 365
|
||||
days-before-pr-close: 21
|
||||
days-before-pr-close: 30
|
||||
delete-branch: true
|
||||
close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }}
|
||||
close-pr-message: ${{ env.CLOSE_PR_MESSAGE }}
|
||||
|
||||
@@ -39,7 +39,6 @@ from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoEncoderConfig,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
)
|
||||
@@ -252,13 +251,10 @@ def benchmark_encoding_decoding(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
camera_encoder_config=VideoEncoderConfig(
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
preset=encoding_cfg.get("preset"),
|
||||
),
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -18,9 +18,8 @@
|
||||
# docker build -f docker/Dockerfile.internal -t lerobot-internal .
|
||||
|
||||
# Configure the base image for CI with GPU access
|
||||
# TODO(Steven): Bump these versions
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG OS_VERSION=22.04
|
||||
ARG CUDA_VERSION=12.6.3
|
||||
ARG OS_VERSION=24.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
# Define Python version argument
|
||||
@@ -36,16 +35,13 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
|
||||
# Install Python, system dependencies, and uv (as root)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common build-essential git curl \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
build-essential git curl \
|
||||
libglib2.0-0 libgl1 libegl1 ffmpeg \
|
||||
libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
cmake pkg-config ninja-build \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
|
||||
@@ -47,6 +47,8 @@
|
||||
title: π₀-FAST (Pi0Fast)
|
||||
- local: pi05
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: eo1
|
||||
title: EO-1
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
|
||||
@@ -90,6 +90,6 @@ lerobot-record \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
@@ -194,7 +194,7 @@ lerobot-record \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
168
docs/source/eo1.mdx
Normal file
168
docs/source/eo1.mdx
Normal file
@@ -0,0 +1,168 @@
|
||||
# EO-1
|
||||
|
||||
EO-1 is a **Vision-Language-Action policy for robot control**. The LeRobot implementation integrates EO-1 with the standard LeRobot training, evaluation, processor interface.
|
||||
|
||||
## Model Overview
|
||||
|
||||
EO-1 uses a Qwen2.5-VL backbone for vision-language understanding and adds a continuous flow-matching action head for robot control. The policy formats each robot-control sample as a multimodal conversation: camera images are passed to Qwen2.5-VL, the robot state is represented with EO-1 state tokens, and the future action chunk is represented with EO-1 action tokens.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png"
|
||||
alt="An overview of EO-1"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
During training, EO-1 learns to denoise continuous action chunks at the action-token positions. During inference, it samples an action chunk, returns continuous actions, and executes `n_action_steps` from the chunk before sampling again.
|
||||
|
||||
### What the LeRobot Integration Covers
|
||||
|
||||
- Standard `policy.type=eo1` configuration through LeRobot
|
||||
- Qwen2.5-VL image and text preprocessing through policy processors
|
||||
- Continuous flow-matching action prediction
|
||||
- Checkpoint save/load through LeRobot policy APIs
|
||||
- Training with `lerobot-train` and evaluation with `lerobot-eval`
|
||||
|
||||
The broader EO-1 project also includes interleaved vision-text-action pretraining and multimodal reasoning workflows. This page focuses on the LeRobot robot-control policy path.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
1. Install LeRobot by following the [Installation Guide](./installation).
|
||||
2. Install EO-1 dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1]"
|
||||
```
|
||||
|
||||
3. If you want to train or evaluate on LIBERO, install the LIBERO dependencies too:
|
||||
|
||||
```bash
|
||||
pip install -e ".[eo1,libero]"
|
||||
```
|
||||
|
||||
EO-1 can use the standard PyTorch scaled-dot-product attention backend through `policy.attn_implementation=sdpa`. If your environment has a compatible `flash_attn` installation, you can request `policy.attn_implementation=flash_attention_2`.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
EO-1 expects a LeRobot dataset with:
|
||||
|
||||
- At least one visual observation, for example `observation.images.image`
|
||||
- `observation.state`
|
||||
- `action`
|
||||
- A language task instruction through the dataset `task` field
|
||||
|
||||
If your dataset uses different observation names, use `rename_map` to align them with the names expected by your training or evaluation setup.
|
||||
|
||||
## Usage
|
||||
|
||||
To use EO-1 in a LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=eo1
|
||||
```
|
||||
|
||||
By default, a new EO-1 policy initializes its backbone from:
|
||||
|
||||
```python
|
||||
policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct
|
||||
```
|
||||
|
||||
Once a LeRobot-format EO-1 checkpoint is available, load it with:
|
||||
|
||||
```python
|
||||
policy.path=your-org/your-eo1-checkpoint
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Training Command Example
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.type=eo1 \
|
||||
--policy.vlm_base=Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.attn_implementation=sdpa \
|
||||
--policy.gradient_checkpointing=false \
|
||||
--output_dir=./outputs/eo1_training \
|
||||
--job_name=eo1_training \
|
||||
--steps=300000 \
|
||||
--batch_size=16 \
|
||||
--policy.device=cuda
|
||||
```
|
||||
|
||||
### Key Training Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------------------- | ----------------------------- | ----------------------------------------------------------------------- |
|
||||
| `policy.vlm_base` | `Qwen/Qwen2.5-VL-3B-Instruct` | Qwen2.5-VL checkpoint used to initialize a new policy |
|
||||
| `policy.dtype` | `auto` | Backbone dtype request: `auto`, `bfloat16`, or `float32` |
|
||||
| `policy.attn_implementation` | `None` | Optional Qwen attention backend, such as `sdpa` |
|
||||
| `policy.gradient_checkpointing` | `false` | Reduces memory usage during training |
|
||||
| `policy.chunk_size` | `8` | Number of future actions predicted per chunk |
|
||||
| `policy.n_action_steps` | `8` | Number of actions consumed from a sampled chunk |
|
||||
| `policy.num_denoise_steps` | `10` | Number of flow-matching denoising steps used during sampling |
|
||||
| `policy.max_state_dim` | `32` | State padding dimension |
|
||||
| `policy.max_action_dim` | `32` | Action padding dimension |
|
||||
| `policy.force_fp32_autocast` | `true` | Keeps the flow head in fp32 even when the backbone uses mixed precision |
|
||||
| `policy.supervise_padding_action_dims` | `true` | Controls whether padded action dimensions are supervised |
|
||||
| `policy.supervise_padding_actions` | `true` | Controls whether padded future action rows are supervised |
|
||||
|
||||
## Evaluation
|
||||
|
||||
EO-1 can be evaluated through `lerobot-eval` once you have a LeRobot-format checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=20
|
||||
```
|
||||
|
||||
For datasets or environments whose camera names differ from the checkpoint configuration, pass a `rename_map`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=your-org/your-eo1-checkpoint \
|
||||
--env.type=libero \
|
||||
--env.task=libero_object \
|
||||
--rename_map='{"observation.images.image2":"observation.images.wrist_image"}'
|
||||
```
|
||||
|
||||
## Configuration Notes
|
||||
|
||||
### Image Processing
|
||||
|
||||
EO-1 uses the Qwen2.5-VL processor. The `policy.image_min_pixels` and `policy.image_max_pixels` settings control the image resizing bounds before the visual tokens are passed into the backbone.
|
||||
|
||||
### State and Action Dimensions
|
||||
|
||||
The policy pads state and action vectors to `policy.max_state_dim` and `policy.max_action_dim` before the EO-1 flow head. Predictions are cropped back to the original action dimension before being returned by the policy.
|
||||
|
||||
### Attention Backend
|
||||
|
||||
Use `policy.attn_implementation=sdpa` for a portable setup. Use `flash_attention_2` only when `flash_attn` is installed and compatible with your environment.
|
||||
|
||||
## References
|
||||
|
||||
- [EO-1 project](https://github.com/EO-Robotics/EO1)
|
||||
- [EO-1 paper](https://arxiv.org/abs/2508.21112)
|
||||
- [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{eo1,
|
||||
title={EO-1: Interleaved Vision-Text-Action Pretraining for General Robot Control},
|
||||
author={Delin Qu and Haoming Song and Qizhi Chen and Zhaoqing Chen and Xianqiang Gao and Xinyi Ye and Qi Lv and Modi Shi and Guanghui Ren and Cheng Ruan and Maoqing Yao and Haoran Yang and Jiacheng Bao and Bin Zhao and Dong Wang},
|
||||
journal={arXiv preprint},
|
||||
year={2025},
|
||||
url={https://arxiv.org/abs/2508.21112}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This LeRobot integration follows the **Apache 2.0 License** used by LeRobot. Check the upstream EO-1 model and dataset pages for the licenses of released EO-1 checkpoints and data.
|
||||
@@ -123,7 +123,7 @@ lerobot-record \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
|
||||
@@ -232,7 +232,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -278,6 +278,6 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -193,7 +193,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
@@ -43,7 +43,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -203,7 +203,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
|
||||
@@ -14,22 +14,12 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
|
||||
|
||||
## 2. Tuning Parameters
|
||||
|
||||
All encoding parameters are grouped under `camera_encoder_config` (a `VideoEncoderConfig` dataclass), accessible from the CLI via `--dataset.camera_encoder_config.<field>`.
|
||||
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------------------- | ------------- | ------------- | ------------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.camera_encoder_config.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `pix_fmt` | `--dataset.camera_encoder_config.pix_fmt` | `str` | `"yuv420p"` | Pixel format |
|
||||
| `g` | `--dataset.camera_encoder_config.g` | `int \| None` | `2` | GOP size (keyframe interval) |
|
||||
| `crf` | `--dataset.camera_encoder_config.crf` | `int \| None` | `30` | Quality level (mapped to codec-specific parameter) |
|
||||
| `preset` | `--dataset.camera_encoder_config.preset` | `int \| None` | `12` | Speed preset (libsvtav1 only, 0 = slowest … 13 = fastest) |
|
||||
| `fast_decode` | `--dataset.camera_encoder_config.fast_decode` | `int` | `0` | Fast-decode tuning level |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance (global). `None` lets the codec decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
|
||||
> [!TIP]
|
||||
> Not all parameters apply to every codec. `VideoEncoderConfig` will warn at startup if you set a parameter that your chosen codec ignores (e.g. `preset` with `h264_nvenc`).
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
@@ -50,7 +40,7 @@ Streaming encoding means the CPU is encoding video **during** the capture loop,
|
||||
|
||||
### `encoder_threads` Tuning
|
||||
|
||||
This parameter (`--dataset.encoder_threads`) controls how many threads each encoder instance uses internally:
|
||||
This parameter controls how many threads each encoder instance uses internally:
|
||||
|
||||
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
|
||||
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
|
||||
@@ -92,15 +82,15 @@ Use HW encoding when:
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ---------------------------------------------------------- |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder_config.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder_config.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder_config.vcodec=auto` |
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
@@ -110,15 +100,15 @@ Use HW encoding when:
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder_config.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder_config.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder_config.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
@@ -156,10 +146,10 @@ On very constrained systems, streaming encoding may compete too heavily with the
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.camera_encoder_config.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
|
||||
`camera_encoder_config.vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
|
||||
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
|
||||
|
||||
@@ -117,10 +117,10 @@ lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.camera_encoder_config.vcodec libsvtav1 \
|
||||
--operation.camera_encoder_config.pix_fmt yuv420p \
|
||||
--operation.camera_encoder_config.g 2 \
|
||||
--operation.camera_encoder_config.crf 30
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
@@ -147,14 +147,11 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `camera_encoder_config`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder_config.<field>`:
|
||||
- `vcodec`: Video codec — `h264`, `hevc`, `libsvtav1`, `auto`, or hardware codecs (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format — `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: GOP size — lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Quality level — lower is better, 0 is lossless (default: 30)
|
||||
- `preset`: Speed preset, libsvtav1 only (default: 12)
|
||||
- `fast_decode`: Fast-decode tuning (default: 0)
|
||||
- `encoder_threads`: Threads per encoder instance — global setting, separate from `camera_encoder_config` (default: None)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
|
||||
av-dep = ["av>=15.0.0,<16.0.0"]
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
|
||||
transformers-dep = ["transformers>=5.4.0,<5.6.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -194,6 +194,7 @@ groot = [
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.transforms import ImageTransformsConfig
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +34,7 @@ class DatasetConfig:
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
|
||||
@@ -40,19 +40,10 @@ from .io_utils import load_episodes, write_stats
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import (
|
||||
check_video_encoder_config_pyav,
|
||||
detect_available_encoders_pyav,
|
||||
get_codec,
|
||||
)
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import (
|
||||
VideoEncoderConfig,
|
||||
VideoEncodingManager,
|
||||
camera_encoder_defaults,
|
||||
)
|
||||
from .video_utils import VideoEncodingManager
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
# and legacy migration constants are intentionally NOT re-exported here.
|
||||
@@ -67,20 +58,15 @@ __all__ = [
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncoderConfig",
|
||||
"VideoEncodingManager",
|
||||
"camera_encoder_defaults",
|
||||
"add_features",
|
||||
"aggregate_datasets",
|
||||
"aggregate_pipeline_dataset_features",
|
||||
"aggregate_stats",
|
||||
"check_video_encoder_config_pyav",
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"delete_episodes",
|
||||
"detect_available_encoders_pyav",
|
||||
"get_codec",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
"make_dataset",
|
||||
|
||||
@@ -332,6 +332,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
@@ -416,7 +417,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
compatibility_check=True,
|
||||
)
|
||||
# Update duration of this destination file
|
||||
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||
|
||||
@@ -48,7 +48,7 @@ from .utils import (
|
||||
is_valid_version,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import VideoEncoderConfig, get_video_info
|
||||
from .video_utils import get_video_info
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
@@ -510,23 +510,10 @@ class LeRobotDatasetMetadata:
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
camera_encoder_config: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
@@ -535,9 +522,7 @@ class LeRobotDatasetMetadata:
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(
|
||||
video_path, camera_encoder_config=camera_encoder_config
|
||||
)
|
||||
self.info.features[key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
|
||||
@@ -62,7 +62,7 @@ from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import VideoEncoderConfig, encode_video_frames, get_video_info
|
||||
from .video_utils import encode_video_frames, get_video_info
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -92,7 +92,6 @@ def delete_episodes(
|
||||
episode_indices: list[int],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Delete episodes from a LeRobotDataset and create a new dataset.
|
||||
|
||||
@@ -101,7 +100,6 @@ def delete_episodes(
|
||||
episode_indices: List of episode indices to delete.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
camera_encoder_config: Video encoder settings used when re-encoding video segments (default: :class:`VideoEncoderConfig()`).
|
||||
"""
|
||||
if not episode_indices:
|
||||
raise ValueError("No episodes to delete")
|
||||
@@ -134,7 +132,7 @@ def delete_episodes(
|
||||
|
||||
video_metadata = None
|
||||
if dataset.meta.video_keys:
|
||||
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping, camera_encoder_config)
|
||||
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
|
||||
|
||||
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
|
||||
|
||||
@@ -156,7 +154,6 @@ def split_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
splits: dict[str, float | list[int]],
|
||||
output_dir: str | Path | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> dict[str, LeRobotDataset]:
|
||||
"""Split a LeRobotDataset into multiple smaller datasets.
|
||||
|
||||
@@ -165,7 +162,6 @@ def split_dataset(
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
split names to fractions (must sum to <= 1.0).
|
||||
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
camera_encoder_config: Video encoder settings used when re-encoding video segments (default: :class:`VideoEncoderConfig()`).
|
||||
|
||||
Examples:
|
||||
Split by specific episodes
|
||||
@@ -226,9 +222,7 @@ def split_dataset(
|
||||
|
||||
video_metadata = None
|
||||
if dataset.meta.video_keys:
|
||||
video_metadata = _copy_and_reindex_videos(
|
||||
dataset, new_meta, episode_mapping, camera_encoder_config
|
||||
)
|
||||
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
|
||||
|
||||
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
|
||||
|
||||
@@ -584,7 +578,8 @@ def _keep_episodes_from_video_with_av(
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
@@ -598,10 +593,9 @@ def _keep_episodes_from_video_with_av(
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
camera_encoder_config: Video encoder settings (default: :class:`VideoEncoderConfig()`).
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
from fractions import Fraction
|
||||
|
||||
import av
|
||||
@@ -625,12 +619,12 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Convert fps to Fraction for PyAV compatibility.
|
||||
fps_fraction = Fraction(fps).limit_denominator(1000)
|
||||
v_out = out.add_stream(camera_encoder_config.vcodec, rate=fps_fraction)
|
||||
v_out = out.add_stream(vcodec, rate=fps_fraction)
|
||||
|
||||
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
|
||||
v_out.width = v_in.codec_context.width
|
||||
v_out.height = v_in.codec_context.height
|
||||
v_out.pix_fmt = camera_encoder_config.pix_fmt
|
||||
v_out.pix_fmt = pix_fmt
|
||||
|
||||
# Set time_base to match the frame rate for proper timestamp handling.
|
||||
v_out.time_base = Fraction(1, int(fps))
|
||||
@@ -693,7 +687,8 @@ def _copy_and_reindex_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_mapping: dict[int, int],
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> dict[int, dict]:
|
||||
"""Copy and filter video files, only re-encoding files with deleted episodes.
|
||||
|
||||
@@ -705,13 +700,10 @@ def _copy_and_reindex_videos(
|
||||
src_dataset: Source dataset to copy from
|
||||
dst_meta: Destination metadata object
|
||||
episode_mapping: Mapping from old episode indices to new indices
|
||||
camera_encoder_config: Video encoder settings used when re-encoding segments (default: :class:`VideoEncoderConfig()`).
|
||||
|
||||
Returns:
|
||||
dict mapping episode index to its video metadata (chunk_index, file_index, timestamps)
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
@@ -800,7 +792,8 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
camera_encoder_config,
|
||||
vcodec,
|
||||
pix_fmt,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -1271,7 +1264,11 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1285,7 +1282,11 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
camera_encoder_config: Video encoder settings used for calibration encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1321,7 +1322,11 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1639,7 +1644,11 @@ def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1654,7 +1663,11 @@ def convert_image_to_video_dataset(
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
camera_encoder_config: Video encoder settings (default: :class:`VideoEncoderConfig()`).
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
@@ -1663,9 +1676,6 @@ def convert_image_to_video_dataset(
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
@@ -1689,10 +1699,7 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder_config.vcodec}, pixel format: {camera_encoder_config.pix_fmt}, "
|
||||
f"GOP: {camera_encoder_config.g}, CRF: {camera_encoder_config.crf}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1762,7 +1769,11 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1804,7 +1815,11 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1850,9 +1865,7 @@ def convert_image_to_video_dataset(
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder_config=camera_encoder_config
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
|
||||
@@ -52,7 +52,6 @@ from .utils import (
|
||||
)
|
||||
from .video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
concatenate_video_files,
|
||||
encode_video_frames,
|
||||
get_video_duration_in_s,
|
||||
@@ -66,19 +65,14 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
@@ -95,21 +89,20 @@ class DatasetWriter:
|
||||
self,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
vcodec: str,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
initial_frames: int = 0,
|
||||
):
|
||||
"""Initialize the writer with metadata, codec, and encoder config.
|
||||
"""Initialize the writer with metadata, codec, and encoding config.
|
||||
|
||||
Args:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
camera_encoder_config: Video encoder settings applied to all cameras.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
|
||||
@@ -118,7 +111,7 @@ class DatasetWriter:
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._vcodec = vcodec
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
@@ -291,7 +284,7 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder_config,
|
||||
self._vcodec,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -502,7 +495,7 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key, camera_encoder_config=self._camera_encoder_config)
|
||||
self._meta.update_video_info(video_key)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -571,12 +564,7 @@ class DatasetWriter:
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
return _encode_video_worker(
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder_config,
|
||||
self._encoder_threads,
|
||||
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
|
||||
)
|
||||
|
||||
def close_writer(self) -> None:
|
||||
|
||||
@@ -36,8 +36,8 @@ from .utils import (
|
||||
)
|
||||
from .video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
get_safe_default_video_backend,
|
||||
get_safe_default_codec,
|
||||
resolve_vcodec,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -58,10 +58,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -177,15 +177,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
camera_encoder_config (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). Defaults to
|
||||
:class:`~lerobot.datasets.video_utils.VideoEncoderConfig` defaults when ``None``.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
|
||||
Note:
|
||||
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
|
||||
@@ -201,12 +202,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self._return_uint8 = return_uint8
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
if self._requested_root is not None:
|
||||
@@ -252,16 +251,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps,
|
||||
self._camera_encoder_config,
|
||||
self._encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
self.writer = DatasetWriter(
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
camera_encoder_config=self._camera_encoder_config,
|
||||
encoder_threads=self._encoder_threads,
|
||||
vcodec=self._vcodec,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
initial_frames=self.meta.total_frames,
|
||||
@@ -302,15 +298,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@staticmethod
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
encoder_threads: int | None,
|
||||
vcodec: str,
|
||||
encoder_queue_maxsize: int,
|
||||
encoder_threads: int | None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
|
||||
# ── Metadata properties ───────────────────────────────────────────
|
||||
@@ -625,7 +625,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -656,23 +656,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend (used when reading back).
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
camera_encoder_config: Video encoder settings for cameras; defaults
|
||||
match :class:`~lerobot.datasets.video_utils.VideoEncoderConfig`
|
||||
when ``None``.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
vcodec: Video codec for encoding. Options include ``'libsvtav1'``,
|
||||
``'h264'``, ``'hevc'``, ``'auto'``.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
before flushing to parquet.
|
||||
streaming_encoding: If ``True``, encode video frames in real-time
|
||||
during capture instead of writing images first.
|
||||
encoder_queue_maxsize: Max buffered frames per camera when using
|
||||
streaming encoding.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A new :class:`LeRobotDataset` in write mode.
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -693,24 +690,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize
|
||||
)
|
||||
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
vcodec=vcodec,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -733,12 +729,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Resume recording on an existing dataset.
|
||||
|
||||
@@ -761,16 +757,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend for reading back data.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
camera_encoder_config: Video encoder settings for cameras; defaults
|
||||
match :class:`~lerobot.datasets.video_utils.VideoEncoderConfig`
|
||||
when ``None``.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
vcodec: Video codec for encoding.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
image_writer_threads: Threads for async image writing.
|
||||
streaming_encoding: If ``True``, encode video in real-time during
|
||||
capture.
|
||||
encoder_queue_maxsize: Max buffered frames per camera for streaming.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A :class:`LeRobotDataset` in write mode, ready to append episodes.
|
||||
@@ -781,6 +774,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
|
||||
"the shared cache. Please provide a local directory path."
|
||||
)
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj._requested_root = Path(root)
|
||||
@@ -789,9 +783,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if obj._requested_root is not None:
|
||||
obj._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -800,25 +796,21 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.meta = LeRobotDatasetMetadata(
|
||||
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
obj.root = obj.meta.root
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer for appending
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, camera_encoder_config, encoder_threads, encoder_queue_maxsize
|
||||
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
vcodec=vcodec,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
|
||||
|
||||
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
|
||||
Checks degrade to a no-op when the target codec isn't available locally.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import av
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
try:
|
||||
return av.codec.Codec(vcodec, "w")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_options_by_name(vcodec: str) -> dict[str, av.option.Option]:
|
||||
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return {}
|
||||
return {opt.name: opt for opt in codec.descriptor.options}
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
|
||||
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return ()
|
||||
return tuple(fmt.name for fmt in (codec.video_formats or []))
|
||||
|
||||
|
||||
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
|
||||
|
||||
Each name is probed directly via :func:`get_codec`; input order is preserved.
|
||||
"""
|
||||
if isinstance(encoders, str):
|
||||
encoders = [encoders]
|
||||
|
||||
available: list[str] = []
|
||||
for name in encoders:
|
||||
codec = get_codec(name)
|
||||
if codec is not None and codec.type == "video":
|
||||
available.append(name)
|
||||
else:
|
||||
logger.debug("encoder '%s' not available as video encoder", name)
|
||||
return available
|
||||
|
||||
|
||||
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
|
||||
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
|
||||
type_name = opt.type.name
|
||||
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
try:
|
||||
num_val = float(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
) from e
|
||||
elif isinstance(value, (float, int)):
|
||||
num_val = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
|
||||
# Check integer type compatibility
|
||||
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
|
||||
raise ValueError(
|
||||
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
|
||||
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
|
||||
)
|
||||
|
||||
# Check numeric range compatibility
|
||||
lo, hi = float(opt.min), float(opt.max)
|
||||
if lo < hi and not (lo <= num_val <= hi):
|
||||
raise ValueError(
|
||||
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
|
||||
)
|
||||
|
||||
elif type_name == "STRING":
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
|
||||
if isinstance(value, str):
|
||||
str_val = value
|
||||
elif isinstance(value, (int, float)):
|
||||
str_val = str(value)
|
||||
else:
|
||||
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
|
||||
|
||||
# Check string choice compatibility
|
||||
choices = [c.name for c in (opt.choices or [])]
|
||||
if choices and str_val not in choices:
|
||||
raise ValueError(
|
||||
f"{label}={str_val!r} is not a supported choice for codec "
|
||||
f"{vcodec!r}; valid choices: {choices}"
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
formats = _get_codec_video_formats(vcodec)
|
||||
if formats and pix_fmt not in formats:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
|
||||
f"supported pixel formats: {list(formats)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any], config: VideoEncoderConfig) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
for key, value in codec_options.items():
|
||||
# GOP size is not a codec-specific option, it has to be validated separately.
|
||||
if key == "g":
|
||||
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
|
||||
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
|
||||
continue
|
||||
if key not in supported_options:
|
||||
continue
|
||||
opt = supported_options[key]
|
||||
label = f"extra_options[{key!r}]" if key in config.extra_options else key
|
||||
_check_option_value(vcodec, label, value, opt)
|
||||
|
||||
|
||||
def check_video_encoder_config_pyav(config: VideoEncoderConfig) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.datasets.video_utils.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
ValueError: on the first incompatibility encountered.
|
||||
"""
|
||||
vcodec = config.vcodec
|
||||
options = _get_codec_options_by_name(vcodec)
|
||||
if not options:
|
||||
logger.warning(
|
||||
"Codec %r is not available in the bundled FFmpeg build; ",
|
||||
vcodec,
|
||||
)
|
||||
return
|
||||
_check_pixel_format(config.vcodec, config.pix_fmt)
|
||||
_check_codec_options(config.vcodec, config.get_codec_options(), config)
|
||||
@@ -22,7 +22,7 @@ import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
@@ -37,11 +37,7 @@ import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.datasets.pyav_utils import (
|
||||
check_video_encoder_config_pyav,
|
||||
detect_available_encoders_pyav,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -58,154 +54,68 @@ HW_ENCODERS = [
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
return options
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
"""Video encoder configuration.
|
||||
|
||||
Attributes:
|
||||
vcodec: FFmpeg encoder name. ``"auto"`` is resolved during
|
||||
construction (HW encoder if available, else ``libsvtav1``).
|
||||
pix_fmt: Pixel format (e.g. ``"yuv420p"``).
|
||||
g: GOP size (keyframe interval).
|
||||
crf: Quality level — mapped to the native quality parameter of the
|
||||
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
|
||||
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
|
||||
preset: Speed/quality preset. Accepted type is per-codec.
|
||||
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
|
||||
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
|
||||
set ``tune=fastdecode``. Ignored for other codecs.
|
||||
video_backend: Python library driving FFmpeg for encoding. Only ``"pyav"``
|
||||
is currently supported.
|
||||
extra_options: Free-form dictionary of additional FFmpeg options
|
||||
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
|
||||
"""
|
||||
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int | None = 2
|
||||
crf: int | None = 30
|
||||
preset: int | str | None = None
|
||||
fast_decode: int = 0
|
||||
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
|
||||
# two backends (encoding and decoding).
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
if self.preset is None and self.vcodec == "libsvtav1":
|
||||
self.preset = LIBSVTAV1_DEFAULT_PRESET
|
||||
|
||||
self.validate()
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Detect available encoders based on the video backend."""
|
||||
if self.video_backend == "pyav":
|
||||
return detect_available_encoders_pyav(encoders)
|
||||
else:
|
||||
return []
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate the video encoder config."""
|
||||
if self.video_backend == "pyav":
|
||||
check_video_encoder_config_pyav(self)
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1.
|
||||
|
||||
Any explicitly-requested codec that isn't in the local FFmpeg build is
|
||||
also silently rewritten to ``libsvtav1`` so encoding never hard-fails on
|
||||
a host missing the requested encoder.
|
||||
"""
|
||||
if self.vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if self.vcodec == "auto":
|
||||
available = self.detect_available_encoders(HW_ENCODERS)
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
self.vcodec = encoder
|
||||
return
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
self.vcodec = "libsvtav1"
|
||||
|
||||
if self.detect_available_encoders(self.vcodec):
|
||||
logger.info(f"Using video codec: {self.vcodec}")
|
||||
self.vcodec = self.vcodec
|
||||
return
|
||||
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
|
||||
|
||||
def get_codec_options(
|
||||
self, encoder_threads: int | None = None, as_strings: bool = False
|
||||
) -> dict[str, str]:
|
||||
"""Translate the tuning fields to codec-specific FFmpeg options.
|
||||
|
||||
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
|
||||
|
||||
Args:
|
||||
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
|
||||
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
|
||||
For h264/hevc, this is mapped to ``threads``.
|
||||
Hardware encoders ignore this parameter.
|
||||
as_strings: If ``True``, casts values to strings.
|
||||
"""
|
||||
opts: dict[str, Any] = {}
|
||||
|
||||
def set_if(key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
opts[key] = value if not as_strings else str(value)
|
||||
|
||||
# GOP size is not a codec-specific option, so it is always set.
|
||||
set_if("g", self.g)
|
||||
|
||||
if self.vcodec == "libsvtav1":
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
svtav1_parts: list[str] = []
|
||||
if self.fast_decode is not None:
|
||||
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
|
||||
if encoder_threads is not None:
|
||||
svtav1_parts.append(f"lp={encoder_threads}")
|
||||
if svtav1_parts:
|
||||
opts["svtav1-params"] = ":".join(svtav1_parts)
|
||||
elif self.vcodec in ("h264", "hevc"):
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
if self.fast_decode:
|
||||
opts["tune"] = "fastdecode"
|
||||
set_if("threads", encoder_threads)
|
||||
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
if self.crf is not None:
|
||||
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
|
||||
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
opts["rc"] = "constqp"
|
||||
set_if("qp", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "h264_vaapi":
|
||||
set_if("qp", self.crf)
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
|
||||
# Extra options are merged last but never override structured fields (values are kept as given).
|
||||
for k, v in self.extra_options.items():
|
||||
if k not in opts:
|
||||
set_if(k, v)
|
||||
|
||||
return opts
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logger.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
@@ -232,7 +142,7 @@ def decode_video_frames(
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
backend = get_safe_default_codec()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
@@ -490,17 +400,18 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
vcodec = camera_encoder_config.vcodec
|
||||
pix_fmt = camera_encoder_config.pix_fmt
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -511,18 +422,42 @@ def encode_video_frames(
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logger.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
# Get input frames
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
video_options = camera_encoder_config.get_codec_options(encoder_threads, as_strings=True)
|
||||
# Define video codec options
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -559,10 +494,7 @@ def encode_video_frames(
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str],
|
||||
output_video_path: Path,
|
||||
overwrite: bool = True,
|
||||
compatibility_check: bool = False,
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
@@ -575,7 +507,6 @@ def concatenate_video_files(
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
compatibility_check: Whether to check if the input videos are compatible. Default is False.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
@@ -594,22 +525,6 @@ def concatenate_video_files(
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
|
||||
# This check may be skipped at recording time as videos are encoded with the same encoder config.
|
||||
if compatibility_check:
|
||||
reference_video_info = get_video_info(input_video_paths[0])
|
||||
for input_path in input_video_paths[1:]:
|
||||
video_info = get_video_info(input_path)
|
||||
if (
|
||||
video_info["video.height"] != reference_video_info["video.height"]
|
||||
or video_info["video.width"] != reference_video_info["video.width"]
|
||||
or video_info["video.fps"] != reference_video_info["video.fps"]
|
||||
or video_info["video.codec"] != reference_video_info["video.codec"]
|
||||
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
|
||||
)
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
@@ -676,20 +591,26 @@ class _CameraEncoderThread(threading.Thread):
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
codec_options: dict[str, str],
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.codec_options = codec_options
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
@@ -725,9 +646,19 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
@@ -793,25 +724,22 @@ class StreamingVideoEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fps: Frames per second for the output videos.
|
||||
camera_encoder_config: Video encoder settings applied to all cameras.
|
||||
When ``None``, :class:`VideoEncoderConfig` defaults are used.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
"""
|
||||
self.fps = fps
|
||||
self._camera_encoder_config = camera_encoder_config or VideoEncoderConfig()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
@@ -842,19 +770,18 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
vcodec = self._camera_encoder_config.vcodec
|
||||
codec_options = self._camera_encoder_config.get_codec_options(
|
||||
self._encoder_threads, as_strings=True
|
||||
)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=self._camera_encoder_config.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self.encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -1059,18 +986,8 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
camera_encoder_config: "VideoEncoderConfig | None" = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
camera_encoder_config: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting video stream information
|
||||
@@ -1101,11 +1018,6 @@ def get_video_info(
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided
|
||||
if camera_encoder_config is not None:
|
||||
for field_name, field_value in asdict(camera_encoder_config).items():
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from lerobot.utils.action_interpolator import ActionInterpolator as ActionInterp
|
||||
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config as EO1Config
|
||||
from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
@@ -41,6 +42,7 @@ __all__ = [
|
||||
"DiffusionConfig",
|
||||
"GrootConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"EO1Config",
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
|
||||
1
src/lerobot/policies/eo1/README.md
Symbolic link
1
src/lerobot/policies/eo1/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../docs/source/eo1.mdx
|
||||
7
src/lerobot/policies/eo1/__init__.py
Normal file
7
src/lerobot/policies/eo1/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from .configuration_eo1 import EO1Config
|
||||
from .modeling_eo1 import EO1Policy
|
||||
from .processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
__all__ = ["EO1Config", "EO1Policy", "make_eo1_pre_post_processors"]
|
||||
193
src/lerobot/policies/eo1/configuration_eo1.py
Normal file
193
src/lerobot/policies/eo1/configuration_eo1.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLTextConfig,
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
else:
|
||||
Qwen2_5_VLConfig = None
|
||||
Qwen2_5_VLTextConfig = None
|
||||
Qwen2_5_VLVisionConfig = None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("eo1")
|
||||
@dataclass
|
||||
class EO1Config(PreTrainedConfig):
|
||||
"""Configuration for native EO1 policy integration in LeRobot."""
|
||||
|
||||
vlm_base: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
vlm_config: dict | None = None
|
||||
|
||||
# Vision processor settings.
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
# Execution and action horizon.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 8
|
||||
n_action_steps: int = 8
|
||||
|
||||
# State/action padding to match EO1 flow head dimensionality.
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching sampling.
|
||||
num_denoise_steps: int = 10
|
||||
num_action_layers: int = 2
|
||||
action_act: str = "linear"
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
supervise_padding_action_dims: bool = True
|
||||
supervise_padding_actions: bool = True
|
||||
|
||||
# Policy-level dtype request for the Qwen backbone.
|
||||
# - "auto": follow the backbone config/checkpoint default dtype. For Qwen2.5-VL this resolves to bf16.
|
||||
# The EO1 flow-matching head still keeps its own parameters in fp32.
|
||||
# - "bfloat16": force the backbone to initialize/load in bf16 regardless of the saved config default.
|
||||
# - "float32": force the backbone to initialize/load in fp32 for maximum numerical conservatism.
|
||||
dtype: str = "auto" # Options: "auto", "bfloat16", "float32"
|
||||
force_fp32_autocast: bool = True
|
||||
|
||||
# Optional attention backend request passed through to the Qwen backbone.
|
||||
# Common values: None, "eager", "sdpa", "flash_attention_2".
|
||||
attn_implementation: str | None = None
|
||||
|
||||
# Training settings.
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Optimizer settings aligned with EO1/experiments/2_libero/train.sh and EO1 TrainPipelineConfig defaults.
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.1
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings aligned with EO1 train.sh: cosine schedule with warmup_ratio=0.03.
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 900 # 0.03 * 30_000 long-run steps
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# Populate the serialized backbone config only when the caller did not provide one.
|
||||
if self.vlm_config is None:
|
||||
require_package("transformers", extra="eo1")
|
||||
self.vlm_config = Qwen2_5_VLConfig.from_pretrained(self.vlm_base).to_dict()
|
||||
|
||||
@property
|
||||
def vlm_backbone_config(self) -> Qwen2_5_VLConfig:
|
||||
require_package("transformers", extra="eo1")
|
||||
config_dict = deepcopy(self.vlm_config)
|
||||
if self.attn_implementation is not None:
|
||||
config_dict["attn_implementation"] = self.attn_implementation
|
||||
return Qwen2_5_VLConfig(**config_dict)
|
||||
|
||||
@property
|
||||
def text_config(self) -> Qwen2_5_VLTextConfig:
|
||||
return self.vlm_backbone_config.text_config
|
||||
|
||||
@property
|
||||
def vision_config(self) -> Qwen2_5_VLVisionConfig:
|
||||
return self.vlm_backbone_config.vision_config
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up EO1 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"EO1 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
620
src/lerobot/policies/eo1/modeling_eo1.py
Normal file
620
src/lerobot/policies/eo1/modeling_eo1.py
Normal file
@@ -0,0 +1,620 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.utils.checkpoint
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from transformers.utils import torch_compilable_check
|
||||
else:
|
||||
ACT2FN = None
|
||||
Qwen2_5_VLForConditionalGeneration = None
|
||||
torch_compilable_check = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Pad the last dimension of a vector to new_dim with zeros.
|
||||
|
||||
Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] >= new_dim:
|
||||
return vector
|
||||
return F.pad(vector, (0, new_dim - vector.shape[-1]))
|
||||
|
||||
|
||||
class EO1Policy(PreTrainedPolicy):
|
||||
"""EO1 policy wrapper for LeRobot robot-only training/evaluation."""
|
||||
|
||||
config_class = EO1Config
|
||||
name = "eo1"
|
||||
|
||||
def __init__(self, config: EO1Config, **kwargs):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
if config.pretrained_path is None:
|
||||
# Initialize from pretrained VLM
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config.vlm_base,
|
||||
dtype=config.dtype,
|
||||
attn_implementation=config.attn_implementation,
|
||||
)
|
||||
else:
|
||||
vlm_backbone = Qwen2_5_VLForConditionalGeneration._from_config(
|
||||
config.vlm_backbone_config,
|
||||
dtype=config.vlm_backbone_config.dtype if config.dtype == "auto" else config.dtype,
|
||||
)
|
||||
|
||||
self.model = EO1VisionFlowMatchingModel(config, vlm_backbone)
|
||||
if config.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
self.model.to(config.device)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||
|
||||
@staticmethod
|
||||
def _get_model_inputs(batch: dict[str, Tensor], excluded_keys: set[str]) -> dict[str, Tensor]:
|
||||
return {key: value for key, value in batch.items() if key not in excluded_keys}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
state = self.prepare_state(batch[OBS_STATE])
|
||||
actions = self.prepare_action(batch[ACTION])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE, ACTION})
|
||||
loss = self.model(states=state, action=actions, **model_inputs)
|
||||
|
||||
loss_dict = {"loss": loss.item()}
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
states = self.prepare_state(batch[OBS_STATE])
|
||||
model_inputs = self._get_model_inputs(batch, {OBS_STATE})
|
||||
actions = self.model.sample_actions(states=states, **model_inputs).to(torch.float32)
|
||||
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
return actions[:, :, :original_action_dim]
|
||||
|
||||
def prepare_state(self, state: Tensor) -> Tensor:
|
||||
return pad_vector(state, self.config.max_state_dim)
|
||||
|
||||
def prepare_action(self, action: Tensor) -> Tensor:
|
||||
return pad_vector(action, self.config.max_action_dim)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
return torch.float32
|
||||
if device_type == "cpu":
|
||||
# CPU doesn't support bfloat16, use float32 instead
|
||||
if target_dtype == torch.bfloat16:
|
||||
return torch.float32
|
||||
if target_dtype == torch.float64:
|
||||
return torch.float64
|
||||
return target_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy)
|
||||
time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
class EO1VisionActionProjector(torch.nn.Sequential):
|
||||
"""This block implements the multi-layer perceptron (MLP) module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 2,
|
||||
activation_layer: str = "linear",
|
||||
bias: bool = True,
|
||||
device: Any = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
layers = []
|
||||
in_dim = in_channels
|
||||
hidden_channels = [in_dim] * (num_layers - 1) + [out_channels]
|
||||
for hidden_dim in hidden_channels[:-1]:
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device))
|
||||
layers.append(ACT2FN[activation_layer])
|
||||
in_dim = hidden_dim
|
||||
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias, dtype=dtype, device=device))
|
||||
super().__init__(*layers)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self[0].weight.dtype
|
||||
|
||||
|
||||
class EO1VisionFlowMatchingModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: EO1Config,
|
||||
vlm_backbone: Qwen2_5_VLForConditionalGeneration | None = None,
|
||||
):
|
||||
require_package("transformers", extra="eo1")
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
# Preserve the backbone dtype selected at construction time so Qwen's fp32 rotary buffers stay intact.
|
||||
self.vlm_backbone = vlm_backbone
|
||||
self.hidden_size = self.vlm_backbone.config.text_config.hidden_size
|
||||
max_state_dim = config.max_state_dim
|
||||
max_action_dim = config.max_action_dim
|
||||
self.state_proj = nn.Linear(max_state_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_in_proj = nn.Linear(max_action_dim, self.hidden_size, dtype=torch.float32)
|
||||
self.action_out_proj = EO1VisionActionProjector(
|
||||
self.hidden_size,
|
||||
max_action_dim,
|
||||
config.num_action_layers,
|
||||
config.action_act,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.action_time_mlp_in = nn.Linear(self.hidden_size * 2, self.hidden_size, dtype=torch.float32)
|
||||
self.action_time_mlp_out = nn.Linear(self.hidden_size, self.hidden_size, dtype=torch.float32)
|
||||
self.gradient_checkpointing_enabled = False
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.vlm_backbone.get_input_embeddings()
|
||||
|
||||
def flow_head_autocast_context(self):
|
||||
if self.config.force_fp32_autocast:
|
||||
return torch.autocast(
|
||||
device_type=self.state_proj.weight.device.type,
|
||||
enabled=False,
|
||||
)
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = True
|
||||
self.vlm_backbone.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
logger.info("Enabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""Disable gradient checkpointing for the Qwen2.5-VL backbone."""
|
||||
self.gradient_checkpointing_enabled = False
|
||||
self.vlm_backbone.gradient_checkpointing_disable()
|
||||
logger.info("Disabled gradient checkpointing for EO1VisionFlowMatchingModel")
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Apply manual gradient checkpointing to EO1 flow-head computations when training."""
|
||||
if self.gradient_checkpointing_enabled and self.training and torch.is_grad_enabled():
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(
|
||||
self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device
|
||||
)
|
||||
time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def get_placeholder_mask(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None,
|
||||
inputs_embeds: torch.FloatTensor | None,
|
||||
state_features: torch.FloatTensor | None = None,
|
||||
action_features: torch.FloatTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> tuple[torch.BoolTensor, torch.BoolTensor]:
|
||||
"""Return EO1 state/action placeholder masks, following Qwen's multimodal mask style."""
|
||||
if input_ids is None:
|
||||
special_state_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(state_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_state_mask = special_state_mask.all(-1)
|
||||
special_action_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(action_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_action_mask = special_action_mask.all(-1)
|
||||
else:
|
||||
special_state_mask = input_ids == state_token_id
|
||||
special_action_mask = input_ids == action_token_id
|
||||
|
||||
n_state_tokens = special_state_mask.sum()
|
||||
special_state_mask = (
|
||||
special_state_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if state_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_state_mask].numel() == state_features.numel(),
|
||||
f"State features and state tokens do not match, tokens: {n_state_tokens}, features: {state_features.shape[0]}",
|
||||
)
|
||||
|
||||
n_action_tokens = special_action_mask.sum()
|
||||
special_action_mask = (
|
||||
special_action_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
)
|
||||
if action_features is not None:
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_action_mask].numel() == action_features.numel(),
|
||||
f"Action features and action tokens do not match, tokens: {n_action_tokens}, features: {action_features.shape[0]}",
|
||||
)
|
||||
|
||||
return special_state_mask, special_action_mask
|
||||
|
||||
def embed_prefix(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
states: torch.Tensor,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the EO1 prefix tokens before native Qwen injects multimodal features."""
|
||||
|
||||
# Get the input embeddings for the input IDs
|
||||
def input_embed_func(input_ids: torch.LongTensor) -> torch.FloatTensor:
|
||||
return self.get_input_embeddings()(input_ids)
|
||||
|
||||
inputs_embeds = self._apply_checkpoint(input_embed_func, input_ids)
|
||||
|
||||
# Project the states to the hidden size
|
||||
def state_proj_func(states: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
states = states.to(dtype=self.state_proj.weight.dtype)
|
||||
return self.state_proj(states)
|
||||
|
||||
state_embs = self._apply_checkpoint(state_proj_func, states)
|
||||
state_mask, _ = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_features=state_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
state_embs = state_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(state_mask, state_embs)
|
||||
return inputs_embeds
|
||||
|
||||
def embed_suffix(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
noisy_actions: torch.Tensor,
|
||||
) -> torch.FloatTensor:
|
||||
"""Embed the suffix"""
|
||||
|
||||
def action_proj_func(noisy_actions: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
noisy_actions = noisy_actions.to(dtype=self.action_in_proj.weight.dtype)
|
||||
return self.action_in_proj(noisy_actions)
|
||||
|
||||
action_embs = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
time_embs = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.hidden_size,
|
||||
min_period=self.config.min_period,
|
||||
max_period=self.config.max_period,
|
||||
device=action_embs.device,
|
||||
)
|
||||
time_embs = time_embs.to(dtype=action_embs.dtype)
|
||||
time_embs = time_embs[:, None, :].expand_as(action_embs)
|
||||
action_time_embs = torch.cat([action_embs, time_embs], dim=2)
|
||||
|
||||
def mlp_func(action_time_embs: torch.Tensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_time_embs = action_time_embs.to(dtype=self.action_time_mlp_in.weight.dtype)
|
||||
action_time_embs = self.action_time_mlp_in(action_time_embs)
|
||||
action_time_embs = F.silu(action_time_embs)
|
||||
return self.action_time_mlp_out(action_time_embs)
|
||||
|
||||
action_time_embs = self._apply_checkpoint(mlp_func, action_time_embs)
|
||||
return action_time_embs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.FloatTensor | None = None,
|
||||
action: torch.FloatTensor | None = None,
|
||||
action_is_pad: torch.BoolTensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Run the EO1 training forward pass and compute the flow-matching loss."""
|
||||
|
||||
# 1. Build the EO1 prefix with state placeholders resolved.
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
|
||||
# 2. Sample the diffusion target and replace the action placeholders.
|
||||
time = self.sample_time(action.shape[0], inputs_embeds.device)
|
||||
noise = self.sample_noise(action.shape, inputs_embeds.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * action
|
||||
u_t = noise - action
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
_, action_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
action_features=action_time_embs,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_time_embs = action_time_embs.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(action_mask, action_time_embs)
|
||||
|
||||
# 3. Optionally drop padded action tokens from backbone attention.
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
action_is_pad = action_is_pad.to(device=inputs_embeds.device, dtype=torch.bool)
|
||||
action_token_mask = action_mask[..., 0]
|
||||
action_padding_mask = torch.zeros_like(action_token_mask)
|
||||
action_padding_mask = action_padding_mask.masked_scatter(
|
||||
action_token_mask,
|
||||
action_is_pad.reshape(-1),
|
||||
)
|
||||
attention_mask = attention_mask.masked_fill(action_padding_mask, 0)
|
||||
|
||||
# 4. Run the Qwen backbone on the fused EO1 sequence.
|
||||
def vlm_forward_func(
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
inputs_embeds: torch.FloatTensor,
|
||||
pixel_values: torch.Tensor | None,
|
||||
image_grid_thw: torch.LongTensor | None,
|
||||
mm_token_type_ids: torch.IntTensor | None,
|
||||
) -> torch.FloatTensor:
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
use_cache=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
hidden_states = self._apply_checkpoint(
|
||||
vlm_forward_func,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
pixel_values,
|
||||
image_grid_thw,
|
||||
mm_token_type_ids,
|
||||
)
|
||||
action_hidden_states = hidden_states[action_mask[..., 0]]
|
||||
|
||||
# 5. Project the action-token hidden states back to the flow target space.
|
||||
def action_out_proj_func(action_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
with self.flow_head_autocast_context():
|
||||
action_hidden_states = action_hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
return self.action_out_proj(action_hidden_states)
|
||||
|
||||
v_t = self._apply_checkpoint(action_out_proj_func, action_hidden_states)
|
||||
v_t = v_t.reshape(u_t.shape).to(dtype=u_t.dtype)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
# 6. Apply the configured supervision mask and reduce the loss.
|
||||
if not self.config.supervise_padding_action_dims:
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
losses = losses[..., :original_action_dim]
|
||||
|
||||
if not self.config.supervise_padding_actions:
|
||||
losses = losses[~action_is_pad]
|
||||
|
||||
return losses.mean()
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_actions(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
image_grid_thw: torch.LongTensor | None = None,
|
||||
mm_token_type_ids: torch.IntTensor | None = None,
|
||||
states: torch.Tensor | None = None,
|
||||
*,
|
||||
state_token_id: int,
|
||||
action_token_id: int,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""Sample actions from the model."""
|
||||
if states is None:
|
||||
raise ValueError("states are required for EO1 action sampling.")
|
||||
if mm_token_type_ids is None:
|
||||
raise ValueError("mm_token_type_ids are required for EO1 action sampling.")
|
||||
|
||||
# 1. Resolve the left-padded rollout prompt and locate the action span.
|
||||
chunk_size = self.config.chunk_size
|
||||
|
||||
inputs_embeds = self.embed_prefix(
|
||||
input_ids,
|
||||
states=states,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
).clone()
|
||||
_, action_placeholder_mask = self.get_placeholder_mask(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
state_token_id=state_token_id,
|
||||
action_token_id=action_token_id,
|
||||
)
|
||||
action_mask = action_placeholder_mask[..., 0]
|
||||
token_counts = action_mask.sum(dim=1)
|
||||
if not torch.all(token_counts == chunk_size):
|
||||
raise ValueError(
|
||||
f"Each sample must contain exactly {chunk_size} action tokens, got {token_counts.tolist()}."
|
||||
)
|
||||
if action_mask.ne(action_mask[:1]).any():
|
||||
raise ValueError(
|
||||
"Batch inference expects all samples to share the same action token mask after left padding."
|
||||
)
|
||||
act_start = int(action_mask[0].to(torch.int64).argmax().item())
|
||||
act_end = act_start + self.config.chunk_size
|
||||
if not torch.all(action_mask[:, act_start:act_end]):
|
||||
raise ValueError("Action tokens must form a contiguous chunk of length chunk_size.")
|
||||
act_slice = slice(act_start, act_end)
|
||||
|
||||
# 2. Encode the fixed prefix once and cache its KV state.
|
||||
batch_size = input_ids.shape[0]
|
||||
device = inputs_embeds.device
|
||||
attention_mask = attention_mask.to(device)
|
||||
mm_token_type_ids = mm_token_type_ids.to(device)
|
||||
position_ids, _ = self.vlm_backbone.model.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
mm_token_type_ids=mm_token_type_ids,
|
||||
)
|
||||
position_ids = position_ids.to(device)
|
||||
|
||||
outputs = self.vlm_backbone.model(
|
||||
input_ids=input_ids[:, :act_start],
|
||||
attention_mask=attention_mask[:, :act_start],
|
||||
position_ids=position_ids[..., :act_start],
|
||||
inputs_embeds=inputs_embeds[:, :act_start],
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
mm_token_type_ids=mm_token_type_ids[:, :act_start],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
x_t = self.sample_noise(
|
||||
(batch_size, chunk_size, self.config.max_action_dim),
|
||||
device,
|
||||
).to(dtype=self.action_in_proj.weight.dtype)
|
||||
dt = -1.0 / self.config.num_denoise_steps
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
# 3. Denoise only the action chunk while keeping the prefix cache invariant.
|
||||
for step in range(self.config.num_denoise_steps):
|
||||
time = torch.full(
|
||||
(batch_size,),
|
||||
1.0 + step * dt,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
action_time_embs = self.embed_suffix(time, x_t)
|
||||
inputs_embeds[:, act_slice] = action_time_embs.to(inputs_embeds.dtype)
|
||||
|
||||
# Keep the prefix KV cache invariant across denoising steps.
|
||||
past_key_values.crop(act_start)
|
||||
outputs = self.vlm_backbone.model(
|
||||
attention_mask=attention_mask[:, :act_end],
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds[:, act_slice],
|
||||
position_ids=position_ids[..., act_slice],
|
||||
use_cache=True,
|
||||
return_dict=True,
|
||||
)
|
||||
with self.flow_head_autocast_context():
|
||||
hidden_states = outputs.last_hidden_state[:, :chunk_size]
|
||||
hidden_states = hidden_states.to(dtype=self.action_out_proj.dtype)
|
||||
v_t = self.action_out_proj(hidden_states)
|
||||
|
||||
x_t += dt * v_t.reshape(x_t.shape)
|
||||
|
||||
return x_t
|
||||
282
src/lerobot/policies/eo1/processor_eo1.py
Normal file
282
src/lerobot/policies/eo1/processor_eo1.py
Normal file
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
|
||||
else:
|
||||
Qwen2_5_VLProcessor = None
|
||||
|
||||
SYSTEM_MESSAGE = "You are a helpful physical assistant."
|
||||
|
||||
# EO-1 special tokens
|
||||
ACTION_START_TOKEN = "<|action_start|>" # nosec B105
|
||||
DEFAULT_ACTION_TOKEN = "<|action_pad|>" # nosec B105
|
||||
ACTION_END_TOKEN = "<|action_end|>" # nosec B105
|
||||
STATE_START_TOKEN = "<|state_start|>" # nosec B105
|
||||
DEFAULT_STATE_TOKEN = "<|state_pad|>" # nosec B105
|
||||
STATE_END_TOKEN = "<|state_end|>" # nosec B105
|
||||
TASK_VLA_TOKEN = "<|vla|>" # nosec B105
|
||||
|
||||
EO1_SPECIAL_TOKENS = [
|
||||
ACTION_START_TOKEN,
|
||||
DEFAULT_ACTION_TOKEN,
|
||||
ACTION_END_TOKEN,
|
||||
STATE_START_TOKEN,
|
||||
DEFAULT_STATE_TOKEN,
|
||||
STATE_END_TOKEN,
|
||||
TASK_VLA_TOKEN,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_conversation_template_processor")
|
||||
class EO1ConversationTemplateStep(ComplementaryDataProcessorStep):
|
||||
input_features: dict[str, PolicyFeature] | dict[str, dict[str, Any]]
|
||||
chunk_size: int
|
||||
|
||||
_image_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
# Robust JSON deserialization handling (guard empty maps).
|
||||
if self.input_features:
|
||||
first_val = next(iter(self.input_features.values()))
|
||||
if isinstance(first_val, dict):
|
||||
reconstructed = {}
|
||||
for key, ft_dict in self.input_features.items():
|
||||
reconstructed[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.input_features = reconstructed
|
||||
|
||||
self._image_keys = [
|
||||
key for key, value in self.input_features.items() if value.type == FeatureType.VISUAL
|
||||
]
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
tasks = complementary_data.get("task")
|
||||
if tasks is None:
|
||||
raise ValueError("Task is required for EO1ConversationTemplateStep.")
|
||||
|
||||
observation = self.transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is None:
|
||||
raise ValueError("Observation is required for EO1ConversationTemplateStep.")
|
||||
|
||||
if OBS_STATE in observation and observation[OBS_STATE].shape[0] != len(tasks):
|
||||
raise ValueError("Batch size mismatch between observation.state and task list.")
|
||||
|
||||
# LeRobot visual observations reach in processor as float32 tensors in [0, 1].
|
||||
# Convert to uint8 in [0, 255] to meet the input requirement of Qwen2.5-VL-3B-Instruct.
|
||||
images = {
|
||||
key: observation[key].clamp(0, 1).mul(255.0).round().to(torch.uint8) for key in self._image_keys
|
||||
}
|
||||
messages = []
|
||||
for i in range(len(tasks)):
|
||||
content = [
|
||||
*[{"type": "image", "image": images[key][i]} for key in self._image_keys],
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
f"{STATE_START_TOKEN}{DEFAULT_STATE_TOKEN}{STATE_END_TOKEN}{tasks[i]}{TASK_VLA_TOKEN}"
|
||||
),
|
||||
},
|
||||
]
|
||||
messages.append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
|
||||
{"role": "user", "content": content},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"{ACTION_START_TOKEN}{DEFAULT_ACTION_TOKEN * self.chunk_size}{ACTION_END_TOKEN}",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
complementary_data["messages"] = messages
|
||||
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only materializes EO1-specific message objects in complementary_data.
|
||||
PipelineFeatureType tracks only ACTION and OBSERVATION, so there is no static
|
||||
feature contract change to record here.
|
||||
"""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.input_features.items()
|
||||
},
|
||||
"chunk_size": self.chunk_size,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="eo1_qwen_processor")
|
||||
class EO1QwenProcessorStep(ComplementaryDataProcessorStep):
|
||||
processor_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
image_min_pixels: int | None = 64 * 28 * 28
|
||||
image_max_pixels: int | None = 128 * 28 * 28
|
||||
use_fast_processor: bool = False
|
||||
|
||||
_processor: Qwen2_5_VLProcessor | None = field(default=None, init=False, repr=False)
|
||||
_state_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
_action_token_id: int | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
require_package("transformers", extra="eo1")
|
||||
self._processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
self.processor_name,
|
||||
use_fast=self.use_fast_processor,
|
||||
)
|
||||
self._processor.tokenizer.add_tokens(EO1_SPECIAL_TOKENS, special_tokens=True)
|
||||
self._state_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_STATE_TOKEN)
|
||||
self._action_token_id = self._processor.tokenizer.convert_tokens_to_ids(DEFAULT_ACTION_TOKEN)
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
messages = complementary_data.pop("messages", None)
|
||||
if messages is None:
|
||||
raise ValueError("Messages are required for EO1QwenProcessorStep.")
|
||||
|
||||
# Rollout batches use left padding so action spans stay aligned across samples.
|
||||
# Supervised batches use right padding to match standard training collation.
|
||||
padding_side = "right" if self.transition.get(TransitionKey.ACTION) is not None else "left"
|
||||
|
||||
inputs = self._processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
padding=True,
|
||||
padding_side=padding_side,
|
||||
min_pixels=self.image_min_pixels,
|
||||
max_pixels=self.image_max_pixels,
|
||||
add_generation_prompt=False,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
complementary_data["input_ids"] = inputs["input_ids"]
|
||||
complementary_data["pixel_values"] = inputs["pixel_values"]
|
||||
complementary_data["image_grid_thw"] = inputs["image_grid_thw"]
|
||||
complementary_data["attention_mask"] = inputs["attention_mask"]
|
||||
complementary_data["mm_token_type_ids"] = inputs["mm_token_type_ids"]
|
||||
complementary_data["state_token_id"] = self._state_token_id
|
||||
complementary_data["action_token_id"] = self._action_token_id
|
||||
|
||||
return complementary_data
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"processor_name": self.processor_name,
|
||||
"image_min_pixels": self.image_min_pixels,
|
||||
"image_max_pixels": self.image_max_pixels,
|
||||
"use_fast_processor": self.use_fast_processor,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step only converts the messages to the model input format.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_eo1_pre_post_processors(
|
||||
config: EO1Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Build pre/post processor pipelines for EO1."""
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
EO1ConversationTemplateStep(input_features=config.input_features, chunk_size=config.chunk_size),
|
||||
EO1QwenProcessorStep(
|
||||
processor_name=config.vlm_base,
|
||||
image_min_pixels=config.image_min_pixels,
|
||||
image_max_pixels=config.image_max_pixels,
|
||||
use_fast_processor=config.use_fast_processor,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -46,6 +46,7 @@ from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||
|
||||
from .act.configuration_act import ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig
|
||||
from .eo1.configuration_eo1 import EO1Config
|
||||
from .groot.configuration_groot import GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config
|
||||
@@ -146,6 +147,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .wall_x.modeling_wall_x import WallXPolicy
|
||||
|
||||
return WallXPolicy
|
||||
elif name == "eo1":
|
||||
from .eo1.modeling_eo1 import EO1Policy
|
||||
|
||||
return EO1Policy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -196,6 +201,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return XVLAConfig(**kwargs)
|
||||
elif policy_type == "wall_x":
|
||||
return WallXConfig(**kwargs)
|
||||
elif policy_type == "eo1":
|
||||
return EO1Config(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -399,6 +406,13 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, EO1Config):
|
||||
from .eo1.processor_eo1 import make_eo1_pre_post_processors
|
||||
|
||||
processors = make_eo1_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
@@ -514,7 +528,7 @@ def make_policy(
|
||||
|
||||
logging.info("Loading policy's PEFT adapter.")
|
||||
|
||||
peft_pretrained_path = cfg.pretrained_path
|
||||
peft_pretrained_path = str(cfg.pretrained_path)
|
||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||
|
||||
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
|
||||
@@ -527,7 +541,9 @@ def make_policy(
|
||||
)
|
||||
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||
policy = PeftModel.from_pretrained(
|
||||
policy, peft_pretrained_path, config=peft_config, is_trainable=True
|
||||
)
|
||||
|
||||
else:
|
||||
# Make a fresh policy.
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
@@ -109,7 +109,6 @@ class MultiEmbodimentActionEncoder(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowmatchingActionHeadConfig(PretrainedConfig):
|
||||
"""NOTE: N1.5 uses XEmbFlowmatchingPolicyHeadConfig as action head"""
|
||||
|
||||
|
||||
@@ -444,13 +444,13 @@ class PaliGemmaWithExpertModel(
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
features = image_outputs.pooler_output
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -666,8 +666,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language tokens
|
||||
def lang_embed_func(lang_tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -748,16 +747,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
@@ -1292,8 +1283,11 @@ class PI0Policy(PreTrainedPolicy):
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -728,14 +728,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
return embs, pad_masks, att_masks, adarms_cond
|
||||
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
||||
def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss."""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
@@ -1262,8 +1256,11 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Unpack
|
||||
@@ -261,13 +260,15 @@ class PI0FastPaliGemma(nn.Module):
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
image_outputs = self.paligemma.model.get_image_features(image)
|
||||
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
|
||||
features = image_outputs.pooler_output
|
||||
norm = 2048**0.5
|
||||
features = features / norm * norm
|
||||
if features.dtype != out_dtype:
|
||||
features = features.to(out_dtype)
|
||||
return features
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.model.language_model.embed_tokens(tokens)
|
||||
return self.paligemma.model.language_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -417,8 +418,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Process language instruction tokens
|
||||
def lang_embed_func(tokens):
|
||||
lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens)
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
return lang_emb * math.sqrt(lang_emb_dim)
|
||||
return lang_emb
|
||||
|
||||
lang_emb = self._apply_checkpoint(lang_embed_func, tokens)
|
||||
embs.append(lang_emb)
|
||||
@@ -432,8 +432,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
def fast_action_embed_func(fast_action_tokens):
|
||||
fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens)
|
||||
fast_emb_dim = fast_emb.shape[-1]
|
||||
return fast_emb * math.sqrt(fast_emb_dim)
|
||||
return fast_emb
|
||||
|
||||
fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens)
|
||||
embs.append(fast_action_emb)
|
||||
@@ -666,7 +665,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if t < max_decoding_steps - 1:
|
||||
# embed the newly generated token
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
@@ -771,7 +769,6 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
# Embed the single previous token
|
||||
# We use embed_language_tokens directly to avoid overhead of full prefix embedding
|
||||
next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token)
|
||||
next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1])
|
||||
if prefix_embs.dtype == torch.bfloat16:
|
||||
next_token_emb = next_token_emb.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_flash_attn_greater_or_equal,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
@@ -890,7 +890,7 @@ class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -45,7 +45,7 @@ from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
is_flash_attn_greater_or_equal,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
@@ -909,7 +909,7 @@ class Florence2FlashAttention2(Florence2Attention):
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal("2.1.0")
|
||||
|
||||
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||
|
||||
@@ -54,6 +54,7 @@ class BiOpenArmFollower(Robot):
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=left_cameras,
|
||||
side=config.left_arm_config.side,
|
||||
@@ -72,6 +73,7 @@ class BiOpenArmFollower(Robot):
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=right_cameras,
|
||||
side=config.right_arm_config.side,
|
||||
|
||||
@@ -66,6 +66,10 @@ class OpenArmFollowerConfigBase:
|
||||
# Whether to disable torque when disconnecting
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# When True, expose `.vel` and `.torque` per motor in observation features.
|
||||
# Default False for compatibility with the position-only openarm_mini teleoperator.
|
||||
use_velocity_and_torque: bool = False
|
||||
|
||||
# Safety limit for relative target positions
|
||||
# Set to a positive scalar for all motors, or a dict mapping motor names to limits
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
@@ -93,8 +93,9 @@ class OpenArmFollower(Robot):
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float # Add this
|
||||
features[f"{motor}.torque"] = float # Add this
|
||||
if self.config.use_velocity_and_torque:
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -235,8 +236,9 @@ class OpenArmFollower(Robot):
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
if self.config.use_velocity_and_torque:
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
|
||||
@@ -23,6 +23,7 @@ from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, send_next_action
|
||||
from .display import BaseDisplay
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,6 +39,8 @@ class BaseStrategy(RolloutStrategy):
|
||||
"""Initialise the inference engine."""
|
||||
self._init_engine(ctx)
|
||||
logger.info("Base strategy ready")
|
||||
self._display = BaseDisplay(duration=ctx.runtime.cfg.duration)
|
||||
self._display.show_banner()
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run the autonomous control loop until shutdown or duration expires."""
|
||||
@@ -72,9 +75,7 @@ class BaseStrategy(RolloutStrategy):
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Disconnect hardware and stop inference."""
|
||||
|
||||
@@ -33,6 +33,7 @@ from ..inference import InferenceEngine
|
||||
if TYPE_CHECKING:
|
||||
from ..configs import RolloutStrategyConfig
|
||||
from ..context import HardwareContext, ProcessorContext, RolloutContext, RuntimeContext
|
||||
from .display import RolloutStatusDisplay
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +52,17 @@ class RolloutStrategy(abc.ABC):
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._warmup_flushed: bool = False
|
||||
self._cached_obs_processed: dict | None = None
|
||||
self._display: RolloutStatusDisplay | None = None
|
||||
|
||||
def _warn_slow_loop(self, dt: float, control_interval: float, fps: float) -> None:
|
||||
"""Warn when the control loop runs slower than the target FPS."""
|
||||
if dt > control_interval:
|
||||
logger.warning(
|
||||
"Control loop running slower (%.1f Hz) than target (%.0f Hz). "
|
||||
"Possible causes: camera FPS not keeping up, slow policy inference, CPU starvation.",
|
||||
1 / dt,
|
||||
fps,
|
||||
)
|
||||
|
||||
def _init_engine(self, ctx: RolloutContext) -> None:
|
||||
"""Attach the inference engine and action interpolator, then start the backend.
|
||||
|
||||
@@ -33,12 +33,13 @@ Recording modes:
|
||||
``record_autonomous=False``: Only correction windows are recorded.
|
||||
Each correction (start to stop) becomes one episode.
|
||||
|
||||
Teleoperator expectations:
|
||||
The user is responsible for keeping the leader arm aligned with the
|
||||
follower arm at the moment a correction begins. Programmatic motor
|
||||
handover (``enable_torque`` / ``disable_torque`` / ``write_goal_positions``)
|
||||
is intentionally not invoked here — see the TODO in
|
||||
:func:`DAggerStrategy._apply_transition` for the open design decision.
|
||||
Teleoperator handover:
|
||||
On AUTONOMOUS → PAUSED, actuated teleops (those with non-empty
|
||||
``feedback_features``, e.g. SO-101, OpenArmMini) are smoothly driven to
|
||||
the follower's last position via ``send_feedback`` so the operator takes
|
||||
over without a jerk. Non-actuated teleops cannot be driven,
|
||||
so on PAUSED → CORRECTING the follower is instead slid to the teleop's
|
||||
current pose before the correction begins.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -70,6 +71,7 @@ from ..configs import DAggerKeyboardConfig, DAggerPedalConfig, DAggerStrategyCon
|
||||
from ..context import RolloutContext
|
||||
from ..robot_wrapper import ThreadSafeRobot
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .display import DAggerDisplay
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
@@ -175,17 +177,27 @@ class DAggerEvents:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# TODO(Steven): re-enable programmatic teleop alignment once we decide whether
|
||||
# to enforce motor-control methods on every Teleoperator. Until then the user
|
||||
# is responsible for moving the leader arm to the follower's pose at the moment
|
||||
# a correction begins.
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
|
||||
) -> None:
|
||||
"""Smoothly move teleop to target position via linear interpolation.
|
||||
def _teleop_supports_feedback(teleop: Teleoperator) -> bool:
|
||||
"""Return True when the teleop can receive position feedback (is actuated).
|
||||
TODO(Maxime): See if it is possible to unify this interface across teleops instead of duck-typing.
|
||||
"""
|
||||
return (
|
||||
bool(teleop.feedback_features)
|
||||
and hasattr(teleop, "disable_torque")
|
||||
and hasattr(teleop, "enable_torque")
|
||||
)
|
||||
|
||||
Requires the teleoperator to support motor control methods
|
||||
(``enable_torque``, ``write_goal_positions``, ``get_action``).
|
||||
|
||||
def _teleop_smooth_move_to(
|
||||
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move an actuated teleop to ``target_pos`` via linear interpolation.
|
||||
|
||||
Requires the teleoperator to support feedback
|
||||
(i.e. have non-empty ``feedback_features`` and implement ``disable_torque`` / ``enable_torque``).
|
||||
|
||||
TODO(Maxime): This blocks up to ``duration_s`` seconds, during this time
|
||||
the follower robot doesn't receive new actions, this could be an issue on LeKiwi.
|
||||
"""
|
||||
teleop.enable_torque()
|
||||
current = teleop.get_action()
|
||||
@@ -193,13 +205,28 @@ def _teleop_smooth_move_to(
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {}
|
||||
for k in current:
|
||||
if k in target_pos:
|
||||
interp[k] = current[k] * (1 - t) + target_pos[k] * t
|
||||
else:
|
||||
interp[k] = current[k]
|
||||
teleop.write_goal_positions(interp)
|
||||
interp = {
|
||||
k: current[k] * (1 - t) + target_pos[k] * t if k in target_pos else current[k] for k in current
|
||||
}
|
||||
teleop.send_feedback(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _follower_smooth_move_to(
|
||||
robot: ThreadSafeRobot, current: dict, target: dict, duration_s: float = 1.0, fps: int = 30
|
||||
) -> None:
|
||||
"""Smoothly move the follower robot from ``current`` to ``target`` action.
|
||||
|
||||
Used when the teleop is non-actuated: instead of driving the leader arm
|
||||
to the follower, we bring the follower to the teleop's current pose.
|
||||
Both ``current`` and ``target`` must be in robot-action key space.
|
||||
"""
|
||||
steps = max(int(duration_s * fps), 1)
|
||||
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp = {k: current[k] * (1 - t) + target[k] * t if k in target else current[k] for k in current}
|
||||
robot.send_action(interp)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
@@ -260,7 +287,7 @@ def _init_dagger_keyboard(events: DAggerEvents, cfg: DAggerKeyboardConfig):
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"DAgger keyboard listener started (pause_resume='%s', correction='%s', upload='%s', ESC=stop)",
|
||||
cfg.pause_resume,
|
||||
cfg.correction,
|
||||
@@ -344,6 +371,28 @@ class DAggerStrategy(RolloutStrategy):
|
||||
self._episode_duration_s,
|
||||
)
|
||||
|
||||
if self.config.input_device == "keyboard":
|
||||
kb = self.config.keyboard
|
||||
pause_key, correction_key, upload_key = (
|
||||
kb.pause_resume.upper(),
|
||||
kb.correction.upper(),
|
||||
kb.upload.upper(),
|
||||
)
|
||||
else:
|
||||
pb = self.config.pedal
|
||||
pause_key, correction_key, upload_key = pb.pause_resume, pb.correction, pb.upload
|
||||
|
||||
self._display = DAggerDisplay(
|
||||
record_autonomous=self.config.record_autonomous,
|
||||
num_episodes=self.config.num_episodes,
|
||||
episode_duration_s=self._episode_duration_s,
|
||||
input_device=self.config.input_device,
|
||||
pause_key=pause_key,
|
||||
correction_key=correction_key,
|
||||
upload_key=upload_key,
|
||||
)
|
||||
self._display.show_banner()
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run DAgger episodes with human-in-the-loop intervention."""
|
||||
if self.config.record_autonomous:
|
||||
@@ -415,10 +464,8 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset()
|
||||
# TODO(Steven): re-enable once Teleoperator motor-control methods are
|
||||
# standardised; until then the user pre-aligns the leader by hand.
|
||||
# teleop.disable_torque()
|
||||
engine.resume()
|
||||
self._display.show_state(DAggerPhase.AUTONOMOUS)
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
record_tick = 0
|
||||
@@ -441,8 +488,17 @@ class DAggerStrategy(RolloutStrategy):
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
self._apply_transition(
|
||||
old_phase,
|
||||
new_phase,
|
||||
engine,
|
||||
interpolator,
|
||||
ctx,
|
||||
last_action,
|
||||
)
|
||||
self._display.show_state(new_phase)
|
||||
if new_phase == DAggerPhase.AUTONOMOUS:
|
||||
last_action = None
|
||||
|
||||
phase = events.phase
|
||||
obs = robot.get_observation()
|
||||
@@ -525,16 +581,11 @@ class DAggerStrategy(RolloutStrategy):
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||
|
||||
finally:
|
||||
logger.info("DAgger continuous control loop ended — pausing engine")
|
||||
engine.pause()
|
||||
# TODO(Steven): re-enable once Teleoperator motor-control methods
|
||||
# are standardised across all teleop implementations.
|
||||
# teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
@@ -570,10 +621,8 @@ class DAggerStrategy(RolloutStrategy):
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
events.reset()
|
||||
# TODO(Steven): re-enable once Teleoperator motor-control methods are
|
||||
# standardised; until then the user pre-aligns the leader by hand.
|
||||
# teleop.disable_torque()
|
||||
engine.resume()
|
||||
self._display.show_state(DAggerPhase.AUTONOMOUS)
|
||||
|
||||
last_action: dict[str, Any] | None = None
|
||||
start_time = time.perf_counter()
|
||||
@@ -600,8 +649,17 @@ class DAggerStrategy(RolloutStrategy):
|
||||
transition = events.consume_transition()
|
||||
if transition is not None:
|
||||
old_phase, new_phase = transition
|
||||
self._apply_transition(old_phase, new_phase, engine, interpolator, robot, teleop)
|
||||
last_action = None
|
||||
self._apply_transition(
|
||||
old_phase,
|
||||
new_phase,
|
||||
engine,
|
||||
interpolator,
|
||||
ctx,
|
||||
last_action,
|
||||
)
|
||||
self._display.show_state(new_phase)
|
||||
if new_phase == DAggerPhase.AUTONOMOUS:
|
||||
last_action = None
|
||||
|
||||
# Correction ended -> save episode (blocking if not streaming)
|
||||
if old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
@@ -672,16 +730,11 @@ class DAggerStrategy(RolloutStrategy):
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||
|
||||
finally:
|
||||
logger.info("DAgger corrections-only loop ended — pausing engine")
|
||||
engine.pause()
|
||||
# TODO(Steven): re-enable once Teleoperator motor-control methods
|
||||
# are standardised across all teleop implementations.
|
||||
# teleop.disable_torque()
|
||||
with contextlib.suppress(Exception):
|
||||
with self._episode_lock:
|
||||
dataset.save_episode()
|
||||
@@ -698,36 +751,71 @@ class DAggerStrategy(RolloutStrategy):
|
||||
new_phase: DAggerPhase,
|
||||
engine,
|
||||
interpolator,
|
||||
robot: ThreadSafeRobot,
|
||||
teleop: Teleoperator,
|
||||
ctx: RolloutContext,
|
||||
prev_action: dict | None,
|
||||
) -> None:
|
||||
"""Execute side-effects for a validated phase transition."""
|
||||
"""Execute side-effects for a validated phase transition, including smooth handovers.
|
||||
|
||||
AUTONOMOUS -> PAUSED (actuated teleop):
|
||||
Pause the engine, then drive the leader arm to the follower's last
|
||||
commanded position so the operator takes over without a jerk.
|
||||
|
||||
PAUSED -> CORRECTING (non-actuated teleop):
|
||||
Slide the follower to the teleop's current pose so the robot meets
|
||||
the operator's hand rather than jumping to it on the first frame.
|
||||
|
||||
CORRECTING -> PAUSED (actuated teleop):
|
||||
Re-enable torque to hold position after correction.
|
||||
This will be potentially useful if cancelling the correction recording
|
||||
|
||||
PAUSED -> AUTONOMOUS:
|
||||
Reset and resume the inference engine.
|
||||
"""
|
||||
teleop = ctx.hardware.teleop
|
||||
robot = ctx.hardware.robot_wrapper
|
||||
|
||||
logger.info("Phase transition: %s -> %s", old_phase.value, new_phase.value)
|
||||
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
|
||||
logger.info("Pausing engine — robot holds position")
|
||||
logger.info("Pausing engine - robot holds position")
|
||||
engine.pause()
|
||||
obs = robot.get_observation()
|
||||
_robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
# TODO(Steven): once Teleoperator motor-control methods are
|
||||
# standardised, drive the leader to the follower's pose here so the
|
||||
# operator does not need to pre-align the arm by hand. Until then
|
||||
# the user is responsible for the alignment.
|
||||
# _teleop_smooth_move_to(teleop, _robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
elif new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode — human teleop control")
|
||||
# TODO(Steven): re-enable once Teleoperator motor-control methods
|
||||
# are standardised across all teleop implementations.
|
||||
# teleop.disable_torque()
|
||||
if _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
# TODO(Maxime): prev_action is in robot action key space (output of robot_action_processor).
|
||||
# send_feedback expects teleop feedback key space. For homogeneous setups (e.g. SO-101
|
||||
# leader + SO-101 follower) the keys are identical so this works. If the processor pipeline
|
||||
# does non-trivial key renaming (e.g. a rename_map on action keys), the interpolation in
|
||||
# _teleop_smooth_move_to silently no-ops and the arm doesn't move.
|
||||
logger.info("Smooth handover: moving leader arm to follower position")
|
||||
_teleop_smooth_move_to(teleop, prev_action)
|
||||
|
||||
elif old_phase == DAggerPhase.PAUSED and new_phase == DAggerPhase.CORRECTING:
|
||||
logger.info("Entering correction mode - human teleop control")
|
||||
if not _teleop_supports_feedback(teleop) and prev_action is not None:
|
||||
logger.info("Smooth handover: sliding follower to teleop position")
|
||||
obs = robot.get_observation()
|
||||
teleop_action = teleop.get_action()
|
||||
processed = ctx.processors.teleop_action_processor((teleop_action, obs))
|
||||
target = ctx.processors.robot_action_processor((processed, obs))
|
||||
_follower_smooth_move_to(robot, prev_action, target)
|
||||
|
||||
# unlock the teleop for human control
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
elif old_phase == DAggerPhase.CORRECTING and new_phase == DAggerPhase.PAUSED:
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.enable_torque()
|
||||
|
||||
elif new_phase == DAggerPhase.AUTONOMOUS:
|
||||
logger.info("Resuming autonomous mode — resetting engine and interpolator")
|
||||
logger.info("Resuming autonomous mode - resetting engine and interpolator")
|
||||
interpolator.reset()
|
||||
engine.reset()
|
||||
engine.resume()
|
||||
|
||||
# release teleop before resuming the policy
|
||||
if _teleop_supports_feedback(teleop):
|
||||
teleop.disable_torque()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background push (shared by both modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
263
src/lerobot/rollout/strategies/display.py
Normal file
263
src/lerobot/rollout/strategies/display.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Console status display for rollout strategies.
|
||||
|
||||
One subclass per strategy — static states/controls are declared as class
|
||||
constants; runtime-dependent values are passed to ``__init__``.
|
||||
|
||||
In each strategy's ``setup()``:
|
||||
|
||||
self._display = DAggerDisplay(
|
||||
record_autonomous=self.config.record_autonomous,
|
||||
num_episodes=self.config.num_episodes,
|
||||
episode_duration_s=self._episode_duration_s,
|
||||
input_device=self.config.input_device,
|
||||
pause_key="SPACE",
|
||||
correction_key="TAB",
|
||||
upload_key="ENTER",
|
||||
)
|
||||
self._display.show_banner()
|
||||
|
||||
On each state transition:
|
||||
|
||||
self._display.show_state("correcting")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def _supports_color() -> bool:
|
||||
return hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
|
||||
|
||||
|
||||
class _C:
|
||||
"""ANSI escape codes."""
|
||||
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
GREEN = "\033[1;92m"
|
||||
YELLOW = "\033[1;93m"
|
||||
RED = "\033[1;91m"
|
||||
CYAN = "\033[1;96m"
|
||||
WHITE = "\033[1;97m"
|
||||
GRAY = "\033[2;37m"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StateConfig:
|
||||
"""One named rollout state.
|
||||
|
||||
``key`` must match the string passed to ``RolloutStatusDisplay.show_state()``.
|
||||
"""
|
||||
|
||||
key: str
|
||||
emoji: str
|
||||
label: str
|
||||
description: str
|
||||
color: str = _C.WHITE
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlConfig:
|
||||
"""One keyboard/pedal binding shown in the startup banner."""
|
||||
|
||||
key: str
|
||||
description: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base display class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RolloutStatusDisplay:
|
||||
"""Unified console status display. Subclass once per strategy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str,
|
||||
states: list[StateConfig],
|
||||
controls: list[ControlConfig],
|
||||
info: list[str] | None = None,
|
||||
) -> None:
|
||||
self.strategy = strategy
|
||||
self._states = {s.key: s for s in states}
|
||||
self._controls = controls
|
||||
self._info = info or []
|
||||
self._use_color = _supports_color()
|
||||
|
||||
def _c(self, code: str, text: str) -> str:
|
||||
if not self._use_color:
|
||||
return text
|
||||
return f"{code}{text}{_C.RESET}"
|
||||
|
||||
def show_banner(self) -> None:
|
||||
"""Print startup banner: strategy name, states, controls, config info."""
|
||||
width = 62
|
||||
sep = self._c(_C.BOLD, "═" * width)
|
||||
|
||||
print(f"\n{sep}")
|
||||
print(self._c(_C.BOLD, f" lerobot-rollout │ {self.strategy}"))
|
||||
|
||||
if self._states:
|
||||
print()
|
||||
for state in self._states.values():
|
||||
label = self._c(state.color, f"{state.label:<14}")
|
||||
desc = self._c(_C.GRAY, state.description)
|
||||
print(f" {state.emoji} {label} {desc}")
|
||||
|
||||
if self._controls:
|
||||
print()
|
||||
key_width = max(len(c.key) for c in self._controls)
|
||||
for ctrl in self._controls:
|
||||
key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]")
|
||||
print(f" {key_str} {ctrl.description}")
|
||||
|
||||
if self._info:
|
||||
print()
|
||||
for item in self._info:
|
||||
print(f" {item}")
|
||||
|
||||
print(f"{sep}\n")
|
||||
|
||||
def show_state(self, state_key: str | enum.Enum) -> None:
|
||||
"""Print the current state and available controls - call this on every transition."""
|
||||
key = state_key.value if isinstance(state_key, enum.Enum) else state_key
|
||||
state = self._states.get(key)
|
||||
if state is None:
|
||||
return
|
||||
label = self._c(state.color, f"{state.label:<14}")
|
||||
desc = self._c(_C.GRAY, state.description)
|
||||
print(f"\n {state.emoji} {label} {desc}\n")
|
||||
|
||||
if self._controls:
|
||||
key_width = max(len(c.key) for c in self._controls)
|
||||
for ctrl in self._controls:
|
||||
key_str = self._c(_C.CYAN, f"[{ctrl.key:<{key_width}}]")
|
||||
print(f" {key_str} {ctrl.description}")
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# One display subclass per strategy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BaseDisplay(RolloutStatusDisplay):
|
||||
"""Status display for the base (eval-only, no recording) strategy."""
|
||||
|
||||
_STATES = [StateConfig("running", "🟢", "RUNNING", "autonomous rollout — no recording", _C.GREEN)]
|
||||
_CONTROLS = [ControlConfig("Ctrl+C", "stop session")]
|
||||
|
||||
def __init__(self, duration: float = 0) -> None:
|
||||
info = ["No recording — evaluation only."]
|
||||
if duration > 0:
|
||||
info.append(f"Duration: {duration:.0f}s")
|
||||
super().__init__("base", self._STATES, self._CONTROLS, info)
|
||||
|
||||
|
||||
class SentryDisplay(RolloutStatusDisplay):
|
||||
"""Status display for the sentry (continuous autonomous recording) strategy."""
|
||||
|
||||
_STATES = [StateConfig("recording", "🟢", "RECORDING", "continuous autonomous recording", _C.GREEN)]
|
||||
_CONTROLS = [ControlConfig("Ctrl+C", "stop session")]
|
||||
|
||||
def __init__(self, episode_duration_s: float, upload_every_n_episodes: int) -> None:
|
||||
info = [
|
||||
f"Episode rotation: ~{episode_duration_s:.0f}s | "
|
||||
f"Upload every {upload_every_n_episodes} episodes",
|
||||
]
|
||||
super().__init__("sentry", self._STATES, self._CONTROLS, info)
|
||||
|
||||
|
||||
class HighlightDisplay(RolloutStatusDisplay):
|
||||
"""Status display for the highlight (ring-buffer on-demand save) strategy."""
|
||||
|
||||
def __init__(self, ring_buffer_seconds: float, save_key: str, push_key: str) -> None:
|
||||
states = [
|
||||
StateConfig(
|
||||
"buffering",
|
||||
"⚪",
|
||||
"BUFFERING",
|
||||
f"ring buffer active — last {ring_buffer_seconds:.0f}s captured",
|
||||
_C.WHITE,
|
||||
),
|
||||
StateConfig("recording", "🔴", "RECORDING", "live recording — press [s] to save episode", _C.RED),
|
||||
]
|
||||
controls = [
|
||||
ControlConfig(save_key, "BUFFERING ↔ RECORDING start recording / save episode"),
|
||||
ControlConfig(push_key, "push dataset to Hub (background)"),
|
||||
ControlConfig("ESC", "stop session"),
|
||||
]
|
||||
super().__init__("highlight", states, controls)
|
||||
|
||||
|
||||
class DAggerDisplay(RolloutStatusDisplay):
|
||||
"""Status display for the dagger (human-in-the-loop) strategy."""
|
||||
|
||||
_PAUSED_STATE = StateConfig("paused", "🟡", "PAUSED", "holding last position — awaiting input", _C.YELLOW)
|
||||
_CORRECTING_STATE = StateConfig(
|
||||
"correcting", "🔴", "CORRECTING", "human teleop active — recording correction", _C.RED
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
record_autonomous: bool,
|
||||
num_episodes: int,
|
||||
episode_duration_s: float,
|
||||
input_device: str,
|
||||
pause_key: str,
|
||||
correction_key: str,
|
||||
upload_key: str,
|
||||
) -> None:
|
||||
mode = "continuous recording" if record_autonomous else "corrections only"
|
||||
auto_desc = "policy running — recording" if record_autonomous else "policy running — no recording"
|
||||
states = [
|
||||
StateConfig("autonomous", "🟢", "AUTONOMOUS", auto_desc, _C.GREEN),
|
||||
self._PAUSED_STATE,
|
||||
self._CORRECTING_STATE,
|
||||
]
|
||||
controls = [
|
||||
ControlConfig(pause_key, "AUTONOMOUS ↔ PAUSED pause / resume policy"),
|
||||
ControlConfig(correction_key, "PAUSED ↔ CORRECTING start / stop correction"),
|
||||
ControlConfig(upload_key, "push dataset to Hub"),
|
||||
ControlConfig("ESC", "stop session"),
|
||||
]
|
||||
info = [f"Target: {num_episodes} episodes | Input: {input_device}"]
|
||||
if record_autonomous:
|
||||
info.append(f"Episode rotation: ~{episode_duration_s:.0f}s")
|
||||
super().__init__(f"dagger [{mode}]", states, controls, info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dagger_display = DAggerDisplay(
|
||||
record_autonomous=False,
|
||||
num_episodes=20,
|
||||
episode_duration_s=30,
|
||||
input_device="keyboard",
|
||||
pause_key="SPACE",
|
||||
correction_key="TAB",
|
||||
upload_key="ENTER",
|
||||
)
|
||||
dagger_display.show_banner()
|
||||
dagger_display.show_state("paused")
|
||||
dagger_display.show_state("correcting")
|
||||
dagger_display.show_state("paused")
|
||||
dagger_display.show_state("autonomous")
|
||||
@@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -36,6 +37,7 @@ from ..configs import HighlightStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..ring_buffer import RolloutRingBuffer
|
||||
from .core import RolloutStrategy, safe_push_to_hub, send_next_action
|
||||
from .display import HighlightDisplay
|
||||
|
||||
PYNPUT_AVAILABLE = _pynput_available
|
||||
keyboard = None
|
||||
@@ -53,6 +55,13 @@ if PYNPUT_AVAILABLE:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HighlightPhase(enum.Enum):
|
||||
"""Observable phases of a Highlight session."""
|
||||
|
||||
BUFFERING = "buffering" # Ring buffer accumulating frames, not recording
|
||||
RECORDING = "recording" # Live recording active
|
||||
|
||||
|
||||
class HighlightStrategy(RolloutStrategy):
|
||||
"""Autonomous rollout with on-demand recording via ring buffer.
|
||||
|
||||
@@ -105,6 +114,13 @@ class HighlightStrategy(RolloutStrategy):
|
||||
self.config.save_key,
|
||||
self.config.push_key,
|
||||
)
|
||||
self._display = HighlightDisplay(
|
||||
ring_buffer_seconds=self.config.ring_buffer_seconds,
|
||||
save_key=self.config.save_key,
|
||||
push_key=self.config.push_key,
|
||||
)
|
||||
self._display.show_banner()
|
||||
self._display.show_state(HighlightPhase.BUFFERING)
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run the autonomous loop, buffering frames and recording on demand."""
|
||||
@@ -162,6 +178,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
for buffered_frame in ring.drain():
|
||||
dataset.add_frame(buffered_frame)
|
||||
self._recording_live.set()
|
||||
self._display.show_state(HighlightPhase.RECORDING)
|
||||
else:
|
||||
dataset.add_frame(frame)
|
||||
with self._episode_lock:
|
||||
@@ -172,6 +189,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
play_sounds,
|
||||
)
|
||||
self._recording_live.clear()
|
||||
self._display.show_state(HighlightPhase.BUFFERING)
|
||||
continue # frame already consumed — skip ring.append
|
||||
|
||||
if self._push_requested.is_set():
|
||||
@@ -188,9 +206,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||
|
||||
finally:
|
||||
logger.info("Highlight control loop ended")
|
||||
@@ -255,7 +271,7 @@ class HighlightStrategy(RolloutStrategy):
|
||||
|
||||
self._listener = keyboard.Listener(on_press=on_press)
|
||||
self._listener.start()
|
||||
logger.info("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
|
||||
logger.debug("Keyboard listener started (save='%s', push='%s', ESC=stop)", save_key, push_key)
|
||||
except ImportError:
|
||||
logger.warning("pynput not available — keyboard listener disabled")
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from lerobot.utils.utils import log_say
|
||||
from ..configs import SentryStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from .core import RolloutStrategy, estimate_max_episode_seconds, safe_push_to_hub, send_next_action
|
||||
from .display import SentryDisplay
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,6 +80,11 @@ class SentryStrategy(RolloutStrategy):
|
||||
self._episode_duration_s,
|
||||
self.config.upload_every_n_episodes,
|
||||
)
|
||||
self._display = SentryDisplay(
|
||||
episode_duration_s=self._episode_duration_s,
|
||||
upload_every_n_episodes=self.config.upload_every_n_episodes,
|
||||
)
|
||||
self._display.show_banner()
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
"""Run the continuous recording loop with automatic episode rotation."""
|
||||
@@ -160,9 +166,7 @@ class SentryStrategy(RolloutStrategy):
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
self._warn_slow_loop(dt, control_interval, cfg.fps)
|
||||
|
||||
finally:
|
||||
logger.info("Sentry control loop ended — saving final episode")
|
||||
|
||||
@@ -49,14 +49,6 @@ Delete episodes and save to a new dataset at a specific path and with a new repo
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and re-encode video segments with h264:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--operation.camera_encoder_config.vcodec h264 \
|
||||
--operation.camera_encoder_config.crf 23
|
||||
|
||||
Split dataset by fractions (pusht_train, pusht_val):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
@@ -82,14 +74,6 @@ Split into more than two splits:
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||
|
||||
Split dataset and re-encode video segments with h264:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}' \
|
||||
--operation.camera_encoder_config.vcodec h264 \
|
||||
--operation.camera_encoder_config.crf 23
|
||||
|
||||
Merge multiple datasets:
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
@@ -203,7 +187,7 @@ import abc
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
@@ -211,8 +195,6 @@ import draccus
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
convert_image_to_video_dataset,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
@@ -236,14 +218,12 @@ class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@dataclass
|
||||
class DeleteEpisodesConfig(OperationConfig):
|
||||
episode_indices: list[int] | None = None
|
||||
camera_encoder_config: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("split")
|
||||
@dataclass
|
||||
class SplitConfig(OperationConfig):
|
||||
splits: dict[str, float | list[int]] | None = None
|
||||
camera_encoder_config: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("merge")
|
||||
@@ -270,7 +250,11 @@ class ModifyTasksConfig(OperationConfig):
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
output_dir: str | None = None
|
||||
camera_encoder_config: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int = 2
|
||||
crf: int = 30
|
||||
fast_decode: int = 0
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
max_episodes_per_batch: int | None = None
|
||||
@@ -372,7 +356,6 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||
episode_indices=cfg.operation.episode_indices,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
camera_encoder_config=cfg.operation.camera_encoder_config,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
@@ -404,7 +387,6 @@ def handle_split(cfg: EditDatasetConfig) -> None:
|
||||
dataset,
|
||||
splits=cfg.operation.splits,
|
||||
output_dir=cfg.new_root,
|
||||
camera_encoder_config=cfg.operation.camera_encoder_config,
|
||||
)
|
||||
|
||||
for split_name, split_ds in split_datasets.items():
|
||||
@@ -575,8 +557,11 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
dataset=dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
camera_encoder_config=getattr(cfg.operation, "camera_encoder_config", None)
|
||||
or camera_encoder_defaults(),
|
||||
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
|
||||
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
|
||||
g=getattr(cfg.operation, "g", 2),
|
||||
crf=getattr(cfg.operation, "crf", 30),
|
||||
fast_decode=getattr(cfg.operation, "fast_decode", 0),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None),
|
||||
|
||||
@@ -63,27 +63,6 @@ lerobot-record \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
Example recording with custom video encoding parameters:
|
||||
```shell
|
||||
lerobot-record \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \\
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\
|
||||
--robot.id=black \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \\
|
||||
--teleop.id=blue \\
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \\
|
||||
--dataset.num_episodes=2 \\
|
||||
--dataset.single_task="Grab the cube" \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
--dataset.encoder_threads=2 \\
|
||||
--dataset.camera_encoder_config.vcodec=h264 \\
|
||||
--dataset.camera_encoder_config.preset=fast \\
|
||||
--dataset.camera_encoder_config.extra_options={"tune": "film", "profile:v": "high", "bf": 2} \\
|
||||
--display_data=true
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -398,10 +377,10 @@ def record(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder_config=cfg.dataset.camera_encoder_config,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
@@ -427,10 +406,10 @@ def record(
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
camera_encoder_config=cfg.dataset.camera_encoder_config,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
@@ -441,7 +420,7 @@ def record(
|
||||
|
||||
if not cfg.dataset.streaming_encoding:
|
||||
logging.info(
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.camera_encoder_config.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
|
||||
@@ -277,9 +277,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if cfg.peft is not None:
|
||||
if cfg.is_reward_model_training:
|
||||
raise ValueError("PEFT is only supported for policy training. ")
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
from peft import PeftModel
|
||||
|
||||
if isinstance(policy, PeftModel):
|
||||
logging.info("PEFT adapter already loaded from checkpoint, skipping wrap_with_peft.")
|
||||
else:
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
# Wait for all processes to finish model creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -49,6 +49,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
manual_control=config.left_arm_config.manual_control,
|
||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
)
|
||||
@@ -63,6 +64,7 @@ class BiOpenArmLeader(Teleoperator):
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
manual_control=config.right_arm_config.manual_control,
|
||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
)
|
||||
|
||||
@@ -60,6 +60,10 @@ class OpenArmLeaderConfigBase:
|
||||
# When enabled, motors have torque disabled for manual movement
|
||||
manual_control: bool = True
|
||||
|
||||
# When True, expose `.vel` and `.torque` per motor in action features.
|
||||
# Default False for compatibility with the position-only openarm_mini teleoperator.
|
||||
use_velocity_and_torque: bool = False
|
||||
|
||||
# TODO(Steven, Pepijn): Not used ... ?
|
||||
# MIT control parameters (used when manual_control=False for torque control)
|
||||
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
|
||||
@@ -70,8 +70,9 @@ class OpenArmLeader(Teleoperator):
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
if self.config.use_velocity_and_torque:
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -201,8 +202,9 @@ class OpenArmLeader(Teleoperator):
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
action_dict[f"{motor}.pos"] = state.get("position")
|
||||
action_dict[f"{motor}.vel"] = state.get("velocity")
|
||||
action_dict[f"{motor}.torque"] = state.get("torque")
|
||||
if self.config.use_velocity_and_torque:
|
||||
action_dict[f"{motor}.vel"] = state.get("velocity")
|
||||
action_dict[f"{motor}.torque"] = state.get("torque")
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
@@ -112,7 +112,7 @@ class OpenArmMini(Teleoperator):
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
return self.action_features
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
@@ -348,8 +348,9 @@ class OpenArmMini(Teleoperator):
|
||||
if left_goals:
|
||||
self.bus_left.sync_write("Goal_Position", left_goals)
|
||||
|
||||
@check_if_not_connected
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
|
||||
self.write_goal_positions(feedback)
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
|
||||
@@ -59,7 +59,7 @@ class SOLeader(Teleoperator):
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
return self.action_features
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
@@ -130,6 +130,12 @@ class SOLeader(Teleoperator):
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
def enable_torque(self) -> None:
|
||||
self.bus.enable_torque()
|
||||
|
||||
def disable_torque(self) -> None:
|
||||
self.bus.disable_torque()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
@@ -145,9 +151,11 @@ class SOLeader(Teleoperator):
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
goals = {k.removesuffix(".pos"): v for k, v in feedback.items() if k.endswith(".pos")}
|
||||
if goals:
|
||||
self.bus.sync_write("Goal_Position", goals)
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
|
||||
@@ -69,7 +69,7 @@ def is_package_available(
|
||||
return package_exists
|
||||
|
||||
|
||||
def get_safe_default_video_backend():
|
||||
def get_safe_default_codec():
|
||||
logger = logging.getLogger(__name__)
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
|
||||
@@ -20,7 +20,7 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.dataset_reader import DatasetReader
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
|
||||
# ── Loading ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -35,7 +35,7 @@ def test_try_load_returns_true_when_data_exists(tmp_path, lerobot_dataset_factor
|
||||
root=dataset.root,
|
||||
episodes=None,
|
||||
tolerance_s=1e-4,
|
||||
video_backend=get_safe_default_video_backend(),
|
||||
video_backend=get_safe_default_codec(),
|
||||
delta_timestamps=None,
|
||||
image_transforms=None,
|
||||
)
|
||||
@@ -58,7 +58,7 @@ def test_try_load_returns_false_when_no_data(tmp_path):
|
||||
root=meta.root,
|
||||
episodes=None,
|
||||
tolerance_s=1e-4,
|
||||
video_backend=get_safe_default_video_backend(),
|
||||
video_backend=get_safe_default_codec(),
|
||||
delta_timestamps=None,
|
||||
image_transforms=None,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_features,
|
||||
convert_image_to_video_dataset,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_features,
|
||||
@@ -33,7 +32,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1247,12 +1246,10 @@ def test_convert_image_to_video_dataset(tmp_path):
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video",
|
||||
camera_encoder_config=VideoEncoderConfig(
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
),
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
episode_indices=[0, 1],
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
@@ -28,7 +28,6 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
|
||||
from lerobot.datasets.dataset_writer import _encode_video_worker
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
|
||||
|
||||
SIMPLE_FEATURES = {
|
||||
@@ -53,8 +52,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
|
||||
# ── Existing encode_video_worker tests ───────────────────────────────
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_camera_encoder_config(tmp_path):
|
||||
"""_encode_video_worker forwards camera_encoder_config to encode_video_frames."""
|
||||
def test_encode_video_worker_forwards_vcodec(tmp_path):
|
||||
"""_encode_video_worker correctly forwards the vcodec parameter."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -69,21 +68,13 @@ def test_encode_video_worker_forwards_camera_encoder_config(tmp_path):
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||
_encode_video_worker(
|
||||
video_key,
|
||||
0,
|
||||
tmp_path,
|
||||
fps=30,
|
||||
camera_encoder_config=VideoEncoderConfig(vcodec="h264", preset=None),
|
||||
encoder_threads=4,
|
||||
)
|
||||
_encode_video_worker(video_key, 0, tmp_path, fps=30, vcodec="h264")
|
||||
|
||||
assert captured_kwargs["camera_encoder_config"].vcodec == "h264"
|
||||
assert captured_kwargs["encoder_threads"] == 4
|
||||
assert captured_kwargs["vcodec"] == "h264"
|
||||
|
||||
|
||||
def test_encode_video_worker_default_camera_encoder_config(tmp_path):
|
||||
"""_encode_video_worker passes None camera_encoder_config which encode_video_frames defaults."""
|
||||
def test_encode_video_worker_default_vcodec(tmp_path):
|
||||
"""_encode_video_worker uses libsvtav1 as the default codec."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -100,8 +91,7 @@ def test_encode_video_worker_default_camera_encoder_config(tmp_path):
|
||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||
_encode_video_worker(video_key, 0, tmp_path, fps=30)
|
||||
|
||||
assert captured_kwargs["camera_encoder_config"] is None
|
||||
assert captured_kwargs["encoder_threads"] is None
|
||||
assert captured_kwargs["vcodec"] == "libsvtav1"
|
||||
|
||||
|
||||
# ── add_frame contracts ──────────────────────────────────────────────
|
||||
|
||||
@@ -43,7 +43,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
create_branch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS, VideoEncoderConfig
|
||||
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
@@ -1470,9 +1470,17 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
|
||||
|
||||
def test_lerobot_dataset_vcodec_validation():
|
||||
"""Invalid vcodec in encoder config is rejected at construction time."""
|
||||
"""Test that LeRobotDataset validates the vcodec parameter."""
|
||||
# Test that invalid vcodec raises ValueError
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
VideoEncoderConfig(vcodec="invalid_codec")
|
||||
LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly
|
||||
# Actually test via create since it's easier
|
||||
LeRobotDataset.create(
|
||||
repo_id="test/invalid_codec",
|
||||
fps=30,
|
||||
features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}},
|
||||
vcodec="invalid_codec",
|
||||
)
|
||||
|
||||
|
||||
def test_valid_video_codecs_constant():
|
||||
|
||||
@@ -14,10 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for streaming video encoding."""
|
||||
"""Tests for streaming video encoding and hardware-accelerated encoding."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -26,20 +27,112 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av # noqa: E402
|
||||
|
||||
from lerobot.datasets.pyav_utils import get_codec
|
||||
from lerobot.datasets.video_utils import (
|
||||
VALID_VIDEO_CODECS,
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
_CameraEncoderThread,
|
||||
_get_codec_options,
|
||||
detect_available_hw_encoders,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# Cross-codec validation tests only fire when the target codec is present
|
||||
# in the local FFmpeg build; on other platforms validate() is a no-op.
|
||||
_has_videotoolbox = get_codec("h264_videotoolbox") is not None
|
||||
_videotoolbox_only = pytest.mark.skipif(
|
||||
not _has_videotoolbox, reason="h264_videotoolbox not in local FFmpeg build"
|
||||
)
|
||||
# ─── _get_codec_options tests ───
|
||||
|
||||
|
||||
class TestGetCodecOptions:
|
||||
def test_libsvtav1_defaults(self):
|
||||
opts = _get_codec_options("libsvtav1")
|
||||
assert opts["g"] == "2"
|
||||
assert opts["crf"] == "30"
|
||||
assert opts["preset"] == "12"
|
||||
|
||||
def test_libsvtav1_custom_preset(self):
|
||||
opts = _get_codec_options("libsvtav1", preset=8)
|
||||
assert opts["preset"] == "8"
|
||||
|
||||
def test_h264_options(self):
|
||||
opts = _get_codec_options("h264", g=10, crf=23)
|
||||
assert opts["g"] == "10"
|
||||
assert opts["crf"] == "23"
|
||||
assert "preset" not in opts
|
||||
|
||||
def test_videotoolbox_options(self):
|
||||
opts = _get_codec_options("h264_videotoolbox", g=2, crf=30)
|
||||
assert opts["g"] == "2"
|
||||
# CRF 30 maps to quality = max(1, min(100, 100 - 30*2)) = 40
|
||||
assert opts["q:v"] == "40"
|
||||
assert "crf" not in opts
|
||||
|
||||
def test_nvenc_options(self):
|
||||
opts = _get_codec_options("h264_nvenc", g=2, crf=25)
|
||||
assert opts["rc"] == "constqp"
|
||||
assert opts["qp"] == "25"
|
||||
assert "crf" not in opts
|
||||
# NVENC doesn't support g
|
||||
assert "g" not in opts
|
||||
|
||||
def test_vaapi_options(self):
|
||||
opts = _get_codec_options("h264_vaapi", crf=28)
|
||||
assert opts["qp"] == "28"
|
||||
|
||||
def test_qsv_options(self):
|
||||
opts = _get_codec_options("h264_qsv", crf=25)
|
||||
assert opts["global_quality"] == "25"
|
||||
|
||||
def test_no_g_no_crf(self):
|
||||
opts = _get_codec_options("h264", g=None, crf=None)
|
||||
assert "g" not in opts
|
||||
assert "crf" not in opts
|
||||
|
||||
|
||||
# ─── HW encoder detection tests ───
|
||||
|
||||
|
||||
class TestHWEncoderDetection:
|
||||
def test_detect_available_hw_encoders_returns_list(self):
|
||||
result = detect_available_hw_encoders()
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_detect_available_hw_encoders_only_valid(self):
|
||||
from lerobot.datasets.video_utils import HW_ENCODERS
|
||||
|
||||
result = detect_available_hw_encoders()
|
||||
for encoder in result:
|
||||
assert encoder in HW_ENCODERS
|
||||
|
||||
def test_resolve_vcodec_passthrough(self):
|
||||
assert resolve_vcodec("libsvtav1") == "libsvtav1"
|
||||
assert resolve_vcodec("h264") == "h264"
|
||||
|
||||
def test_resolve_vcodec_auto_fallback(self):
|
||||
"""When no HW encoders are available, auto should fall back to libsvtav1."""
|
||||
with patch("lerobot.datasets.video_utils.detect_available_hw_encoders", return_value=[]):
|
||||
assert resolve_vcodec("auto") == "libsvtav1"
|
||||
|
||||
def test_resolve_vcodec_auto_picks_hw(self):
|
||||
"""When a HW encoder is available, auto should pick it."""
|
||||
with patch(
|
||||
"lerobot.datasets.video_utils.detect_available_hw_encoders",
|
||||
return_value=["h264_videotoolbox"],
|
||||
):
|
||||
assert resolve_vcodec("auto") == "h264_videotoolbox"
|
||||
|
||||
def test_resolve_vcodec_auto_returns_valid(self):
|
||||
"""Test that resolve_vcodec('auto') returns a known valid codec."""
|
||||
result = resolve_vcodec("auto")
|
||||
assert result in VALID_VIDEO_CODECS
|
||||
|
||||
def test_hw_encoder_names_accepted_in_validation(self):
|
||||
"""Test that HW encoder names pass validation in VALID_VIDEO_CODECS."""
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
|
||||
def test_resolve_vcodec_invalid_raises(self):
|
||||
"""Test that resolve_vcodec raises ValueError for invalid codecs."""
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
resolve_vcodec("not_a_real_codec")
|
||||
|
||||
|
||||
# ─── _CameraEncoderThread tests ───
|
||||
@@ -57,13 +150,14 @@ class TestCameraEncoderThread:
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
enc_cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(),
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -108,13 +202,14 @@ class TestCameraEncoderThread:
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
enc_cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(),
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -142,13 +237,14 @@ class TestCameraEncoderThread:
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
enc_cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(),
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -170,20 +266,11 @@ class TestCameraEncoderThread:
|
||||
|
||||
|
||||
class TestStreamingVideoEncoder:
|
||||
def _make_encoder_config(self, **kwargs):
|
||||
"""Helper to build a VideoEncoderConfig."""
|
||||
return VideoEncoderConfig(**kwargs)
|
||||
|
||||
def test_single_camera_episode(self, tmp_path):
|
||||
"""Test encoding a single camera episode."""
|
||||
video_keys = [f"{OBS_IMAGES}.laptop"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13
|
||||
),
|
||||
)
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.laptop"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 20
|
||||
@@ -208,13 +295,9 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_multi_camera_episode(self, tmp_path):
|
||||
"""Test encoding multiple cameras simultaneously."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.laptop", f"{OBS_IMAGES}.phone"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30
|
||||
),
|
||||
)
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 15
|
||||
@@ -236,13 +319,8 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_sequential_episodes(self, tmp_path):
|
||||
"""Test that multiple sequential episodes work correctly."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30
|
||||
),
|
||||
)
|
||||
|
||||
for ep in range(3):
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
@@ -264,13 +342,8 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_cancel_episode(self, tmp_path):
|
||||
"""Test that canceling an episode cleans up properly."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30
|
||||
),
|
||||
)
|
||||
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
@@ -292,33 +365,28 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_feed_without_start_raises(self, tmp_path):
|
||||
"""Test that feeding frames without starting an episode raises."""
|
||||
encoder = StreamingVideoEncoder(fps=30)
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
with pytest.raises(RuntimeError, match="No active episode"):
|
||||
encoder.feed_frame("cam", np.zeros((64, 96, 3), dtype=np.uint8))
|
||||
encoder.close()
|
||||
|
||||
def test_finish_without_start_raises(self, tmp_path):
|
||||
"""Test that finishing without starting raises."""
|
||||
encoder = StreamingVideoEncoder(fps=30)
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
with pytest.raises(RuntimeError, match="No active episode"):
|
||||
encoder.finish_episode()
|
||||
encoder.close()
|
||||
|
||||
def test_close_is_idempotent(self, tmp_path):
|
||||
"""Test that close() can be called multiple times safely."""
|
||||
encoder = StreamingVideoEncoder(fps=30)
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
encoder.close()
|
||||
encoder.close() # Should not raise
|
||||
|
||||
def test_video_duration_matches_frame_count(self, tmp_path):
|
||||
"""Test that encoded video duration matches num_frames / fps."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13
|
||||
),
|
||||
)
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 90 # 3 seconds at 30fps
|
||||
@@ -349,13 +417,9 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_multi_camera_start_episode_called_once(self, tmp_path):
|
||||
"""Test that with multiple cameras, no frames are lost due to double start_episode."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.cam1", f"{OBS_IMAGES}.cam2"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30
|
||||
),
|
||||
)
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 30
|
||||
@@ -382,24 +446,17 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_encoder_threads_passed_to_thread(self, tmp_path):
|
||||
"""Test that encoder_threads is stored and passed through to encoder threads."""
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
cfg = VideoEncoderConfig(
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
)
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=cfg,
|
||||
encoder_threads=2,
|
||||
fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, encoder_threads=2
|
||||
)
|
||||
assert encoder._encoder_threads == 2
|
||||
assert encoder.encoder_threads == 2
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
# Verify codec options include thread tuning for libsvtav1 (lp=…)
|
||||
# Verify the thread received the encoder_threads value
|
||||
thread = encoder._threads[f"{OBS_IMAGES}.cam"]
|
||||
assert "svtav1-params" in thread.codec_options or "threads" in thread.codec_options
|
||||
assert thread.encoder_threads == 2
|
||||
|
||||
# Feed some frames and finish to ensure it works end-to-end
|
||||
num_frames = 10
|
||||
@@ -421,20 +478,16 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_encoder_threads_none_by_default(self, tmp_path):
|
||||
"""Test that encoder_threads defaults to None (codec auto-detect)."""
|
||||
encoder = StreamingVideoEncoder(fps=30)
|
||||
assert encoder._encoder_threads is None
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
assert encoder.encoder_threads is None
|
||||
encoder.close()
|
||||
|
||||
def test_graceful_frame_dropping(self, tmp_path):
|
||||
"""Test that full queue drops frames instead of crashing."""
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13
|
||||
),
|
||||
queue_maxsize=1,
|
||||
fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13, queue_maxsize=1
|
||||
)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
# Feed many frames quickly - with queue_maxsize=1, some will be dropped
|
||||
|
||||
@@ -1,569 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for ``lerobot.datasets.video_utils`` encoding functions and ``VideoEncoderConfig`` config class."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av # noqa: E402
|
||||
|
||||
from lerobot.datasets.image_writer import write_image
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pyav_utils import get_codec
|
||||
from lerobot.datasets.utils import INFO_PATH
|
||||
from lerobot.datasets.video_utils import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VideoEncoderConfig,
|
||||
concatenate_video_files,
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
)
|
||||
|
||||
|
||||
# Per-codec skip markers — validation tests only fire when the codec is available
|
||||
def _require_encoder(vcodec: str) -> pytest.MarkDecorator:
|
||||
"""Skip the test if ``vcodec`` is not available in the local FFmpeg build."""
|
||||
return pytest.mark.skipif(get_codec(vcodec) is None, reason=f"{vcodec!r} not in local FFmpeg build")
|
||||
|
||||
|
||||
require_libsvtav1 = _require_encoder("libsvtav1")
|
||||
require_h264 = _require_encoder("h264")
|
||||
require_videotoolbox = _require_encoder("h264_videotoolbox")
|
||||
require_nvenc = _require_encoder("h264_nvenc")
|
||||
require_vaapi = _require_encoder("h264_vaapi")
|
||||
require_qsv = _require_encoder("h264_qsv")
|
||||
|
||||
|
||||
# ─── VideoEncoderConfig / codec options ──────────────────────────────
|
||||
|
||||
|
||||
class TestCodecOptions:
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_defaults(self):
|
||||
cfg = VideoEncoderConfig()
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["g"] == 2
|
||||
assert opts["crf"] == 30
|
||||
assert opts["preset"] == 12
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_custom_preset(self):
|
||||
cfg = VideoEncoderConfig(preset=8)
|
||||
assert cfg.get_codec_options()["preset"] == 8
|
||||
|
||||
@require_h264
|
||||
def test_h264_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264", g=10, crf=23, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["g"] == 10
|
||||
assert opts["crf"] == 23
|
||||
assert "preset" not in opts
|
||||
|
||||
@require_videotoolbox
|
||||
def test_videotoolbox_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_videotoolbox", g=2, crf=30, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["g"] == 2
|
||||
assert opts["q:v"] == 40
|
||||
assert "crf" not in opts
|
||||
|
||||
@_require_encoder("h264_nvenc")
|
||||
def test_nvenc_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_nvenc", g=2, crf=25, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["rc"] == "constqp"
|
||||
assert opts["qp"] == 25
|
||||
assert "crf" not in opts
|
||||
assert "g" not in opts
|
||||
|
||||
@_require_encoder("h264_vaapi")
|
||||
def test_vaapi_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_vaapi", crf=28, preset=None)
|
||||
assert cfg.get_codec_options()["qp"] == 28
|
||||
|
||||
@_require_encoder("h264_qsv")
|
||||
def test_qsv_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_qsv", crf=25, preset=None)
|
||||
assert cfg.get_codec_options()["global_quality"] == 25
|
||||
|
||||
@require_h264
|
||||
def test_no_g_no_crf(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264", g=None, crf=None, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
assert "g" not in opts
|
||||
assert "crf" not in opts
|
||||
|
||||
@require_libsvtav1
|
||||
def test_encoder_threads_libsvtav1(self):
|
||||
cfg = VideoEncoderConfig(fast_decode=0)
|
||||
opts = cfg.get_codec_options(encoder_threads=4)
|
||||
assert "lp=4" in opts.get("svtav1-params", "")
|
||||
|
||||
@require_h264
|
||||
def test_encoder_threads_h264(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264", preset=None)
|
||||
assert cfg.get_codec_options(encoder_threads=2)["threads"] == 2
|
||||
|
||||
@require_libsvtav1
|
||||
def test_fast_decode_libsvtav1(self):
|
||||
cfg = VideoEncoderConfig(fast_decode=1)
|
||||
opts = cfg.get_codec_options()
|
||||
assert "fast-decode=1" in opts.get("svtav1-params", "")
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_fast_decode_clamped_to_svt_range(self):
|
||||
"""Out-of-range fast_decode is clamped to [0, 2] in svtav1-params (SVT-AV1 FastDecode)."""
|
||||
cfg = VideoEncoderConfig(fast_decode=100)
|
||||
assert "fast-decode=2" in cfg.get_codec_options().get("svtav1-params", "")
|
||||
cfg_neg = VideoEncoderConfig(fast_decode=-5)
|
||||
assert "fast-decode=0" in cfg_neg.get_codec_options().get("svtav1-params", "")
|
||||
|
||||
@require_h264
|
||||
def test_fast_decode_h264(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264", fast_decode=1, preset=None)
|
||||
assert cfg.get_codec_options()["tune"] == "fastdecode"
|
||||
|
||||
@require_libsvtav1
|
||||
def test_pix_fmt_unsupported_raises(self):
|
||||
"""Passing an unsupported pix_fmt is a hard error."""
|
||||
with pytest.raises(ValueError, match="pix_fmt"):
|
||||
VideoEncoderConfig(pix_fmt="yuv444p") # libsvtav1 only supports yuv420p variants
|
||||
|
||||
@require_libsvtav1
|
||||
@require_h264
|
||||
def test_preset_default_behaviour(self):
|
||||
"""Empty constructor picks preset=12 (libsvtav1 path); other codecs stay None."""
|
||||
assert VideoEncoderConfig().preset == 12
|
||||
assert VideoEncoderConfig(vcodec="libsvtav1").preset == 12
|
||||
assert VideoEncoderConfig(vcodec="h264").preset is None
|
||||
assert VideoEncoderConfig(vcodec="h264", preset=None).preset is None
|
||||
|
||||
@require_h264
|
||||
def test_preset_string_on_h264(self):
|
||||
"""h264 accepts string presets and forwards them to FFmpeg."""
|
||||
cfg = VideoEncoderConfig(vcodec="h264", preset="slow")
|
||||
assert cfg.get_codec_options()["preset"] == "slow"
|
||||
|
||||
@require_videotoolbox
|
||||
def test_preset_on_videotoolbox_not_set(self):
|
||||
"""videotoolbox has no preset option at all."""
|
||||
cfg = VideoEncoderConfig(vcodec="h264_videotoolbox", preset="slow")
|
||||
assert "preset" not in cfg.get_codec_options()
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_preset_out_of_range_raises(self):
|
||||
"""libsvtav1 preset must sit in [-2, 13] as exposed by PyAV."""
|
||||
with pytest.raises(ValueError, match="out of range"):
|
||||
VideoEncoderConfig(vcodec="libsvtav1", preset=100)
|
||||
with pytest.raises(ValueError, match="out of range"):
|
||||
VideoEncoderConfig(vcodec="libsvtav1", preset=-3)
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_crf_out_of_range_raises(self):
|
||||
"""libsvtav1 crf must sit in [0, 63]."""
|
||||
with pytest.raises(ValueError, match="crf.*out of range"):
|
||||
VideoEncoderConfig(vcodec="libsvtav1", crf=64)
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_crf_rejects_python_float(self):
|
||||
"""libsvtav1 exposes ``crf`` as an INT AVOption; Python float must not pass validation."""
|
||||
with pytest.raises(ValueError, match="float values are not allowed"):
|
||||
VideoEncoderConfig(vcodec="libsvtav1", crf=2.5)
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_extra_crf_rejects_fractional_string(self):
|
||||
"""INT options reject fractional values even when supplied only via ``extra_options``."""
|
||||
with pytest.raises(ValueError, match="float values are not allowed"):
|
||||
VideoEncoderConfig(
|
||||
vcodec="libsvtav1",
|
||||
crf=None,
|
||||
extra_options={"crf": "2.5"},
|
||||
)
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_extra_crf_rejects_float(self):
|
||||
with pytest.raises(ValueError, match="float values are not allowed"):
|
||||
VideoEncoderConfig(
|
||||
vcodec="libsvtav1",
|
||||
crf=None,
|
||||
extra_options={"crf": 2.5},
|
||||
)
|
||||
|
||||
@require_h264
|
||||
def test_h264_crf_accepts_float_and_int(self):
|
||||
"""x264 exposes crf as a FLOAT option, so both int and float are accepted."""
|
||||
assert VideoEncoderConfig(vcodec="h264", crf=23).get_codec_options()["crf"] == 23
|
||||
assert VideoEncoderConfig(vcodec="h264", crf=23.5).get_codec_options()["crf"] == 23.5
|
||||
|
||||
@require_libsvtav1
|
||||
def test_validate_is_rerunnable(self):
|
||||
"""After mutating a field, validate() re-checks and surfaces new issues."""
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1")
|
||||
cfg.preset = 100 # now out of range
|
||||
with pytest.raises(ValueError, match="out of range"):
|
||||
cfg.validate()
|
||||
|
||||
|
||||
class TestExtraOptions:
|
||||
@require_libsvtav1
|
||||
def test_default_is_empty_dict(self):
|
||||
cfg = VideoEncoderConfig()
|
||||
assert cfg.extra_options == {}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_unknown_key_passes_through(self):
|
||||
"""Keys not published as AVOptions are forwarded to FFmpeg."""
|
||||
cfg = VideoEncoderConfig(extra_options={"totally_made_up_option": "value"})
|
||||
assert cfg.extra_options == {"totally_made_up_option": "value"}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_numeric_value_in_range_ok(self):
|
||||
"""libsvtav1 exposes ``qp`` as INT in [0, 63]."""
|
||||
cfg = VideoEncoderConfig(extra_options={"qp": 30})
|
||||
assert cfg.extra_options == {"qp": 30}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_numeric_out_of_range_raises(self):
|
||||
with pytest.raises(ValueError, match=r"extra_options\['qp'\].*out of range"):
|
||||
VideoEncoderConfig(extra_options={"qp": 999})
|
||||
|
||||
@require_libsvtav1
|
||||
def test_numeric_string_accepted_in_range(self):
|
||||
"""Numeric strings are accepted for numeric options (mirrors FFmpeg)."""
|
||||
cfg = VideoEncoderConfig(extra_options={"qp": "18"})
|
||||
assert cfg.extra_options == {"qp": "18"}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_numeric_string_out_of_range_raises(self):
|
||||
with pytest.raises(ValueError, match=r"extra_options\['qp'\].*out of range"):
|
||||
VideoEncoderConfig(extra_options={"qp": "999"})
|
||||
|
||||
@require_libsvtav1
|
||||
def test_non_numeric_string_on_numeric_option_raises(self):
|
||||
with pytest.raises(ValueError, match=r"extra_options\['qp'\].*not numeric"):
|
||||
VideoEncoderConfig(extra_options={"qp": "medium"})
|
||||
|
||||
@require_libsvtav1
|
||||
def test_bool_on_numeric_option_raises(self):
|
||||
"""``bool`` is explicitly rejected for numeric options."""
|
||||
with pytest.raises(ValueError, match=r"extra_options\['qp'\].*not numeric"):
|
||||
VideoEncoderConfig(extra_options={"qp": True})
|
||||
|
||||
@require_h264
|
||||
def test_string_option_passes_through_unchecked(self):
|
||||
"""String-typed AVOptions are NOT enum-checked (too many accept freeform)."""
|
||||
cfg = VideoEncoderConfig(vcodec="h264", preset=None, extra_options={"tune": "some-future-tune"})
|
||||
assert cfg.extra_options == {"tune": "some-future-tune"}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_merged_into_codec_options_and_stringified(self):
|
||||
"""Typed merge by default; ``as_strings=True`` matches FFmpeg option dict."""
|
||||
cfg = VideoEncoderConfig(extra_options={"qp": 20})
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["qp"] == 20
|
||||
assert isinstance(opts["qp"], int)
|
||||
assert cfg.get_codec_options(as_strings=True)["qp"] == "20"
|
||||
|
||||
@require_libsvtav1
|
||||
def test_structured_fields_win_on_collision(self):
|
||||
"""A colliding extra_options key is discarded; the structured field wins."""
|
||||
cfg = VideoEncoderConfig(crf=30, extra_options={"crf": 18})
|
||||
assert cfg.get_codec_options()["crf"] == 30
|
||||
|
||||
|
||||
class TestEncoderDetection:
|
||||
@require_h264
|
||||
def test_explicit_codec_kept_when_available(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264")
|
||||
assert cfg.vcodec == "h264"
|
||||
|
||||
@require_videotoolbox
|
||||
def test_auto_picks_videotoolbox_when_available(self):
|
||||
"""``h264_videotoolbox`` sits at the top of ``HW_ENCODERS`` so it wins when present."""
|
||||
cfg = VideoEncoderConfig(vcodec="auto")
|
||||
assert cfg.vcodec == "h264_videotoolbox"
|
||||
|
||||
def test_invalid_codec_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
VideoEncoderConfig(vcodec="not_a_real_codec")
|
||||
|
||||
def test_hw_encoder_names_listed_as_valid(self):
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
|
||||
|
||||
ARTIFACTS = Path(__file__).parent.parent / "fixtures" / "artifacts" / "videos"
|
||||
|
||||
# Default video feature set used by persistence tests.
|
||||
VIDEO_FEATURES = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
}
|
||||
VIDEO_KEY = "observation.images.cam"
|
||||
|
||||
|
||||
def _write_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i in range(num_frames):
|
||||
arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
write_image(arr, imgs_dir / f"frame-{i:06d}.png")
|
||||
|
||||
|
||||
def _encode_video(
|
||||
path: Path, num_frames: int = 4, fps: int = 30, cfg: VideoEncoderConfig | None = None
|
||||
) -> Path:
|
||||
imgs_dir = path.parent / f"imgs_{path.stem}"
|
||||
_write_frames(imgs_dir, num_frames=num_frames)
|
||||
encode_video_frames(imgs_dir, path, fps=fps, camera_encoder_config=cfg, overwrite=True)
|
||||
return path
|
||||
|
||||
|
||||
def _read_feature_info(dataset: LeRobotDataset) -> dict:
|
||||
info = json.loads((dataset.root / INFO_PATH).read_text())
|
||||
return info["features"][VIDEO_KEY]["info"]
|
||||
|
||||
|
||||
def _add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
|
||||
shape = dataset.meta.features[VIDEO_KEY]["shape"]
|
||||
for _ in range(num_frames):
|
||||
dataset.add_frame(
|
||||
{
|
||||
VIDEO_KEY: np.random.randint(0, 256, shape, dtype=np.uint8),
|
||||
"action": np.zeros(2, dtype=np.float32),
|
||||
"task": "test",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestGetVideoInfo:
|
||||
def test_returns_all_stream_fields(self):
|
||||
info = get_video_info(ARTIFACTS / "clip_4frames.mp4")
|
||||
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.channels"] == 3
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
assert "video.g" not in info
|
||||
assert "video.crf" not in info
|
||||
assert "video.preset" not in info
|
||||
|
||||
@require_libsvtav1
|
||||
def test_merges_encoder_config_as_video_prefixed_entries(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
|
||||
info = get_video_info(ARTIFACTS / "clip_4frames.mp4", camera_encoder_config=cfg)
|
||||
|
||||
assert info["video.g"] == 2
|
||||
assert info["video.crf"] == 30
|
||||
assert info["video.preset"] == 12
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_stream_derived_keys_take_precedence_over_config(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
|
||||
info = get_video_info(ARTIFACTS / "clip_4frames.mp4", camera_encoder_config=cfg)
|
||||
|
||||
assert info["video.codec"] # populated from stream, not from config's vcodec
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
|
||||
|
||||
class TestEncodeVideoFrames:
|
||||
@require_libsvtav1
|
||||
def test_produces_readable_mp4(self, tmp_path):
|
||||
video_path = _encode_video(tmp_path / "out.mp4")
|
||||
|
||||
assert video_path.exists()
|
||||
info = get_video_info(video_path)
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
|
||||
@require_libsvtav1
|
||||
def test_frame_count_and_duration_match_input(self, tmp_path):
|
||||
num_frames = 10
|
||||
fps = 30
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=num_frames, fps=fps)
|
||||
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
actual_frames = sum(1 for _ in container.decode(stream))
|
||||
duration = (
|
||||
float(stream.duration * stream.time_base)
|
||||
if stream.duration is not None
|
||||
else float(container.duration / av.time_base)
|
||||
)
|
||||
|
||||
assert actual_frames == num_frames
|
||||
assert abs(duration - num_frames / fps) < 0.1
|
||||
|
||||
def test_overwrite_false_skips_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
sentinel = b"pre-existing content"
|
||||
video_path.write_bytes(sentinel)
|
||||
|
||||
encode_video_frames(imgs_dir, video_path, fps=30, overwrite=False)
|
||||
|
||||
assert video_path.read_bytes() == sentinel
|
||||
|
||||
@require_libsvtav1
|
||||
def test_overwrite_true_replaces_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
video_path.write_bytes(b"stale content")
|
||||
|
||||
encode_video_frames(imgs_dir, video_path, fps=30, overwrite=True)
|
||||
|
||||
info = get_video_info(video_path)
|
||||
assert info["video.height"] == 64
|
||||
|
||||
@require_libsvtav1
|
||||
def test_custom_encoder_config_fields_stored_in_info(self, tmp_path):
|
||||
"""All stream-derived and encoder config fields are present after encoding."""
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=4, crf=25, preset=10)
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg)
|
||||
|
||||
info = get_video_info(video_path, camera_encoder_config=cfg)
|
||||
|
||||
# Stream-derived
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.channels"] == 3
|
||||
assert info["video.codec"] == "av1"
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
# Encoder config
|
||||
assert info["video.g"] == 4
|
||||
assert info["video.crf"] == 25
|
||||
assert info["video.preset"] == 10
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
|
||||
|
||||
class TestConcatenateVideoFiles:
|
||||
def test_two_clips_frame_count(self, tmp_path):
|
||||
"""Output frame count equals the sum of the two input frame counts."""
|
||||
out = tmp_path / "out.mp4"
|
||||
concatenate_video_files([ARTIFACTS / "clip_6frames.mp4", ARTIFACTS / "clip_4frames.mp4"], out)
|
||||
|
||||
with av.open(str(out)) as container:
|
||||
total = sum(1 for _ in container.decode(video=0))
|
||||
assert total == 10
|
||||
|
||||
def test_three_clips_frame_count(self, tmp_path):
|
||||
out = tmp_path / "out.mp4"
|
||||
clip = ARTIFACTS / "clip_5frames.mp4"
|
||||
concatenate_video_files([clip, clip, clip], out)
|
||||
|
||||
with av.open(str(out)) as container:
|
||||
total = sum(1 for _ in container.decode(video=0))
|
||||
assert total == 15
|
||||
|
||||
@require_libsvtav1
|
||||
def test_geometry_preserved(self, tmp_path):
|
||||
"""Output resolution, fps, codec and pixel format must match the inputs."""
|
||||
out = tmp_path / "out.mp4"
|
||||
concatenate_video_files([ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_4frames.mp4"], out)
|
||||
|
||||
info = get_video_info(out)
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.codec"] == "av1"
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
|
||||
def test_compatibility_check_raises_on_different_codec(self, tmp_path):
|
||||
with pytest.raises(ValueError):
|
||||
concatenate_video_files(
|
||||
[ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_h264.mp4"],
|
||||
tmp_path / "out.mp4",
|
||||
compatibility_check=True,
|
||||
)
|
||||
|
||||
def test_compatibility_check_raises_on_different_resolution(self, tmp_path):
|
||||
with pytest.raises(ValueError):
|
||||
concatenate_video_files(
|
||||
[ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_32x48.mp4"],
|
||||
tmp_path / "out.mp4",
|
||||
compatibility_check=True,
|
||||
)
|
||||
|
||||
|
||||
class TestEncoderConfigPersistence:
|
||||
"""Encoder config must be stored as ``video.<field>`` entries in
|
||||
``info["features"][key]["info"]`` when the first episode is saved.
|
||||
"""
|
||||
|
||||
@require_libsvtav1
|
||||
def test_first_episode_save_persists_encoder_config(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder_config=cfg
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
info = _read_feature_info(dataset)
|
||||
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.g"] == 2
|
||||
assert info["video.crf"] == 30
|
||||
assert info["video.preset"] == 12
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_second_episode_does_not_overwrite_encoder_fields(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder_config=cfg
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
first_info = dict(_read_feature_info(dataset))
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert _read_feature_info(dataset) == first_info
|
||||
BIN
tests/fixtures/artifacts/videos/clip_32x48.mp4
LFS
vendored
BIN
tests/fixtures/artifacts/videos/clip_32x48.mp4
LFS
vendored
Binary file not shown.
BIN
tests/fixtures/artifacts/videos/clip_4frames.mp4
LFS
vendored
BIN
tests/fixtures/artifacts/videos/clip_4frames.mp4
LFS
vendored
Binary file not shown.
BIN
tests/fixtures/artifacts/videos/clip_5frames.mp4
LFS
vendored
BIN
tests/fixtures/artifacts/videos/clip_5frames.mp4
LFS
vendored
Binary file not shown.
BIN
tests/fixtures/artifacts/videos/clip_6frames.mp4
LFS
vendored
BIN
tests/fixtures/artifacts/videos/clip_6frames.mp4
LFS
vendored
Binary file not shown.
BIN
tests/fixtures/artifacts/videos/clip_h264.mp4
LFS
vendored
BIN
tests/fixtures/artifacts/videos/clip_h264.mp4
LFS
vendored
Binary file not shown.
186
tests/policies/eo1/test_eo1.py
Normal file
186
tests/policies/eo1/test_eo1.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Smoke tests for EO1's public LeRobot policy interface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.eo1.modeling_eo1 import EO1Policy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
HIDDEN_SIZE = 8
|
||||
STATE_DIM = 4
|
||||
ACTION_DIM = 3
|
||||
CHUNK_SIZE = 3
|
||||
N_ACTION_STEPS = 2
|
||||
MAX_ACTION_DIM = 6
|
||||
STATE_TOKEN_ID = 5
|
||||
ACTION_TOKEN_ID = 6
|
||||
|
||||
|
||||
class DummyVLMBackbone(nn.Module):
|
||||
def __init__(self, hidden_size: int, vocab_size: int = 64):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
||||
self.config = SimpleNamespace(text_config=SimpleNamespace(hidden_size=hidden_size))
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embedding
|
||||
|
||||
def get_rope_index(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_grid_thw: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
mm_token_type_ids: torch.Tensor | None = None,
|
||||
):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
if attention_mask is None:
|
||||
text_positions = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
|
||||
else:
|
||||
text_positions = attention_mask.long().cumsum(-1) - 1
|
||||
text_positions = text_positions.masked_fill(attention_mask == 0, 0)
|
||||
position_ids = text_positions.view(1, batch_size, seq_len).expand(3, batch_size, seq_len)
|
||||
rope_deltas = torch.zeros(batch_size, 1, dtype=torch.long, device=input_ids.device)
|
||||
return position_ids, rope_deltas
|
||||
|
||||
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
||||
return gradient_checkpointing_kwargs
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
return None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
return SimpleNamespace(
|
||||
last_hidden_state=inputs_embeds,
|
||||
past_key_values=SimpleNamespace(crop=lambda prefix_len: None),
|
||||
)
|
||||
|
||||
|
||||
def make_eo1_config():
|
||||
from lerobot.policies.eo1.configuration_eo1 import EO1Config
|
||||
|
||||
return EO1Config(
|
||||
device="cpu",
|
||||
dtype="float32",
|
||||
vlm_base="dummy-qwen",
|
||||
vlm_config={},
|
||||
chunk_size=CHUNK_SIZE,
|
||||
n_action_steps=N_ACTION_STEPS,
|
||||
max_state_dim=STATE_DIM,
|
||||
max_action_dim=MAX_ACTION_DIM,
|
||||
num_denoise_steps=2,
|
||||
input_features={
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(STATE_DIM,)),
|
||||
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
},
|
||||
output_features={
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def make_policy_batch(include_action: bool) -> dict[str, torch.Tensor | int]:
|
||||
batch_size = 1
|
||||
seq_len = CHUNK_SIZE + 4
|
||||
input_ids = torch.tensor(
|
||||
[[11, STATE_TOKEN_ID, 12, ACTION_TOKEN_ID, ACTION_TOKEN_ID, ACTION_TOKEN_ID, 13]],
|
||||
dtype=torch.long,
|
||||
)
|
||||
assert input_ids.shape == (batch_size, seq_len)
|
||||
|
||||
batch: dict[str, torch.Tensor | int] = {
|
||||
OBS_STATE: torch.randn(batch_size, STATE_DIM, dtype=torch.float32),
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
|
||||
"pixel_values": torch.zeros(batch_size, 3, 4, 4, dtype=torch.float32),
|
||||
"image_grid_thw": torch.tensor([[1, 2, 2]], dtype=torch.long),
|
||||
"mm_token_type_ids": torch.zeros(batch_size, seq_len, dtype=torch.int32),
|
||||
"state_token_id": STATE_TOKEN_ID,
|
||||
"action_token_id": ACTION_TOKEN_ID,
|
||||
}
|
||||
if include_action:
|
||||
batch[ACTION] = torch.randn(batch_size, CHUNK_SIZE, ACTION_DIM, dtype=torch.float32)
|
||||
return batch
|
||||
|
||||
|
||||
def test_lerobot_eo1_forward_pass(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
|
||||
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
|
||||
)
|
||||
policy = EO1Policy(make_eo1_config())
|
||||
|
||||
loss, metrics = policy.forward(make_policy_batch(include_action=True))
|
||||
|
||||
assert loss.ndim == 0
|
||||
assert torch.isfinite(loss)
|
||||
assert metrics["loss"] == pytest.approx(loss.item())
|
||||
|
||||
|
||||
def test_lerobot_eo1_inference(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"lerobot.policies.eo1.modeling_eo1.Qwen2_5_VLForConditionalGeneration.from_pretrained",
|
||||
lambda *args, **kwargs: DummyVLMBackbone(HIDDEN_SIZE),
|
||||
)
|
||||
policy = EO1Policy(make_eo1_config())
|
||||
|
||||
sample_calls = {"count": 0}
|
||||
fixed_chunk = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.1, 0.2, 0.3, 9.0, 9.0, 9.0],
|
||||
[1.1, 1.2, 1.3, 9.0, 9.0, 9.0],
|
||||
[2.1, 2.2, 2.3, 9.0, 9.0, 9.0],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
def fake_sample_actions(**kwargs):
|
||||
sample_calls["count"] += 1
|
||||
return fixed_chunk
|
||||
|
||||
monkeypatch.setattr(policy.model, "sample_actions", fake_sample_actions)
|
||||
|
||||
batch = make_policy_batch(include_action=False)
|
||||
action_0 = policy.select_action(batch)
|
||||
action_1 = policy.select_action(batch)
|
||||
|
||||
torch.testing.assert_close(action_0, fixed_chunk[:, 0, :ACTION_DIM])
|
||||
torch.testing.assert_close(action_1, fixed_chunk[:, 1, :ACTION_DIM])
|
||||
assert sample_calls["count"] == 1
|
||||
158
uv.lock
generated
158
uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.15' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'",
|
||||
@@ -972,10 +972,10 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "cuda-pathfinder"
|
||||
version = "1.5.3"
|
||||
version = "1.5.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/d6/ac63065d33dd700fee7ebd7d287332401b54e31b9346e142f871e1f0b116/cuda_pathfinder-1.5.3-py3-none-any.whl", hash = "sha256:dff021123aedbb4117cc7ec81717bbfe198fb4e8b5f1ee57e0e084fec5c8577d", size = 49991, upload-time = "2026-04-14T20:09:27.037Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/11/d0/c177e29701cf1d3008d7d2b16b5fc626592ce13bd535f8795c5f57187e0e/cuda_pathfinder-1.5.4-py3-none-any.whl", hash = "sha256:9563d3175ce1828531acf4b94e1c1c7d67208c347ca002493e2654878b26f4b7", size = 51657, upload-time = "2026-04-27T22:42:07.712Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -989,7 +989,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "4.8.4"
|
||||
version = "4.8.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
@@ -1007,9 +1007,9 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/22/22/73e46ac7a8c25e7ef0b3bd6f10da3465021d90219a32eb0b4d2afea4c56e/datasets-4.8.4.tar.gz", hash = "sha256:a1429ed853275ce7943a01c6d2e25475b4501eb758934362106a280470df3a52", size = 604382, upload-time = "2026-03-23T14:21:17.987Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/66/34/14cd8e76f907f7d4dca2334cfeec9f81d30fd15c25a015f99aaea694eaed/datasets-4.8.5.tar.gz", hash = "sha256:0f0c1c3d56ffff2c93b2f4c63c95bac94f3d7e8621aea2a2a576275233bba772", size = 605649, upload-time = "2026-04-27T15:43:57.384Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/e5/247d094108e42ac26363ab8dc57f168840cf7c05774b40ffeb0d78868fcc/datasets-4.8.4-py3-none-any.whl", hash = "sha256:cdc8bee4698e549d78bf1fed6aea2eebc760b22b084f07e6fc020c6577a6ce6d", size = 526991, upload-time = "2026-03-23T14:21:15.89Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/65/99/00f3196036501b53032c4b1ab8337a0b978dee832ed276dae3815df4e8b5/datasets-4.8.5-py3-none-any.whl", hash = "sha256:5079900781719c0e063a8efdd2cd95a31ad0c63209178669cd23cf1b926149ff", size = 528973, upload-time = "2026-04-27T15:43:53.702Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1563,14 +1563,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "gitpython"
|
||||
version = "3.1.47"
|
||||
version = "3.1.49"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "gitdb" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c1/bd/50db468e9b1310529a19fce651b3b0e753b5c07954d486cba31bbee9a5d5/gitpython-3.1.47.tar.gz", hash = "sha256:dba27f922bd2b42cb54c87a8ab3cb6beb6bf07f3d564e21ac848913a05a8a3cd", size = 216978, upload-time = "2026-04-22T02:44:44.059Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e1/63/210aaa302d6a0a78daa67c5c15bbac2cad361722841278b0209b6da20855/gitpython-3.1.49.tar.gz", hash = "sha256:42f9399c9eb33fc581014bedd76049dfbaf6375aa2a5754575966387280315e1", size = 219367, upload-time = "2026-04-29T00:31:20.478Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/c5/a1bc0996af85757903cf2bf444a7824e68e0035ce63fb41d6f76f9def68b/gitpython-3.1.47-py3-none-any.whl", hash = "sha256:489f590edfd6d20571b2c0e72c6a6ac6915ee8b8cd04572330e3842207a78905", size = 209547, upload-time = "2026-04-22T02:44:41.271Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/6f/b842bfa6f21d6f87c57f9abf7194225e55279d96d869775e19e9f7236fc5/gitpython-3.1.49-py3-none-any.whl", hash = "sha256:024b0422d7f84d15cd794844e029ffebd4c5d42a7eb9b936b458697ef550a02c", size = 212190, upload-time = "2026-04-29T00:31:18.412Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1930,7 +1930,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "1.12.0"
|
||||
version = "1.13.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
@@ -1943,9 +1943,9 @@ dependencies = [
|
||||
{ name = "typer" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/56/52/1b54cb569509c725a32c1315261ac9fd0e6b91bbbf74d86fca10d3376164/huggingface_hub-1.12.0.tar.gz", hash = "sha256:7c3fe85e24b652334e5d456d7a812cd9a071e75630fac4365d9165ab5e4a34b6", size = 763091, upload-time = "2026-04-24T13:32:08.674Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/89/ff/ec7ed2eb43bd7ce8bb2233d109cc235c3e807ffe5e469dc09db261fac05e/huggingface_hub-1.13.0.tar.gz", hash = "sha256:f6df2dac5abe82ce2fe05873d10d5ff47bc677d616a2f521f4ee26db9415d9d0", size = 781788, upload-time = "2026-04-30T11:57:33.858Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/2b/ef03ddb96bd1123503c2bd6932001020292deea649e9bf4caa2cb65a85bf/huggingface_hub-1.12.0-py3-none-any.whl", hash = "sha256:d74939969585ee35748bd66de09baf84099d461bda7287cd9043bfb99b0e424d", size = 646806, upload-time = "2026-04-24T13:32:06.717Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/93/db/4b1cdae9460ae1f3ca020cd767f013430ce23eb1d9c890ae3a0609b38d26/huggingface_hub-1.13.0-py3-none-any.whl", hash = "sha256:e942cb50d6a08dd5306688b1ac05bda157fd2fcc88b63dae405f7bd0d3234005", size = 660643, upload-time = "2026-04-30T11:57:31.802Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2131,14 +2131,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "jedi"
|
||||
version = "0.19.2"
|
||||
version = "0.20.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "parso" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/46/b7/a3635f6a2d7cf5b5dd98064fc1d5fbbafcb25477bcea204a3a92145d158b/jedi-0.20.0.tar.gz", hash = "sha256:c3f4ccbd276696f4b19c54618d4fb18f9fc24b0aef02acf704b23f487daa1011", size = 3119416, upload-time = "2026-05-01T23:38:47.814Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/93/242e2eab5fe682ffcb8b0084bde703a41d51e17ee0f3a31ff0d9d813620a/jedi-0.20.0-py2.py3-none-any.whl", hash = "sha256:7bdd9c2634f56713299976f4cbd59cb3fa92165cc5e05ea811fb253480728b67", size = 4884812, upload-time = "2026-05-01T23:38:43.919Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2321,7 +2321,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "jupyter-server"
|
||||
version = "2.17.0"
|
||||
version = "2.18.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
@@ -2343,9 +2343,9 @@ dependencies = [
|
||||
{ name = "traitlets" },
|
||||
{ name = "websocket-client" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5b/ac/e040ec363d7b6b1f11304cc9f209dac4517ece5d5e01821366b924a64a50/jupyter_server-2.17.0.tar.gz", hash = "sha256:c38ea898566964c888b4772ae1ed58eca84592e88251d2cfc4d171f81f7e99d5", size = 731949, upload-time = "2025-08-21T14:42:54.042Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/33/b0/666586d557a71a58cd9960b154fb9aee0ed81dd62a50371195ab95731909/jupyter_server-2.18.1.tar.gz", hash = "sha256:f62be526369b791625e03bd658070563c1a4e9a0a2f439ea1f9dbacea5f7191a", size = 752024, upload-time = "2026-05-05T09:17:51.101Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/92/80/a24767e6ca280f5a49525d987bf3e4d7552bf67c8be07e8ccf20271f8568/jupyter_server-2.17.0-py3-none-any.whl", hash = "sha256:e8cb9c7db4251f51ed307e329b81b72ccf2056ff82d50524debde1ee1870e13f", size = 388221, upload-time = "2025-08-21T14:42:52.034Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/45/bfe3779fd06714a379128f2c4eaf7c99414f0eb081f9f34c135f6b3d511c/jupyter_server-2.18.1-py3-none-any.whl", hash = "sha256:db0374d52a975f88a92a7f20de44e08ef5be9763ba7e99630baf16c46ac8dbf0", size = 391844, upload-time = "2026-05-05T09:17:48.521Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2363,7 +2363,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "jupyterlab"
|
||||
version = "4.5.6"
|
||||
version = "4.5.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "async-lru" },
|
||||
@@ -2380,9 +2380,9 @@ dependencies = [
|
||||
{ name = "tornado" },
|
||||
{ name = "traitlets" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ac/d5/730628e03fff2e8a8e8ccdaedde1489ab1309f9a4fa2536248884e30b7c7/jupyterlab-4.5.6.tar.gz", hash = "sha256:642fe2cfe7f0f5922a8a558ba7a0d246c7bc133b708dfe43f7b3a826d163cf42", size = 23970670, upload-time = "2026-03-11T14:17:04.531Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2b/22/8440ec827762146e7cdecf04335bd348795899d29dc6ae82238707353a2c/jupyterlab-4.5.7.tar.gz", hash = "sha256:55a9822c4754da305f41e113452c68383e214dcf96de760146af89ce5d5117b0", size = 23992763, upload-time = "2026-04-29T16:43:51.328Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/1b/dad6fdcc658ed7af26fdf3841e7394072c9549a8b896c381ab49dd11e2d9/jupyterlab-4.5.6-py3-none-any.whl", hash = "sha256:d6b3dac883aa4d9993348e0f8e95b24624f75099aed64eab6a4351a9cdd1e580", size = 12447124, upload-time = "2026-03-11T14:17:00.229Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/aa/537b8f7d80e799af19af35fb3ddfc970b951088a13c57dd9387dcfbb7f61/jupyterlab-4.5.7-py3-none-any.whl", hash = "sha256:fba4cb0e2c44a52859669d8c98b45de029d5e515f8407bf8534d2a8fc5f0964d", size = 12450123, upload-time = "2026-04-29T16:43:46.639Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2723,6 +2723,10 @@ dynamixel = [
|
||||
{ name = "dynamixel-sdk" },
|
||||
{ name = "pyserial" },
|
||||
]
|
||||
eo1 = [
|
||||
{ name = "qwen-vl-utils" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
evaluation = [
|
||||
{ name = "av" },
|
||||
]
|
||||
@@ -3029,6 +3033,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["pyserial-dep"], marker = "extra == 'unitree-g1'" },
|
||||
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'lekiwi'" },
|
||||
{ name = "lerobot", extras = ["pyzmq-dep"], marker = "extra == 'unitree-g1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
|
||||
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
|
||||
@@ -3043,6 +3048,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
|
||||
@@ -3109,10 +3115,10 @@ requires-dist = [
|
||||
{ name = "torchdiffeq", marker = "extra == 'wallx'", specifier = ">=0.2.4,<0.3.0" },
|
||||
{ name = "torchvision", specifier = ">=0.22.0,<0.26.0" },
|
||||
{ name = "tqdm", specifier = ">=4.66.0,<5.0.0" },
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = "==5.3.0" },
|
||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" },
|
||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||
]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "eo1", "hilserl", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "librt"
|
||||
@@ -3486,11 +3492,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "mistune"
|
||||
version = "3.2.0"
|
||||
version = "3.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9d/55/d01f0c4b45ade6536c51170b9043db8b2ec6ddf4a35c7ea3f5f559ac935b/mistune-3.2.0.tar.gz", hash = "sha256:708487c8a8cdd99c9d90eb3ed4c3ed961246ff78ac82f03418f5183ab70e398a", size = 95467, upload-time = "2025-12-23T11:36:34.994Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ca/84/620cc3f7e3adf6f5067e10f4dbae71295d8f9e16d5d3f9ef97c40f2f592c/mistune-3.2.1.tar.gz", hash = "sha256:7c8e5501d38bac1582e067e46c8343f17d57ea1aaa735823f3aba1fd59c88a28", size = 98003, upload-time = "2026-05-03T14:33:22.312Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1", size = 53598, upload-time = "2025-12-23T11:36:33.211Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/7f/a946aa4f8752b37102b41e64dca18a1976ac705c3a0d1dfe74d820a02552/mistune-3.2.1-py3-none-any.whl", hash = "sha256:78cdb0ba5e938053ccf63651b352508d2efa9411dc8810bfb05f2dc5140c0048", size = 53749, upload-time = "2026-05-03T14:33:20.551Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3855,7 +3861,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "notebook"
|
||||
version = "7.5.5"
|
||||
version = "7.5.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jupyter-server" },
|
||||
@@ -3864,9 +3870,9 @@ dependencies = [
|
||||
{ name = "notebook-shim" },
|
||||
{ name = "tornado" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1f/6d/41052c48d6f6349ca0a7c4d1f6a78464de135e6d18f5829ba2510e62184c/notebook-7.5.5.tar.gz", hash = "sha256:dc0bfab0f2372c8278c457423d3256c34154ac2cc76bf20e9925260c461013c3", size = 14169167, upload-time = "2026-03-11T16:32:51.922Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2a/c2/cf59bd2e6f2c8b976b52477e3e53bf6f97bc714ed046a51821afb428eaee/notebook-7.5.6.tar.gz", hash = "sha256:621174aade80108f0020b0f00738000b215f75fa3cd90771ad7aa0f24536a4e1", size = 14170814, upload-time = "2026-04-30T11:46:26.613Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/aa/cbd1deb9f07446241e88f8d5fecccd95b249bca0b4e5482214a4d1714c49/notebook-7.5.5-py3-none-any.whl", hash = "sha256:a7c14dbeefa6592e87f72290ca982e0c10f5bbf3786be2a600fda9da2764a2b8", size = 14578929, upload-time = "2026-03-11T16:32:48.021Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/d6/1fd0646b9bbd9efbb0b8ae21b2325fbef515769a5621c03e31d8eb8da587/notebook-7.5.6-py3-none-any.whl", hash = "sha256:4dde3f8fb55fa8fb7946d58c6e869ce9baf46d00fc070664f62604569d0faca0", size = 14581730, upload-time = "2026-04-30T11:46:22.342Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4130,7 +4136,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "onnxruntime"
|
||||
version = "1.25.0"
|
||||
version = "1.25.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "flatbuffers" },
|
||||
@@ -4139,25 +4145,25 @@ dependencies = [
|
||||
{ name = "protobuf" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/69/f98c6bda4c34ac382b70c36033a989ceffd1caf5afba47bd2ef26535850f/onnxruntime-1.25.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8ecd3362de3fb496fb3e2d055a95d5acab611cf759a27609c6d99704c9d8f184", size = 17742518, upload-time = "2026-04-22T17:20:34.444Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/c6/19c5bfbc60396791e975652f982bcff9ff4b27947c8e2bf0064ac5d5727b/onnxruntime-1.25.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c99238d20bfa80ac68c7b03c2c936d389189ae40997f78a30d151570d7e18bf", size = 15841110, upload-time = "2026-04-22T17:19:31.284Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/1b/d681878f227513917d8620e4ea504af5eb3313fc01f8aea7b19a976c65db/onnxruntime-1.25.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:be93baa694ef8e5831fcb7b542da21f502b122918b5b9612d9f02972e043ee01", size = 17996146, upload-time = "2026-04-22T17:19:53.792Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/fe/ec98e416bd75063dea1e493661c7c939e18660ee41d6a63d7221e5657f48/onnxruntime-1.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:9596040c1f7d247bbfab5d4db1e7651c790235e48e460c7d445ec81687d5a182", size = 12872370, upload-time = "2026-04-22T17:20:22.856Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/86/9a1ac7c8a8eba7967935d4c109fc956d8f9ba61cba61d9368315bb27d0bc/onnxruntime-1.25.0-cp312-cp312-win_arm64.whl", hash = "sha256:463aed7f5e4a3ca5a476db7e9bba9164fa26921ef34c37e59b28c4c61e55f266", size = 12600072, upload-time = "2026-04-22T17:20:11.523Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/5f/3b916a303f43e9c7eed3a705ea69f6867233c161dede30f4df21538c6693/onnxruntime-1.25.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:1b3d76cf770afba76859f270679c9ad0b017b9357eb5892e91926943e05ca82c", size = 17743247, upload-time = "2026-04-22T17:19:45.206Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/b3/9e45ba86ed39ab688578f21dd39ed4b575726205596891870a1a8b4d5ca9/onnxruntime-1.25.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cddb565dfd630550a8817b3d5493ffcfa0fec273b545b2816f2fce53384e1151", size = 15841442, upload-time = "2026-04-22T17:19:34.209Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/c4/810809e3b411fd66958bdd7285a63acf948988ab4189e1fd860a2f999db3/onnxruntime-1.25.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ade74e651e28b39e6bfd6f576cb9b8a4edfa0916234145154dc891bd55331c22", size = 17993660, upload-time = "2026-04-22T17:19:56.719Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/42/3d/b736cda9c71b3df022ca6bbcb991d14b7723c068dbebe826af9102e79777/onnxruntime-1.25.0-cp313-cp313-win_amd64.whl", hash = "sha256:9196c32c039c37ce8362cbee0aa3a704679be5f2b6fb3e849fea927c98fe1e5b", size = 12871906, upload-time = "2026-04-22T17:20:25.705Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/1f/d7bb87cdbb839b356133e9f9e3851fc0c3130dd1c360640c9ce948e3e083/onnxruntime-1.25.0-cp313-cp313-win_arm64.whl", hash = "sha256:b3e52dc2208dec6f61ef118dff04610927e9a18d99e019a828799b23cc9cdea4", size = 12599753, upload-time = "2026-04-22T17:20:14.661Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/3c/edb0d825a65beed40a3de8a51521d49d433aa767f8d00e633cd2602024c0/onnxruntime-1.25.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:de8548d8fe8fd58ca841178051d535d6f378efae14a4b4eb336617d80540fb41", size = 15852628, upload-time = "2026-04-22T17:19:36.886Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/51/7a660b4d087f27b273ff725f744880e7664f64a9331bfb1eae91ed2a9f0a/onnxruntime-1.25.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4edec672d09e34b9e83ad09c44454ce97627388f32858b1d59fe01d091ff54b5", size = 17997241, upload-time = "2026-04-22T17:19:59.653Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/be/5254acb849f414c8fb2643fe21f2c2ef8089fab18569f24775ccb8ee182d/onnxruntime-1.25.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:38f27febd2ff034a600a8bdbea34b1f7c961a2dab6bcb5351e70548fea456161", size = 17744932, upload-time = "2026-04-22T17:19:48.097Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/49/98/c2593aaa392e278a41bec35a00298aa5f22bb382483ad02ca451a556b2a2/onnxruntime-1.25.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e0ae389ed1647f11c1b501ba1cef1e2c7453002f626136ace214c9c46153ee4", size = 15842603, upload-time = "2026-04-22T17:19:39.879Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/b6/07e924b8a47adc9ce2f92a7ef71a6fb709981b1ebd08179f61cbce6ff9b3/onnxruntime-1.25.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ca32d38173c0f58699ca9dc9e867de74d2c2ab7d1c2d969f862ee8633370b77", size = 17994808, upload-time = "2026-04-22T17:20:02.462Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/31/e0147d87acfd06992a9bf45ffc070fd3ab49ff9a1f12de9fb403f2fc0b97/onnxruntime-1.25.0-cp314-cp314-win_amd64.whl", hash = "sha256:a2829e29621db7a4bcd457e6d0f3e4f541fb274c7127e7d2e1a5b46c70572672", size = 13183697, upload-time = "2026-04-22T17:20:28.658Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/29/b1d5b91d04ae80768ed8e38639ab2fcc92750a67fddc30ad6b700f244113/onnxruntime-1.25.0-cp314-cp314-win_arm64.whl", hash = "sha256:2bed9b35568b3ecf8ab34dc832d37216e47947e86508a0fd6b75e4c19d7ba907", size = 12933438, upload-time = "2026-04-22T17:20:17.223Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/f4/cfd47f88da545ea57c1f2a4b5886d455ec64f53b723b1a448fc44ed757e9/onnxruntime-1.25.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:00548a16e8f0d52cb1c67ef50177e5e2be848ccffc6db60010ee37faaccbbb6f", size = 15853591, upload-time = "2026-04-22T17:19:42.325Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/de/8b406be6ea4f2c254f9bc850cbe8038064c7768a94cdf7785420b3652ea7/onnxruntime-1.25.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a071a0740388e0ffad081c583761f37837b113bde3d03dc70790ed6cf4f4de0b", size = 17996166, upload-time = "2026-04-22T17:20:05.873Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/52/8b2a10e8dedf5d486332bc2b3bca0b1ed8049c0b9e4a5cced95413aadfdd/onnxruntime-1.25.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:66e52f7a30d1f780a34aa84d68a0a04d382d9f5b141884ecbf45b7566b9fbde9", size = 17770987, upload-time = "2026-04-27T22:00:47.985Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/87/a424d2867477c42ef8c60172709281120797f7b0f1fd33cc36b24329c825/onnxruntime-1.25.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a5f41779f044d1ff75593df5c10a4d311bc82563687796d5218e2685b8f9da25", size = 15871829, upload-time = "2026-04-27T21:59:39.088Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/55/7819e64c515f17c86005447ede8122b974ca851255a94125e2119376f0f8/onnxruntime-1.25.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:905409e9eb2ef87f8226e073f56e71faf731c3e480ebd34952cf953730e4a4ff", size = 18024586, upload-time = "2026-04-27T22:00:05.359Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/36/b4f3eb5e95c66389aafd490950b5255e87c9333742cf90516eb50898e1dc/onnxruntime-1.25.1-cp312-cp312-win_amd64.whl", hash = "sha256:d4097b75b77486bb45835a8ed25b9a67976040ec6c258aeabae6aadfbdd1201c", size = 12905112, upload-time = "2026-04-27T22:00:36.478Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/fa/e5c43397632a399f542663ed3e3e37763ee203ba845b10b266cd2ede8925/onnxruntime-1.25.1-cp312-cp312-win_arm64.whl", hash = "sha256:b6c7aa5cae606d5c90a392679fac074b60f80025a2e83e1e90fdf882bd2a97f0", size = 12634433, upload-time = "2026-04-27T22:00:25.918Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/ee/db3ac55ef770347a926ac0f1317df0ab42c8bc604350833b30c7356bf936/onnxruntime-1.25.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e9d9b3b1694196bc3c5bc66f760a237a5e27d7688aaa2e2c9c0f66abd0486699", size = 17770761, upload-time = "2026-04-27T21:59:54.853Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/9a/33225481a94a59906fce44e27ab12fc3bddd2aaecdc6160bd73341ca1aba/onnxruntime-1.25.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:311d29b943e46a55ca72ca1ea48d7815c993122bfc359f68215fddeb9583fff4", size = 15871542, upload-time = "2026-04-27T21:59:41.881Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/09/f20aac60f6fcf840543be54d4e9252cfeb7e8c2bb6d22477aaeb180e763e/onnxruntime-1.25.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:98016a038b31160db23208706139fa3b99cd60bc1c5ffdade77aafd6a37a92ad", size = 18036960, upload-time = "2026-04-27T22:00:10.739Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/50/83/47964ac7e2f7e2f9e83c69ec466642c6835466252cc2ef0561eafeb56b66/onnxruntime-1.25.1-cp313-cp313-win_amd64.whl", hash = "sha256:08717d6eee2820807ba60b1b17032af207bd7aaca5b6c4abaee71f83feae877b", size = 12904886, upload-time = "2026-04-27T22:00:39.878Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d4/6c/a6c5aea47dc95fca7728f8a5af67c184ec9e7d4e7882125c7062e4bba8dd/onnxruntime-1.25.1-cp313-cp313-win_arm64.whl", hash = "sha256:84f8963d70e00167bae273ab7e80e9795bfc5eb94f6b23236a99c5c11af00844", size = 12634117, upload-time = "2026-04-27T22:00:29.15Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/8a/3b65e7911eec86c125e3d6f43d690a6f68671500543c0390ecd6eb59b771/onnxruntime-1.25.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03e800b3a4b48d9f3a2d23aacc4fa95486a3b406b14e51d1a9b8b6981d9adf9c", size = 15882935, upload-time = "2026-04-27T21:59:44.912Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/bb/410a760694f8ae7bbfc5fa81ccbeb7da241e6d520ee02a333a439cf462a2/onnxruntime-1.25.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fd83ef5c10cfc051a1cb465db692d57b996a1bc75a2a97b161398e29cdbc47ff", size = 18021727, upload-time = "2026-04-27T22:00:13.846Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/aa/04530bd38e31e26970fa1212346d76cf81705dc16a8ee5e6f4fb24634c11/onnxruntime-1.25.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:395eb662c437fa2407f44266e4778b75bff261b17c2a6fef042421f9069f871d", size = 17773721, upload-time = "2026-04-27T21:59:59.24Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/7f/ec79ab5cece6a688c944a7fa214a8511d548b9d5142a15d1a3d730b705f1/onnxruntime-1.25.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ae85395f41b291ae3e61780ec5092640181d369ef6c268aa8141c478b509e69", size = 15875954, upload-time = "2026-04-27T21:59:49.394Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/67/fe/20428215d822099ea2c1e3cf35c295cf1a58f467bf18b6c607597a39c18a/onnxruntime-1.25.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:828e1b12710fbedb6dfab5e7bae6f11563617cddf3c2e7e8d84c64de566a4a3a", size = 18038703, upload-time = "2026-04-27T22:00:16.199Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/b1/b15db965e6a68bc47ca7eb584de4e6b3d2d2f484d46cc57f715b596f6528/onnxruntime-1.25.1-cp314-cp314-win_amd64.whl", hash = "sha256:2affc9d2fd9ab013b9c9637464e649a0cca870d57ae18bfef74180eee65c3369", size = 13218513, upload-time = "2026-04-27T22:00:42.506Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/f9/25cd2d1b29cdc8140eee4afbb6fb930b69125526632b1d579bc747975306/onnxruntime-1.25.1-cp314-cp314-win_arm64.whl", hash = "sha256:3387d75d1a815b4b2495b4e47a05ef1b3bcb64a817ddc68587e0bfcb9702bcf6", size = 12969835, upload-time = "2026-04-27T22:00:31.504Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/0e/6c507d1e65b2421fb44e241cbba577c7276792279485024fb1752b43f5c5/onnxruntime-1.25.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:06280b06604660595037f783c6d24bc70cbe5c6093975f194cd1482e77d450de", size = 15883298, upload-time = "2026-04-27T21:59:51.991Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/df/4e/1c9df57496409dc86b320bd38f29ad7a34b7115e4f35b8fca44a827568a7/onnxruntime-1.25.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7e79fd5ce7db10ebcc24e020e2ed0159476e69e2326b9b7828e5aadcf6184212", size = 18021249, upload-time = "2026-04-27T22:00:18.954Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4272,11 +4278,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "parso"
|
||||
version = "0.8.6"
|
||||
version = "0.8.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/81/76/a1e769043c0c0c9fe391b702539d594731a4362334cdf4dc25d0c09761e7/parso-0.8.6.tar.gz", hash = "sha256:2b9a0332696df97d454fa67b81618fd69c35a7b90327cbe6ba5c92d2c68a7bfd", size = 401621, upload-time = "2026-02-09T15:45:24.425Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/30/4b/90c937815137d43ce71ba043cd3566221e9df6b9c805f24b5d138c9d40a7/parso-0.8.7.tar.gz", hash = "sha256:eaaac4c9fdd5e9e8852dc778d2d7405897ec510f2a298071453e5e3a07914bb1", size = 401824, upload-time = "2026-05-01T23:13:02.138Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/61/fae042894f4296ec49e3f193aff5d7c18440da9e48102c3315e1bc4519a7/parso-0.8.6-py2.py3-none-any.whl", hash = "sha256:2c549f800b70a5c4952197248825584cb00f033b29c692671d3bf08bf380baff", size = 106894, upload-time = "2026-02-09T15:45:21.391Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/5d/8268b644392ee874ee82a635cd0df1773de230bde356c38de28e298392cc/parso-0.8.7-py2.py3-none-any.whl", hash = "sha256:a8926eb2a1b915486941fdbd31e86a4baf88fe8c210f25f2f35ecec5b574ca1c", size = 107025, upload-time = "2026-05-01T23:12:58.867Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4837,14 +4843,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pyngrok"
|
||||
version = "8.1.0"
|
||||
version = "8.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pyyaml" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6b/4b/f1372b66985d76177ea492a5f82643b628a958ebdabc99350bb24c643d5b/pyngrok-8.1.0.tar.gz", hash = "sha256:18ea24460f5d74bf5c80feabd55ccf40e9552235ed103f967ec2ef99b57940c6", size = 45000, upload-time = "2026-04-27T02:54:44.771Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8d/ae/6664934258773db4666e65730c43b4b06730f78d49861a9a04ebcf4742ff/pyngrok-8.1.2.tar.gz", hash = "sha256:3b5383ec7dc4646ac0d046435eb58c6cd1cbc9acad70e6dee012b05dc25b070a", size = 45078, upload-time = "2026-04-29T15:16:53.969Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2f/50/d6f73708d6d358184c1227bbd5cf12dfa78af229fe81553bbd0e9a1ad073/pyngrok-8.1.0-py3-none-any.whl", hash = "sha256:8ba8497a2c7ac6b2f41a66f8102b815da30bfbb63321a70a0a4fa3a51b03e79b", size = 25882, upload-time = "2026-04-27T02:54:43.298Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/5c/7733776a6a9704bffee19d203f9be80f25f78a5011a9863971be60fe2763/pyngrok-8.1.2-py3-none-any.whl", hash = "sha256:849e9f55706288b00eb28f4ae8ea16b05e52c609f80ef4a88ca23b385f2f9178", size = 25935, upload-time = "2026-04-29T15:16:52.088Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5121,11 +5127,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2026.1.post1"
|
||||
version = "2026.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/56/db/b8721d71d945e6a8ac63c0fc900b2067181dbb50805958d4d4661cf7d277/pytz-2026.1.post1.tar.gz", hash = "sha256:3378dde6a0c3d26719182142c56e60c7f9af7e968076f31aae569d72a0358ee1", size = 321088, upload-time = "2026-03-03T07:47:50.683Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ff/46/dd499ec9038423421951e4fad73051febaa13d2df82b4064f87af8b8c0c3/pytz-2026.2.tar.gz", hash = "sha256:0e60b47b29f21574376f218fe21abc009894a2321ea16c6754f3cad6eb7cdd6a", size = 320861, upload-time = "2026-05-04T01:35:29.667Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/99/781fe0c827be2742bcc775efefccb3b048a3a9c6ce9aec0cbf4a101677e5/pytz-2026.1.post1-py2.py3-none-any.whl", hash = "sha256:f2fd16142fda348286a75e1a524be810bb05d444e5a081f37f7affc635035f7a", size = 510489, upload-time = "2026-03-03T07:47:49.167Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/dd/96da98f892250475bdf2328112d7468abdd4acc7b902b6af23f4ed958ea0/pytz-2026.2-py2.py3-none-any.whl", hash = "sha256:04156e608bee23d3792fd45c94ae47fae1036688e75032eea2e3bf0323d1f126", size = 510141, upload-time = "2026-05-04T01:35:27.408Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5769,15 +5775,15 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "sentry-sdk"
|
||||
version = "2.58.0"
|
||||
version = "2.59.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "certifi" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/26/b3/fb8291170d0e844173164709fc0fa0c221ed75a5da740c8746f2a83b4eb1/sentry_sdk-2.58.0.tar.gz", hash = "sha256:c1144d947352d54e5b7daa63596d9f848adf684989c06c4f5a659f0c85a18f6f", size = 438764, upload-time = "2026-04-13T17:23:26.265Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/65/e0/9bf5e5fc7442b10880f3ec0eff0ef4208b84a099606f343ec4f5445227fb/sentry_sdk-2.59.0.tar.gz", hash = "sha256:cd265808ef8bf3f3edf69b527c0a0b2b6b1322762679e55b8987db2e9584aec1", size = 447331, upload-time = "2026-05-04T12:19:06.538Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/eb/d875669993b762556ae8b2efd86219943b4c0864d22204d622a9aee3052b/sentry_sdk-2.58.0-py2.py3-none-any.whl", hash = "sha256:688d1c704ddecf382ea3326f21a67453d4caa95592d722b7c780a36a9d23109e", size = 460919, upload-time = "2026-04-13T17:23:24.675Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/00/b8cc413748fb6383d1582e7cda51314f99743351c462a92dc690d5b5853b/sentry_sdk-2.59.0-py2.py3-none-any.whl", hash = "sha256:abcf65ee9a9d9cdebf9ad369782408ecca9c1c792686ef06ba34f5ab233527fe", size = 468432, upload-time = "2026-05-04T12:19:04.741Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6012,14 +6018,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "tifffile"
|
||||
version = "2026.4.11"
|
||||
version = "2026.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d7/4a/e687f5957fead200faad58dbf9c9431a2bbb118040e96f5fb8a55f7ebc50/tifffile-2026.4.11.tar.gz", hash = "sha256:17758ff0c0d4db385792a083ad3ca51fcb0f4d942642f4d8f8bc1287fdcf17bc", size = 394956, upload-time = "2026-04-12T01:57:28.793Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6c/3e/695c7ab56be57814e369c1f38bc3f64b9dea0a83e867d00c0c9d613a9929/tifffile-2026.5.2.tar.gz", hash = "sha256:21b10227ede8493814a34676774797f721f487e36cb0530e7c3bd882caa87f5a", size = 429140, upload-time = "2026-05-02T20:19:31.497Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/9f/74f110b4271ded519c7add4341cbabc824de26817ff1c345b3109df9e99c/tifffile-2026.4.11-py3-none-any.whl", hash = "sha256:9b94ffeddb39e97601af646345e8808f885773de01b299e480ed6d3a41509ec9", size = 248227, upload-time = "2026-04-12T01:57:26.969Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/af/ce4df3ca29122d219c45d3e86e5ff9a9df03b8cf31afd76817b662c803a3/tifffile-2026.5.2-py3-none-any.whl", hash = "sha256:5129b53b826e768a5b1af26b765eeea75c2d0a227d2d12849617e0737588e105", size = 266420, upload-time = "2026-05-02T20:19:29.814Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6243,7 +6249,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "5.3.0"
|
||||
version = "5.5.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "huggingface-hub" },
|
||||
@@ -6256,9 +6262,9 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "typer" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fc/1a/70e830d53ecc96ce69cfa8de38f163712d2b43ac52fbd743f39f56025c31/transformers-5.3.0.tar.gz", hash = "sha256:009555b364029da9e2946d41f1c5de9f15e6b1df46b189b7293f33a161b9c557", size = 8830831, upload-time = "2026-03-04T17:41:46.119Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a5/1e/1e244ab2ab50a863e6b52cc55761910567fa532b69a6740f6e99c5fdbd98/transformers-5.5.4.tar.gz", hash = "sha256:2e67cadba81fc7608cc07c4dd54f524820bc3d95b1cabd0ef3db7733c4f8b82e", size = 8227649, upload-time = "2026-04-13T16:55:55.181Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/88/ae8320064e32679a5429a2c9ebbc05c2bf32cefb6e076f9b07f6d685a9b4/transformers-5.3.0-py3-none-any.whl", hash = "sha256:50ac8c89c3c7033444fb3f9f53138096b997ebb70d4b5e50a2e810bf12d3d29a", size = 10661827, upload-time = "2026-03-04T17:41:42.722Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/fb/162a66789c65e5afa3b051309240c26bf37fbc8fea285b4546ae747995a2/transformers-5.5.4-py3-none-any.whl", hash = "sha256:0bd6281b82966fe5a7a16f553ea517a9db1dee6284d7cb224dfd88fc0dd1c167", size = 10236696, upload-time = "2026-04-13T16:55:51.497Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6287,7 +6293,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "typer"
|
||||
version = "0.25.0"
|
||||
version = "0.25.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "annotated-doc" },
|
||||
@@ -6295,9 +6301,9 @@ dependencies = [
|
||||
{ name = "rich" },
|
||||
{ name = "shellingham" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7b/27/ede8cec7596e0041ba7e7b80b47d132562f56ff454313a16f6084e555c9f/typer-0.25.0.tar.gz", hash = "sha256:123eaf9f19bb40fd268310e12a542c0c6b4fab9c98d9d23342a01ff95e3ce930", size = 120150, upload-time = "2026-04-26T08:46:14.767Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e4/51/9aed62104cea109b820bbd6c14245af756112017d309da813ef107d42e7e/typer-0.25.1.tar.gz", hash = "sha256:9616eb8853a09ffeabab1698952f33c6f29ffdbceb4eaeecf571880e8d7664cc", size = 122276, upload-time = "2026-04-30T19:32:16.964Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/72/193d4e586ec5a4db834a36bbeb47641a62f951f114ffd0fe5b1b46e8d56f/typer-0.25.0-py3-none-any.whl", hash = "sha256:ac01b48823d3db9a83c9e164338057eadbb1c9957a2a6b4eeb486669c560b5dc", size = 55993, upload-time = "2026-04-26T08:46:15.889Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/f9/2b3ff4e56e5fa7debfaf9eb135d0da96f3e9a1d5b27222223c7296336e5f/typer-0.25.1-py3-none-any.whl", hash = "sha256:75caa44ed46a03fb2dab8808753ffacdbfea88495e74c85a28c5eefcf5f39c89", size = 58409, upload-time = "2026-04-30T19:32:18.271Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6428,7 +6434,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "virtualenv"
|
||||
version = "21.2.4"
|
||||
version = "21.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "distlib" },
|
||||
@@ -6436,9 +6442,9 @@ dependencies = [
|
||||
{ name = "platformdirs" },
|
||||
{ name = "python-discovery" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0c/98/3a7e644e19cb26133488caff231be390579860bbbb3da35913c49a1d0a46/virtualenv-21.2.4.tar.gz", hash = "sha256:b294ef68192638004d72524ce7ef303e9d0cf5a44c95ce2e54a7500a6381cada", size = 5850742, upload-time = "2026-04-14T22:15:31.438Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ec/0d/915c02c94d207b85580eb09bffab54438a709e7288524094fe781da526c2/virtualenv-21.3.1.tar.gz", hash = "sha256:c2305bc1fddeec40699b8370d13f8d431b0701f00ce895061ce493aeded4426b", size = 7613791, upload-time = "2026-05-05T01:34:31.402Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/27/8d/edd0bd910ff803c308ee9a6b7778621af0d10252219ad9f19ef4d4982a61/virtualenv-21.2.4-py3-none-any.whl", hash = "sha256:29d21e941795206138d0f22f4e45ff7050e5da6c6472299fb7103318763861ac", size = 5831232, upload-time = "2026-04-14T22:15:29.342Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/4f/f71e641e504111a5a74e3a20bc52d01bd86788b22699dd3fee1c63253cf6/virtualenv-21.3.1-py3-none-any.whl", hash = "sha256:d1a71cf58f2f9228fff23a1f6ec15d39785c6b32e03658d104974247145edd35", size = 7594539, upload-time = "2026-05-05T01:34:28.98Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6542,11 +6548,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "wcwidth"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a38726c948d3399905a4c7cabd0df578ede5dc51f0ec2/wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159", size = 159684, upload-time = "2026-02-06T19:19:40.919Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2c/ee/afaf0f85a9a18fe47a67f1e4422ed6cf1fe642f0ae0a2f81166231303c52/wcwidth-0.7.0.tar.gz", hash = "sha256:90e3a7ea092341c44b99562e75d09e4d5160fe7a3974c6fb842a101a95e7eed0", size = 182132, upload-time = "2026-05-02T16:04:12.653Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/41/52/e465037f5375f43533d1a80b6923955201596a99142ed524d77b571a1418/wcwidth-0.7.0-py3-none-any.whl", hash = "sha256:5d69154c429a82910e241c738cd0e2976fac8a2dd47a1a805f4afed1c0f136f2", size = 110825, upload-time = "2026-05-02T16:04:11.033Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
Reference in New Issue
Block a user