diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml
index d78bdd21b..b6680db73 100644
--- a/.github/workflows/fast_tests.yml
+++ b/.github/workflows/fast_tests.yml
@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# This workflow handles fast testing.
+# This workflow validates each optional-dependency tier in isolation.
+# Each tier installs a different extra and runs the full test suite.
+# Tests that require an extra not installed in the current tier are
+# skipped automatically via pytest.importorskip guards.
name: Fast Tests
on:
@@ -54,8 +57,9 @@ concurrency:
cancel-in-progress: true
jobs:
- # This job runs pytests with the default dependencies.
- # It runs everytime we commit to a PR or push to main
+ # This job runs pytests in isolated dependency tiers.
+ # Each tier installs a different extra and runs the full suite;
+ # tests gated behind other extras skip automatically.
fast-pytest-tests:
name: Fast Pytest Tests
runs-on: ubuntu-latest
@@ -89,8 +93,9 @@ jobs:
version: ${{ env.UV_VERSION }}
python-version: ${{ env.PYTHON_VERSION }}
- - name: Install lerobot with test extras
- run: uv sync --locked --extra "test"
+ # ── Tier 1: Base ──────────────────────────────────────
+ - name: "Tier 1 — Install: base"
+ run: uv sync --locked --extra test
- name: Login to Hugging Face
if: env.HF_USER_TOKEN != ''
@@ -98,5 +103,26 @@ jobs:
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- - name: Run pytest
+ - name: "Tier 1 — Test: base"
+ run: uv run pytest tests -vv --maxfail=10
+
+ # ── Tier 2: Dataset ──────────────────────────────────
+ - name: "Tier 2 — Install: dataset"
+ run: uv sync --locked --extra test --extra dataset
+
+ - name: "Tier 2 — Test: dataset"
+ run: uv run pytest tests -vv --maxfail=10
+
+ # ── Tier 3: Hardware ─────────────────────────────────
+ - name: "Tier 3 — Install: hardware"
+ run: uv sync --locked --extra test --extra hardware
+
+ - name: "Tier 3 — Test: hardware"
+ run: uv run pytest tests -vv --maxfail=10
+
+ # ── Tier 4: Viz ──────────────────────────────────────
+ - name: "Tier 4 — Install: viz"
+ run: uv sync --locked --extra test --extra viz
+
+ - name: "Tier 4 — Test: viz"
run: uv run pytest tests -vv --maxfail=10
diff --git a/docs/source/adding_benchmarks.mdx b/docs/source/adding_benchmarks.mdx
index 3a024f026..6e9d23bdf 100644
--- a/docs/source/adding_benchmarks.mdx
+++ b/docs/source/adding_benchmarks.mdx
@@ -216,7 +216,7 @@ class MyBenchmarkEnvConfig(EnvConfig):
def get_env_processors(self):
"""Override if your benchmark needs observation/action transforms."""
- from lerobot.processor.pipeline import PolicyProcessorPipeline
+ from lerobot.processor import PolicyProcessorPipeline
from lerobot.processor.env_processor import MyBenchmarkProcessorStep
return (
PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]),
diff --git a/docs/source/async.mdx b/docs/source/async.mdx
index a46408a0d..7b1efae97 100644
--- a/docs/source/async.mdx
+++ b/docs/source/async.mdx
@@ -170,7 +170,7 @@ python -m lerobot.async_inference.robot_client \
```python
import threading
from lerobot.robots.so_follower import SO100FollowerConfig
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.async_inference.configs import RobotClientConfig
from lerobot.async_inference.robot_client import RobotClient
from lerobot.async_inference.helpers import visualize_action_queue_size
diff --git a/docs/source/backwardcomp.mdx b/docs/source/backwardcomp.mdx
index 3366c8ab9..a83ee2e2e 100644
--- a/docs/source/backwardcomp.mdx
+++ b/docs/source/backwardcomp.mdx
@@ -41,7 +41,7 @@ The script:
```python
# New usage pattern (after migration)
-from lerobot.policies.factory import make_policy, make_pre_post_processors
+from lerobot.policies import make_policy, make_pre_post_processors
# Load model and processors separately
policy = make_policy(config, ds_meta=dataset.meta)
diff --git a/docs/source/bring_your_own_policies.mdx b/docs/source/bring_your_own_policies.mdx
index 38c32aa71..57ecc2fb2 100644
--- a/docs/source/bring_your_own_policies.mdx
+++ b/docs/source/bring_your_own_policies.mdx
@@ -47,9 +47,9 @@ Here is a template to get you started, customize the parameters and methods as n
```python
# configuration_my_custom_policy.py
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+from lerobot.configs import PreTrainedConfig
+from lerobot.optim import AdamWConfig
+from lerobot.optim import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("my_custom_policy")
@dataclass
@@ -120,7 +120,7 @@ import torch
import torch.nn as nn
from typing import Any
-from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.policies import PreTrainedPolicy
from lerobot.utils.constants import ACTION
from .configuration_my_custom_policy import MyCustomPolicyConfig
diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx
index 8af0f5ae5..2dc2859dd 100644
--- a/docs/source/cameras.mdx
+++ b/docs/source/cameras.mdx
@@ -79,9 +79,8 @@ The following examples show how to use the camera API to configure and capture f
```python
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-from lerobot.cameras.opencv.camera_opencv import OpenCVCamera
-from lerobot.cameras.configs import ColorMode, Cv2Rotation
+from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
+from lerobot.cameras import ColorMode, Cv2Rotation
# Construct an `OpenCVCameraConfig` with your desired FPS, resolution, color mode, and rotation.
config = OpenCVCameraConfig(
@@ -126,9 +125,8 @@ with OpenCVCamera(config) as camera:
```python
-from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig
-from lerobot.cameras.realsense.camera_realsense import RealSenseCamera
-from lerobot.cameras.configs import ColorMode, Cv2Rotation
+from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
+from lerobot.cameras import ColorMode, Cv2Rotation
# Create a `RealSenseCameraConfig` specifying your camera’s serial number and enabling depth.
config = RealSenseCameraConfig(
diff --git a/docs/source/dataset_subtask.mdx b/docs/source/dataset_subtask.mdx
index beb5d80bd..6264aca22 100644
--- a/docs/source/dataset_subtask.mdx
+++ b/docs/source/dataset_subtask.mdx
@@ -95,7 +95,7 @@ After completing your annotation:
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
@@ -133,11 +133,10 @@ if has_subtasks:
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
-from lerobot.processor.tokenizer_processor import TokenizerProcessor
-from lerobot.processor.pipeline import ProcessorPipeline
+from lerobot.processor import TokenizerProcessorStep
-# Create a tokenizer processor
-tokenizer_processor = TokenizerProcessor(
+# Create a tokenizer processor step
+tokenizer_processor = TokenizerProcessorStep(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
@@ -158,7 +157,7 @@ When subtasks are available in the batch, the tokenizer processor adds:
```python
import torch
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
@@ -182,7 +181,7 @@ for batch in dataloader:
Try loading a dataset with subtask annotations:
```python
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx
index 884e84d8c..a87bd325b 100644
--- a/docs/source/earthrover_mini_plus.mdx
+++ b/docs/source/earthrover_mini_plus.mdx
@@ -66,10 +66,10 @@ The SDK gives you:
Follow our [Installation Guide](./installation) to install LeRobot.
-In addition to the base installation, install the EarthRover Mini dependencies:
+In addition to the base installation, install the EarthRover Mini with hardware dependencies:
```bash
-pip install -e .
+pip install -e ".[hardware]"
```
## How It Works
diff --git a/docs/source/env_processor.mdx b/docs/source/env_processor.mdx
index 290af3b34..8bfafdfb9 100644
--- a/docs/source/env_processor.mdx
+++ b/docs/source/env_processor.mdx
@@ -173,8 +173,8 @@ observation = {
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
```python
-from lerobot.envs.factory import make_env_pre_post_processors
-from lerobot.envs.configs import LiberoEnv, PushtEnv
+from lerobot.envs import make_env_pre_post_processors, PushtEnv
+from lerobot.envs.configs import LiberoEnv
# For LIBERO: Returns LiberoProcessorStep in preprocessor
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
@@ -257,7 +257,7 @@ def eval_main(cfg: EvalPipelineConfig):
The `LiberoProcessorStep` demonstrates a real-world environment processor:
```python
-from lerobot.processor.pipeline import ObservationProcessorStep
+from lerobot.processor import ObservationProcessorStep
@dataclass
@ProcessorStepRegistry.register(name="libero_processor")
diff --git a/docs/source/envhub.mdx b/docs/source/envhub.mdx
index 36c08a8b3..47f5567a8 100644
--- a/docs/source/envhub.mdx
+++ b/docs/source/envhub.mdx
@@ -34,7 +34,7 @@ Finally, your environment must implement the standard `gym.vector.VectorEnv` int
Loading an environment from the Hub is as simple as:
```python
-from lerobot.envs.factory import make_env
+from lerobot.envs import make_env
# Load a hub environment (requires explicit consent to run remote code)
env = make_env("lerobot/cartpole-env", trust_remote_code=True)
@@ -191,7 +191,7 @@ api.upload_folder(
### Basic Usage
```python
-from lerobot.envs.factory import make_env
+from lerobot.envs import make_env
# Load from the hub
envs_dict = make_env(
@@ -314,7 +314,7 @@ env = make_env("trusted-org/verified-env@a1b2c3d4", trust_remote_code=True)
Here's a complete example using the reference CartPole environment:
```python
-from lerobot.envs.factory import make_env
+from lerobot.envs import make_env
import numpy as np
# Load the environment
diff --git a/docs/source/envhub_isaaclab_arena.mdx b/docs/source/envhub_isaaclab_arena.mdx
index 828d51bad..b934240d6 100644
--- a/docs/source/envhub_isaaclab_arena.mdx
+++ b/docs/source/envhub_isaaclab_arena.mdx
@@ -58,10 +58,10 @@ pip install -e .
cd ..
-# 5. Install LeRobot
+# 5. Install LeRobot (evaluation extra for env/policy evaluation)
git clone https://github.com/huggingface/lerobot.git
cd lerobot
-pip install -e .
+pip install -e ".[evaluation]"
cd ..
@@ -262,7 +262,7 @@ def main(cfg: EvalPipelineConfig):
"""Run random action rollout for IsaacLab Arena environment."""
logging.info(pformat(asdict(cfg)))
- from lerobot.envs.factory import make_env
+ from lerobot.envs import make_env
env_dict = make_env(
cfg.env,
diff --git a/docs/source/envhub_leisaac.mdx b/docs/source/envhub_leisaac.mdx
index 2537700a5..91bb6a871 100644
--- a/docs/source/envhub_leisaac.mdx
+++ b/docs/source/envhub_leisaac.mdx
@@ -74,7 +74,7 @@ EnvHub exposes every LeIsaac-supported task in a uniform interface. The examples
# envhub_random_action.py
import torch
-from lerobot.envs.factory import make_env
+from lerobot.envs import make_env
# Load from the hub
envs_dict = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True)
@@ -142,7 +142,7 @@ from lerobot.teleoperators import ( # noqa: F401
)
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging
-from lerobot.envs.factory import make_env
+from lerobot.envs import make_env
@dataclass
@@ -282,7 +282,7 @@ Note: when working with `bi_so101_fold_cloth`, call `initialize()` immediately a
```python
import torch
-from lerobot.envs.factory import make_env
+from lerobot.envs import make_env
# Load from the hub
envs_dict = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True)
diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx
index 8e50a2aec..d03e35d8d 100644
--- a/docs/source/il_robots.mdx
+++ b/docs/source/il_robots.mdx
@@ -58,8 +58,8 @@ lerobot-teleoperate \
```python
-from lerobot.teleoperators.so_leader import SO101LeaderConfig, SO101Leader
-from lerobot.robots.so_follower import SO101FollowerConfig, SO101Follower
+from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
+from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem58760431541",
@@ -116,9 +116,9 @@ lerobot-teleoperate \
```python
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader
-from lerobot.robots.koch_follower import KochFollowerConfig, KochFollower
+from lerobot.cameras.opencv import OpenCVCameraConfig
+from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
+from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
camera_config = {
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
@@ -195,13 +195,12 @@ lerobot-record \
```python
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.utils import hw_to_dataset_features
+from lerobot.cameras.opencv import OpenCVCameraConfig
+from lerobot.datasets import LeRobotDataset
+from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
-from lerobot.teleoperators.so_leader.config_so100_leader import SO100LeaderConfig
-from lerobot.teleoperators.so_leader.so100_leader import SO100Leader
-from lerobot.utils.control_utils import init_keyboard_listener
+from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
+from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.scripts.lerobot_record import record_loop
@@ -410,9 +409,8 @@ lerobot-replay \
```python
import time
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig
-from lerobot.robots.so_follower.so100_follower import SO100Follower
+from lerobot.datasets import LeRobotDataset
+from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
@@ -532,15 +530,14 @@ lerobot-record \
```python
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.utils import hw_to_dataset_features
-from lerobot.policies.act.modeling_act import ACTPolicy
-from lerobot.policies.factory import make_pre_post_processors
-from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig
-from lerobot.robots.so_follower.so100_follower import SO100Follower
+from lerobot.cameras.opencv import OpenCVCameraConfig
+from lerobot.datasets import LeRobotDataset
+from lerobot.utils.feature_utils import hw_to_dataset_features
+from lerobot.policies.act import ACTPolicy
+from lerobot.policies import make_pre_post_processors
+from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.scripts.lerobot_record import record_loop
-from lerobot.utils.control_utils import init_keyboard_listener
+from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx
index a988523b5..1d772fc97 100644
--- a/docs/source/installation.mdx
+++ b/docs/source/installation.mdx
@@ -116,6 +116,8 @@ brew install ffmpeg
## Step 3: Install LeRobot 🤗
+The base `lerobot` install is intentionally **lightweight** — it includes only core ML dependencies (PyTorch, torchvision, numpy, opencv, einops, draccus, huggingface-hub, gymnasium, safetensors). Heavier dependencies are gated behind optional extras so you only install what you need.
+
### From Source
First, clone the repository and navigate into the directory:
@@ -131,12 +133,16 @@ Then, install the library in editable mode. This is useful if you plan to contri
```bash
-pip install -e .
+pip install -e ".[core_scripts]" # For robot workflows (recording, replaying, calibrate)
+pip install -e ".[training]" # For training policies
+pip install -e ".[all]" # Everything (all policies, envs, hardware, dev tools)
```
```bash
-uv pip install -e .
+uv pip install -e ".[core_scripts]" # For robot workflows (recording, replaying, calibrate)
+uv pip install -e ".[training]" # For training policies
+uv pip install -e ".[all]" # Everything (all policies, envs, hardware, dev tools)
```
@@ -162,26 +168,48 @@ uv pip install lerobot
-_This installs only the default dependencies._
+_This installs only the core ML dependencies. You will need to add extras for most workflows._
-**Extra Features:**
-To install additional functionality, use one of the following (If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.):
+**Feature Extras:**
+LeRobot provides **feature-scoped extras** that map to common workflows. If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.
+
+| Extra | What it adds | Typical use case |
+| ---------- | ------------------------------------------- | ----------------------------------- |
+| `dataset` | `datasets`, `av`, `torchcodec`, `jsonlines` | Loading & creating datasets |
+| `training` | `dataset` + `accelerate`, `wandb` | Training policies |
+| `hardware` | `pynput`, `pyserial`, `deepdiff` | Connecting to real robots |
+| `viz` | `rerun-sdk` | Visualization during recording/eval |
+
+**Composite Extras** combine feature extras for common CLI scripts:
+
+| Extra | Includes | Typical use case |
+| -------------- | ------------------------------ | ------------------------------------------------------- |
+| `core_scripts` | `dataset` + `hardware` + `viz` | `lerobot-record`, `lerobot-replay`, `lerobot-calibrate` |
+| `evaluation` | `av` | `lerobot-eval` (add policy + env extras as needed) |
+| `dataset_viz` | `dataset` + `viz` | `lerobot-dataset-viz`, `lerobot-imgtransform-viz` |
```bash
-pip install 'lerobot[all]' # All available features
-pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
-pip install 'lerobot[feetech]' # Feetech motor support
+pip install 'lerobot[core_scripts]' # Record, replay, calibrate
+pip install 'lerobot[training]' # Train policies
+pip install 'lerobot[core_scripts,training]' # Record + train
+pip install 'lerobot[all]' # Everything
```
-_Replace `[...]` with your desired features._
+**Policy, environment, and hardware extras** are still available for specific dependencies:
-**Available Tags:**
-For a full list of optional dependencies, see:
-https://pypi.org/project/lerobot/
+```bash
+pip install 'lerobot[pi]' # Pi0/Pi0.5/Pi0-FAST policy deps
+pip install 'lerobot[smolvla]' # SmolVLA policy deps
+pip install 'lerobot[diffusion]' # Diffusion policy deps (diffusers)
+pip install 'lerobot[aloha,pusht]' # Simulation environments
+pip install 'lerobot[feetech]' # Feetech motor support
+```
+
+_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._
### Troubleshooting
-If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
+If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
To install these for Linux run:
```bash
@@ -196,8 +224,8 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c
### Simulations
-Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
-Example:
+Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)).
+These automatically include the `dataset` extra.
```bash
pip install -e ".[aloha]" # or "[pusht]" for example
@@ -213,7 +241,7 @@ pip install -e ".[feetech]" # or "[dynamixel]" for example
### Experiment Tracking
-To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
+Weights and Biases is included in the `training` extra. To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with:
```bash
wandb login
diff --git a/docs/source/introduction_processors.mdx b/docs/source/introduction_processors.mdx
index 6f3768615..4395e889b 100644
--- a/docs/source/introduction_processors.mdx
+++ b/docs/source/introduction_processors.mdx
@@ -19,10 +19,10 @@ This means that your favorite policy can be used like this:
```python
import torch
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.policies.factory import make_pre_post_processors
+from lerobot.datasets import LeRobotDataset
+from lerobot.policies import make_pre_post_processors
from lerobot.policies.your_policy import YourPolicy
-from lerobot.processor.pipeline import RobotProcessorPipeline, PolicyProcessorPipeline
+from lerobot.processor import RobotProcessorPipeline, PolicyProcessorPipeline
dataset = LeRobotDataset("hf_user/dataset", episodes=[0])
sample = dataset[10]
@@ -260,7 +260,7 @@ Since processor pipelines can add new features (like velocity fields), change te
These functions work together by starting with robot hardware specifications (`create_initial_features()`) then simulating the entire pipeline transformation (`aggregate_pipeline_dataset_features()`) to compute the final feature dictionary that gets passed to `LeRobotDataset.create()`, ensuring perfect alignment between what processors output and what datasets expect to store.
```python
-from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
+from lerobot.datasets import aggregate_pipeline_dataset_features
# Start with robot's raw features
initial_features = create_initial_features(
diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx
index 235a355bd..8ab4a5d40 100644
--- a/docs/source/lerobot-dataset-v3.mdx
+++ b/docs/source/lerobot-dataset-v3.mdx
@@ -89,7 +89,7 @@ A core v3 principle is **decoupling storage from the user API**: data is stored
```python
import torch
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
repo_id = "yaak-ai/L2D-v3"
@@ -135,7 +135,7 @@ for batch in data_loader:
Use `StreamingLeRobotDataset` to iterate directly from the Hub without local copies. This allows to stream large datasets without the need to downloading them onto disk or loading them onto memory, and is a key feature of the new dataset format.
```python
-from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
+from lerobot.datasets import StreamingLeRobotDataset
repo_id = "yaak-ai/L2D-v3"
dataset = StreamingLeRobotDataset(repo_id) # streams directly from the Hub
@@ -167,8 +167,8 @@ Currently, transforms are applied during **training time only**, not during reco
Use the `image_transforms` parameter when loading a dataset for training:
```python
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig, ImageTransformConfig
+from lerobot.datasets import LeRobotDataset
+from lerobot.transforms import ImageTransforms, ImageTransformsConfig, ImageTransformConfig
# Option 1: Use default transform configuration (disabled by default)
transforms_config = ImageTransformsConfig(
@@ -290,7 +290,7 @@ python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id= list[float]:
diff --git a/examples/tutorial/act/act_using_example.py b/examples/tutorial/act/act_using_example.py
index 15254d8eb..6a8f73287 100644
--- a/examples/tutorial/act/act_using_example.py
+++ b/examples/tutorial/act/act_using_example.py
@@ -1,9 +1,9 @@
import torch
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.policies.act.modeling_act import ACTPolicy
-from lerobot.policies.factory import make_pre_post_processors
+from lerobot.cameras.opencv import OpenCVCameraConfig
+from lerobot.datasets import LeRobotDatasetMetadata
+from lerobot.policies import make_pre_post_processors
+from lerobot.policies.act import ACTPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
diff --git a/examples/tutorial/async-inf/robot_client.py b/examples/tutorial/async-inf/robot_client.py
index db6ead3fe..ac2331f38 100644
--- a/examples/tutorial/async-inf/robot_client.py
+++ b/examples/tutorial/async-inf/robot_client.py
@@ -3,7 +3,7 @@ import threading
from lerobot.async_inference.configs import RobotClientConfig
from lerobot.async_inference.helpers import visualize_action_queue_size
from lerobot.async_inference.robot_client import RobotClient
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.robots.so_follower import SO100FollowerConfig
diff --git a/examples/tutorial/diffusion/diffusion_training_example.py b/examples/tutorial/diffusion/diffusion_training_example.py
index dc6ca68a3..5cca15923 100644
--- a/examples/tutorial/diffusion/diffusion_training_example.py
+++ b/examples/tutorial/diffusion/diffusion_training_example.py
@@ -4,13 +4,11 @@ from pathlib import Path
import torch
-from lerobot.configs.types import FeatureType
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.datasets.feature_utils import dataset_to_policy_features
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
-from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
-from lerobot.policies.factory import make_pre_post_processors
+from lerobot.configs import FeatureType
+from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata
+from lerobot.policies import make_pre_post_processors
+from lerobot.policies.diffusion import DiffusionConfig, DiffusionPolicy
+from lerobot.utils.feature_utils import dataset_to_policy_features
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
diff --git a/examples/tutorial/diffusion/diffusion_using_example.py b/examples/tutorial/diffusion/diffusion_using_example.py
index 9b31cf359..8f9150ad6 100644
--- a/examples/tutorial/diffusion/diffusion_using_example.py
+++ b/examples/tutorial/diffusion/diffusion_using_example.py
@@ -1,9 +1,9 @@
import torch
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
-from lerobot.policies.factory import make_pre_post_processors
+from lerobot.cameras.opencv import OpenCVCameraConfig
+from lerobot.datasets import LeRobotDatasetMetadata
+from lerobot.policies import make_pre_post_processors
+from lerobot.policies.diffusion import DiffusionPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
diff --git a/examples/tutorial/pi0/using_pi0_example.py b/examples/tutorial/pi0/using_pi0_example.py
index d8cf9dbff..66f6309c2 100644
--- a/examples/tutorial/pi0/using_pi0_example.py
+++ b/examples/tutorial/pi0/using_pi0_example.py
@@ -1,11 +1,11 @@
import torch
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-from lerobot.datasets.feature_utils import hw_to_dataset_features
-from lerobot.policies.factory import make_pre_post_processors
-from lerobot.policies.pi0.modeling_pi0 import PI0Policy
+from lerobot.cameras.opencv import OpenCVCameraConfig
+from lerobot.policies import make_pre_post_processors
+from lerobot.policies.pi0 import PI0Policy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
+from lerobot.utils.feature_utils import hw_to_dataset_features
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py
index d367a01ce..8a08d6d56 100644
--- a/examples/tutorial/rl/hilserl_example.py
+++ b/examples/tutorial/rl/hilserl_example.py
@@ -6,17 +6,17 @@ from queue import Empty, Full
import torch
import torch.optim as optim
-from lerobot.datasets.feature_utils import hw_to_dataset_features
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
-from lerobot.policies.sac.configuration_sac import SACConfig
+from lerobot.policies import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
+from lerobot.teleoperators import TeleopEvents
from lerobot.teleoperators.so_leader import SO100LeaderConfig
-from lerobot.teleoperators.utils import TeleopEvents
+from lerobot.utils.feature_utils import hw_to_dataset_features
LOG_EVERY = 10
SEND_EVERY = 10
diff --git a/examples/tutorial/rl/reward_classifier_example.py b/examples/tutorial/rl/reward_classifier_example.py
index 4af6b899c..b386bf4db 100644
--- a/examples/tutorial/rl/reward_classifier_example.py
+++ b/examples/tutorial/rl/reward_classifier_example.py
@@ -1,8 +1,7 @@
import torch
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.policies.factory import make_policy, make_pre_post_processors
-from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
+from lerobot.datasets import LeRobotDataset
+from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors
def main():
diff --git a/examples/tutorial/smolvla/using_smolvla_example.py b/examples/tutorial/smolvla/using_smolvla_example.py
index b99126efa..f59603db7 100644
--- a/examples/tutorial/smolvla/using_smolvla_example.py
+++ b/examples/tutorial/smolvla/using_smolvla_example.py
@@ -1,11 +1,11 @@
import torch
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-from lerobot.datasets.feature_utils import hw_to_dataset_features
-from lerobot.policies.factory import make_pre_post_processors
-from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
+from lerobot.cameras.opencv import OpenCVCameraConfig
+from lerobot.policies import make_pre_post_processors
+from lerobot.policies.smolvla import SmolVLAPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
+from lerobot.utils.feature_utils import hw_to_dataset_features
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
diff --git a/pyproject.toml b/pyproject.toml
index 79409a200..2f12840e7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -58,45 +58,74 @@ classifiers = [
keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artificial intelligence"]
dependencies = [
-
- # Hugging Face dependencies
- "datasets>=4.0.0,<5.0.0",
- "diffusers>=0.27.2,<0.36.0",
- "huggingface-hub>=1.0.0,<2.0.0",
- "accelerate>=1.10.0,<2.0.0",
-
- # Core dependencies
- "numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
- "setuptools>=71.0.0,<81.0.0",
- "cmake>=3.29.0.1,<4.2.0",
- "packaging>=24.2,<26.0",
-
+ # Core ML
"torch>=2.7,<2.11.0",
- "torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
"torchvision>=0.22.0,<0.26.0",
-
- "einops>=0.8.0,<0.9.0",
+ "numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless.
"opencv-python-headless>=4.9.0,<4.14.0",
- "av>=15.0.0,<16.0.0",
- "jsonlines>=4.0.0,<5.0.0",
- "pynput>=1.7.8,<1.9.0",
- "pyserial>=3.5,<4.0",
+ "Pillow>=10.0.0,<13.0.0",
+ "einops>=0.8.0,<0.9.0",
- "wandb>=0.24.0,<0.25.0",
+ # Config & Hub
"draccus==0.10.0", # TODO: Relax version constraint
- "gymnasium>=1.1.1,<2.0.0",
- "rerun-sdk>=0.24.0,<0.27.0",
+ "huggingface-hub>=1.0.0,<2.0.0",
+ "requests>=2.32.0,<3.0.0",
- # Support dependencies
- "deepdiff>=7.0.1,<9.0.0",
- "imageio[ffmpeg]>=2.34.0,<3.0.0",
+ # Environments
+ # NOTE: gymnasium is used in lerobot.envs (lerobot-train, lerobot-eval), policies/factory,
+ # and robots/unitree. Moving it to an optional extra would require import guards across many
+ # tightly-coupled modules. Candidate for a future refactor to decouple envs from the core.
+ "gymnasium>=1.1.1,<2.0.0",
+
+ # Serialization & checkpointing
+ "safetensors>=0.4.3,<1.0.0",
+
+ # Lightweight utilities
+ "packaging>=24.2,<26.0",
"termcolor>=2.4.0,<4.0.0",
+ "tqdm>=4.66.0,<5.0.0",
+
+ # Build tools (required by opencv-python-headless on some platforms)
+ "cmake>=3.29.0.1,<4.2.0",
+ "setuptools>=71.0.0,<81.0.0",
]
# Optional dependencies
[project.optional-dependencies]
+# ── Feature-scoped extras ──────────────────────────────────
+dataset = [
+ "datasets>=4.0.0,<5.0.0",
+ "pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets
+ "pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets
+ "lerobot[av-dep]",
+ "torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10).
+ "jsonlines>=4.0.0,<5.0.0",
+]
+training = [
+ "lerobot[dataset]",
+ "accelerate>=1.10.0,<2.0.0",
+ "wandb>=0.24.0,<0.25.0",
+]
+hardware = [
+ "pynput>=1.7.8,<1.9.0",
+ "pyserial>=3.5,<4.0",
+ "deepdiff>=7.0.1,<9.0.0",
+]
+viz = [
+ "rerun-sdk>=0.24.0,<0.27.0",
+]
+# ── User-facing composite extras (map to CLI scripts) ─────
+# lerobot-record, lerobot-replay, lerobot-calibrate, lerobot-teleoperate, etc.
+core_scripts = ["lerobot[dataset]", "lerobot[hardware]", "lerobot[viz]"]
+# lerobot-eval -- base evaluation framework. You also need the policy's extra (e.g., lerobot[pi])
+# and the environment's extra (e.g., lerobot[pusht]) if evaluating in simulation.
+evaluation = ["lerobot[av-dep]"]
+# lerobot-dataset-viz, lerobot-imgtransform-viz
+dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
+
# Common
+av-dep = ["av>=15.0.0,<16.0.0"]
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.9.17"]
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
@@ -104,6 +133,7 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
+diffusers-dep = ["diffusers>=0.27.2,<0.36.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster.
@@ -136,28 +166,28 @@ intelrealsense = [
phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"]
# Policies
+diffusion = ["lerobot[diffusers-dep]"]
wallx = [
"lerobot[transformers-dep]",
- "lerobot[peft]",
+ "lerobot[peft-dep]",
"lerobot[scipy-dep]",
"torchdiffeq>=0.2.4,<0.3.0",
"lerobot[qwen-vl-utils-dep]",
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
-smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
-multi_task_dit = ["lerobot[transformers-dep]"]
+smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
+multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"]
groot = [
"lerobot[transformers-dep]",
- "lerobot[peft]",
+ "lerobot[peft-dep]",
+ "lerobot[diffusers-dep]",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
- "safetensors>=0.4.3,<1.0.0",
- "Pillow>=10.0.0,<13.0.0",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
-sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
+sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
@@ -166,31 +196,42 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
# Development
-dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
+dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1"]
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
-aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
-pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
-libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
-metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
+aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
+pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
+libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
+metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
# All
all = [
+ # Feature-scoped extras
+ "lerobot[dataset]",
+ "lerobot[training]",
+ "lerobot[hardware]",
+ "lerobot[viz]",
# NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through
# multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly
# helps pip's resolver converge by constraining scipy early, before it encounters
# the loose scipy requirements from transitive deps like dm-control and metaworld.
"scipy>=1.14.0,<2.0.0",
"lerobot[dynamixel]",
+ "lerobot[feetech]",
+ "lerobot[damiao]",
+ "lerobot[robstride]",
"lerobot[gamepad]",
"lerobot[hopejr]",
"lerobot[lekiwi]",
+ "lerobot[openarms]",
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
+ "lerobot[diffusion]",
+ "lerobot[multi_task_dit]",
"lerobot[wallx]",
"lerobot[pi]",
"lerobot[smolvla]",
@@ -267,7 +308,9 @@ ignore = [
]
[tool.ruff.lint.per-file-ignores]
-"__init__.py" = ["F401", "F403"]
+"__init__.py" = ["F401", "F403", "E402"]
+# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect
+"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
[tool.ruff.lint.isort]
diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py
index eec574296..df43e7172 100644
--- a/src/lerobot/__init__.py
+++ b/src/lerobot/__init__.py
@@ -13,188 +13,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
"""
-This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
-We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
+LeRobot -- PyTorch library for real-world robotics.
-Example:
- ```python
- import lerobot
- print(lerobot.available_envs)
- print(lerobot.available_tasks_per_env)
- print(lerobot.available_datasets)
- print(lerobot.available_datasets_per_env)
- print(lerobot.available_real_world_datasets)
- print(lerobot.available_policies)
- print(lerobot.available_policies_per_env)
- print(lerobot.available_robots)
- print(lerobot.available_cameras)
- print(lerobot.available_motors)
- ```
+Provides datasets, pretrained policies, and tools for training, evaluation,
+data collection, and robot control. Integrates with Hugging Face Hub for
+model and dataset sharing.
-When implementing a new dataset loadable with LeRobotDataset follow these steps:
-- Update `available_datasets_per_env` in `lerobot/__init__.py`
+The base install is intentionally lightweight. Feature-specific dependencies
+are gated behind optional extras::
-When implementing a new environment (e.g. `gym_aloha`), follow these steps:
-- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
-
-When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
-- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
-- Set the required `name` class attribute.
-- Update variables in `tests/test_available.py` by importing your new Policy class
+ pip install 'lerobot[dataset]' # dataset loading & creation
+ pip install 'lerobot[training]' # training loop + wandb
+ pip install 'lerobot[hardware]' # real robot control
+ pip install 'lerobot[core_scripts]' # dataset + hardware + viz (record, replay, calibrate, etc.)
+ pip install 'lerobot[all]' # everything
"""
-import itertools
+from lerobot.__version__ import __version__
-from lerobot.__version__ import __version__ # noqa: F401
-
-# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
-# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
-# a yaml file AND a environment name. The difference should be more obvious.
-available_tasks_per_env = {
- "aloha": [
- "AlohaInsertion-v0",
- "AlohaTransferCube-v0",
+# Maps optional extras to the CLI entry-points they unlock.
+available_extras: dict[str, list[str]] = {
+ "dataset": ["lerobot-dataset-viz", "lerobot-imgtransform-viz", "lerobot-edit-dataset"],
+ "training": ["lerobot-train"],
+ "hardware": [
+ "lerobot-calibrate",
+ "lerobot-find-port",
+ "lerobot-find-cameras",
+ "lerobot-find-joint-limits",
+ "lerobot-setup-motors",
],
- "pusht": ["PushT-v0"],
-}
-available_envs = list(available_tasks_per_env.keys())
-
-available_datasets_per_env = {
- "aloha": [
- "lerobot/aloha_sim_insertion_human",
- "lerobot/aloha_sim_insertion_scripted",
- "lerobot/aloha_sim_transfer_cube_human",
- "lerobot/aloha_sim_transfer_cube_scripted",
- "lerobot/aloha_sim_insertion_human_image",
- "lerobot/aloha_sim_insertion_scripted_image",
- "lerobot/aloha_sim_transfer_cube_human_image",
- "lerobot/aloha_sim_transfer_cube_scripted_image",
- ],
- # TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
- # coupled with tests.
- "pusht": ["lerobot/pusht", "lerobot/pusht_image"],
+ "core_scripts": ["lerobot-record", "lerobot-replay", "lerobot-teleoperate"],
+ "evaluation": ["lerobot-eval"],
}
-available_real_world_datasets = [
- "lerobot/aloha_mobile_cabinet",
- "lerobot/aloha_mobile_chair",
- "lerobot/aloha_mobile_elevator",
- "lerobot/aloha_mobile_shrimp",
- "lerobot/aloha_mobile_wash_pan",
- "lerobot/aloha_mobile_wipe_wine",
- "lerobot/aloha_static_battery",
- "lerobot/aloha_static_candy",
- "lerobot/aloha_static_coffee",
- "lerobot/aloha_static_coffee_new",
- "lerobot/aloha_static_cups_open",
- "lerobot/aloha_static_fork_pick_up",
- "lerobot/aloha_static_pingpong_test",
- "lerobot/aloha_static_pro_pencil",
- "lerobot/aloha_static_screw_driver",
- "lerobot/aloha_static_tape",
- "lerobot/aloha_static_thread_velcro",
- "lerobot/aloha_static_towel",
- "lerobot/aloha_static_vinh_cup",
- "lerobot/aloha_static_vinh_cup_left",
- "lerobot/aloha_static_ziploc_slide",
- "lerobot/umi_cup_in_the_wild",
- "lerobot/unitreeh1_fold_clothes",
- "lerobot/unitreeh1_rearrange_objects",
- "lerobot/unitreeh1_two_robot_greeting",
- "lerobot/unitreeh1_warehouse",
- "lerobot/nyu_rot_dataset",
- "lerobot/utokyo_saytap",
- "lerobot/imperialcollege_sawyer_wrist_cam",
- "lerobot/utokyo_xarm_bimanual",
- "lerobot/tokyo_u_lsmo",
- "lerobot/utokyo_pr2_opening_fridge",
- "lerobot/cmu_franka_exploration_dataset",
- "lerobot/cmu_stretch",
- "lerobot/asu_table_top",
- "lerobot/utokyo_pr2_tabletop_manipulation",
- "lerobot/utokyo_xarm_pick_and_place",
- "lerobot/ucsd_kitchen_dataset",
- "lerobot/austin_buds_dataset",
- "lerobot/dlr_sara_grid_clamp",
- "lerobot/conq_hose_manipulation",
- "lerobot/columbia_cairlab_pusht_real",
- "lerobot/dlr_sara_pour",
- "lerobot/dlr_edan_shared_control",
- "lerobot/ucsd_pick_and_place_dataset",
- "lerobot/berkeley_cable_routing",
- "lerobot/nyu_franka_play_dataset",
- "lerobot/austin_sirius_dataset",
- "lerobot/cmu_play_fusion",
- "lerobot/berkeley_gnm_sac_son",
- "lerobot/nyu_door_opening_surprising_effectiveness",
- "lerobot/berkeley_fanuc_manipulation",
- "lerobot/jaco_play",
- "lerobot/viola",
- "lerobot/kaist_nonprehensile",
- "lerobot/berkeley_mvp",
- "lerobot/uiuc_d3field",
- "lerobot/berkeley_gnm_recon",
- "lerobot/austin_sailor_dataset",
- "lerobot/utaustin_mutex",
- "lerobot/roboturk",
- "lerobot/stanford_hydra_dataset",
- "lerobot/berkeley_autolab_ur5",
- "lerobot/stanford_robocook",
- "lerobot/toto",
- "lerobot/fmb",
- "lerobot/droid_100",
- "lerobot/berkeley_rpt",
- "lerobot/stanford_kuka_multimodal_dataset",
- "lerobot/iamlab_cmu_pickup_insert",
- "lerobot/taco_play",
- "lerobot/berkeley_gnm_cory_hall",
- "lerobot/usc_cloth_sim",
-]
-
-available_datasets = sorted(
- set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
-)
-
-# lists all available policies from `lerobot/policies`
-available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
-
-# lists all available robots from `lerobot/robots`
-available_robots = [
- "koch",
- "koch_bimanual",
- "aloha",
- "so100",
- "so101",
-]
-
-# lists all available cameras from `lerobot/cameras`
-available_cameras = [
- "opencv",
- "intelrealsense",
-]
-
-# lists all available motors from `lerobot/motors`
-available_motors = [
- "dynamixel",
- "feetech",
-]
-
-# keys and values refer to yaml files
-available_policies_per_env = {
- "aloha": ["act"],
- "pusht": ["diffusion", "vqbet"],
- "koch_real": ["act_koch_real"],
- "aloha_real": ["act_aloha_real"],
-}
-
-env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
-env_dataset_pairs = [
- (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
-]
-env_dataset_policy_triplets = [
- (env, dataset, policy)
- for env, datasets in available_datasets_per_env.items()
- for dataset in datasets
- for policy in available_policies_per_env[env]
-]
+__all__ = ["__version__", "available_extras"]
diff --git a/src/lerobot/async_inference/__init__.py b/src/lerobot/async_inference/__init__.py
new file mode 100644
index 000000000..8d7a22584
--- /dev/null
+++ b/src/lerobot/async_inference/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Async inference server/client.
+
+Requires: ``pip install 'lerobot[async]'``
+
+Available modules (import directly)::
+
+ from lerobot.async_inference.policy_server import ...
+ from lerobot.async_inference.robot_client import ...
+"""
+
+from lerobot.utils.import_utils import require_package
+
+require_package("grpcio", extra="async", import_name="grpc")
+
+__all__: list[str] = []
diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py
index 9dd44eb44..4931c68c5 100644
--- a/src/lerobot/async_inference/helpers.py
+++ b/src/lerobot/async_inference/helpers.py
@@ -22,8 +22,7 @@ from typing import Any
import torch
-from lerobot.configs.types import PolicyFeature
-from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
+from lerobot.configs import PolicyFeature
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ( # noqa: F401
@@ -36,6 +35,7 @@ from lerobot.policies import ( # noqa: F401
)
from lerobot.robots.robot import Robot
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
+from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.utils import init_logging
Action = torch.Tensor
diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py
index 3f63929df..787d39abf 100644
--- a/src/lerobot/async_inference/policy_server.py
+++ b/src/lerobot/async_inference/policy_server.py
@@ -38,7 +38,7 @@ import draccus
import grpc
import torch
-from lerobot.policies.factory import get_policy_class, make_pre_post_processors
+from lerobot.policies import get_policy_class, make_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.transport import (
services_pb2, # type: ignore
diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py
index 0ee70a0e6..a250a08fb 100644
--- a/src/lerobot/async_inference/robot_client.py
+++ b/src/lerobot/async_inference/robot_client.py
@@ -47,8 +47,8 @@ import draccus
import grpc
import torch
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
-from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
+from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
+from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
diff --git a/src/lerobot/cameras/__init__.py b/src/lerobot/cameras/__init__.py
index cbf1f11bf..3598d58aa 100644
--- a/src/lerobot/cameras/__init__.py
+++ b/src/lerobot/cameras/__init__.py
@@ -15,3 +15,9 @@
from .camera import Camera
from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation
from .utils import make_cameras_from_configs
+
+# NOTE: Camera submodule configs and implementations (OpenCVCameraConfig, RealSenseCamera, etc.)
+# are intentionally NOT re-exported here to avoid pulling backend-specific dependencies.
+# Import from submodules: ``from lerobot.cameras.opencv import OpenCVCameraConfig``
+
+__all__ = ["Camera", "CameraConfig", "ColorMode", "Cv2Backends", "Cv2Rotation", "make_cameras_from_configs"]
diff --git a/src/lerobot/cameras/reachy2_camera/__init__.py b/src/lerobot/cameras/reachy2_camera/__init__.py
index 72e45f32a..d979a7db5 100644
--- a/src/lerobot/cameras/reachy2_camera/__init__.py
+++ b/src/lerobot/cameras/reachy2_camera/__init__.py
@@ -14,3 +14,5 @@
from .configuration_reachy2_camera import Reachy2CameraConfig
from .reachy2_camera import Reachy2Camera
+
+__all__ = ["Reachy2Camera", "Reachy2CameraConfig"]
diff --git a/src/lerobot/cameras/realsense/__init__.py b/src/lerobot/cameras/realsense/__init__.py
index 67f2f4000..eb20c9973 100644
--- a/src/lerobot/cameras/realsense/__init__.py
+++ b/src/lerobot/cameras/realsense/__init__.py
@@ -14,3 +14,5 @@
from .camera_realsense import RealSenseCamera
from .configuration_realsense import RealSenseCameraConfig
+
+__all__ = ["RealSenseCamera", "RealSenseCameraConfig"]
diff --git a/src/lerobot/cameras/zmq/image_server.py b/src/lerobot/cameras/zmq/image_server.py
index 8222b9fee..b8b6f8e74 100644
--- a/src/lerobot/cameras/zmq/image_server.py
+++ b/src/lerobot/cameras/zmq/image_server.py
@@ -31,8 +31,8 @@ import cv2
import numpy as np
import zmq
-from lerobot.cameras.configs import ColorMode
-from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
+from ..configs import ColorMode
+from ..opencv import OpenCVCamera, OpenCVCameraConfig
logger = logging.getLogger(__name__)
diff --git a/src/lerobot/common/__init__.py b/src/lerobot/common/__init__.py
new file mode 100644
index 000000000..782ef5b77
--- /dev/null
+++ b/src/lerobot/common/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Cross-cutting modules that bridge multiple lerobot packages.
+
+Unlike ``lerobot.utils`` (which must remain dependency-free), modules here
+are allowed to import from ``lerobot.policies``, ``lerobot.processor``,
+``lerobot.configs``, etc. They are deliberately NOT re-exported from the
+top-level ``lerobot`` package.
+
+Available modules (import directly)::
+
+ from lerobot.common.control_utils import predict_action, ...
+ from lerobot.common.train_utils import save_checkpoint, ...
+ from lerobot.common.wandb_utils import WandBLogger, ...
+"""
+
+__all__: list[str] = []
diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/common/control_utils.py
similarity index 95%
rename from src/lerobot/utils/control_utils.py
rename to src/lerobot/common/control_utils.py
index 94cd82fa1..530955078 100644
--- a/src/lerobot/utils/control_utils.py
+++ b/src/lerobot/common/control_utils.py
@@ -12,26 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
+
########################################################################################
# Utilities
########################################################################################
-
-
import logging
import traceback
from contextlib import nullcontext
from copy import copy
from functools import cache
-from typing import Any
+from typing import TYPE_CHECKING, Any
import numpy as np
import torch
-from deepdiff import DeepDiff
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.utils import DEFAULT_FEATURES
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.utils import prepare_observation_for_inference
+from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
+
+if TYPE_CHECKING:
+ from lerobot.datasets import LeRobotDataset
from lerobot.processor import PolicyProcessorPipeline
from lerobot.robots import Robot
from lerobot.types import PolicyAction
@@ -218,6 +217,13 @@ def sanity_check_dataset_robot_compatibility(
Raises:
ValueError: If any of the checked metadata fields do not match.
"""
+ from lerobot.utils.import_utils import require_package
+
+ require_package("deepdiff", extra="hardware")
+ from deepdiff import DeepDiff
+
+ from lerobot.utils.constants import DEFAULT_FEATURES
+
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/common/train_utils.py
similarity index 95%
rename from src/lerobot/utils/train_utils.py
rename to src/lerobot/common/train_utils.py
index 02f6aebb3..3e96e1330 100644
--- a/src/lerobot/utils/train_utils.py
+++ b/src/lerobot/common/train_utils.py
@@ -19,10 +19,13 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.configs.train import TrainPipelineConfig
-from lerobot.datasets.io_utils import load_json, write_json
-from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
-from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
-from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.optim import (
+ load_optimizer_state,
+ load_scheduler_state,
+ save_optimizer_state,
+ save_scheduler_state,
+)
+from lerobot.policies import PreTrainedPolicy
from lerobot.processor import PolicyProcessorPipeline
from lerobot.utils.constants import (
CHECKPOINTS_DIR,
@@ -31,6 +34,7 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR,
TRAINING_STEP,
)
+from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.random_utils import load_rng_state, save_rng_state
diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/common/wandb_utils.py
similarity index 100%
rename from src/lerobot/rl/wandb_utils.py
rename to src/lerobot/common/wandb_utils.py
diff --git a/src/lerobot/configs/__init__.py b/src/lerobot/configs/__init__.py
new file mode 100644
index 000000000..3ddaec1af
--- /dev/null
+++ b/src/lerobot/configs/__init__.py
@@ -0,0 +1,47 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Public API for lerobot configuration types and base config classes.
+
+NOTE: TrainPipelineConfig, EvalPipelineConfig, and TrainRLServerPipelineConfig
+are intentionally NOT re-exported here to avoid circular dependencies
+(they import lerobot.envs and lerobot.policies at module level).
+Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
+"""
+
+from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
+from .policies import PreTrainedConfig
+from .types import (
+ FeatureType,
+ NormalizationMode,
+ PipelineFeatureType,
+ PolicyFeature,
+ RTCAttentionSchedule,
+)
+
+__all__ = [
+ # Types
+ "FeatureType",
+ "NormalizationMode",
+ "PipelineFeatureType",
+ "PolicyFeature",
+ "RTCAttentionSchedule",
+ # Config classes
+ "DatasetConfig",
+ "EvalConfig",
+ "PeftConfig",
+ "PreTrainedConfig",
+ "WandBConfig",
+]
diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py
index d6ad665bf..b05e96fde 100644
--- a/src/lerobot/configs/default.py
+++ b/src/lerobot/configs/default.py
@@ -16,8 +16,8 @@
from dataclasses import dataclass, field
-from lerobot.datasets.transforms import ImageTransformsConfig
-from lerobot.datasets.video_utils import get_safe_default_codec
+from lerobot.transforms import ImageTransformsConfig
+from lerobot.utils.import_utils import get_safe_default_codec
@dataclass
diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py
index da8bee6b2..d1cebd27f 100644
--- a/src/lerobot/configs/eval.py
+++ b/src/lerobot/configs/eval.py
@@ -19,8 +19,9 @@ from pathlib import Path
from lerobot import envs, policies # noqa: F401
from lerobot.configs import parser
-from lerobot.configs.default import EvalConfig
-from lerobot.configs.policies import PreTrainedConfig
+
+from .default import EvalConfig
+from .policies import PreTrainedConfig
logger = getLogger(__name__)
diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py
index ce567b8f5..91701af6d 100644
--- a/src/lerobot/configs/policies.py
+++ b/src/lerobot/configs/policies.py
@@ -26,13 +26,13 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
-from lerobot.configs.types import FeatureType, PolicyFeature
-from lerobot.optim.optimizers import OptimizerConfig
-from lerobot.optim.schedulers import LRSchedulerConfig
+from lerobot.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.device_utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.utils.hub import HubMixin
+from .types import FeatureType, PolicyFeature
+
T = TypeVar("T", bound="PreTrainedConfig")
logger = getLogger(__name__)
diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py
index 8b8aedb26..d754a0847 100644
--- a/src/lerobot/configs/train.py
+++ b/src/lerobot/configs/train.py
@@ -24,12 +24,12 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.configs import parser
-from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.optim import OptimizerConfig
-from lerobot.optim.schedulers import LRSchedulerConfig
+from lerobot.optim import LRSchedulerConfig, OptimizerConfig
from lerobot.utils.hub import HubMixin
+from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
+from .policies import PreTrainedConfig
+
TRAIN_CONFIG_NAME = "train_config.json"
diff --git a/src/lerobot/data_processing/__init__.py b/src/lerobot/data_processing/__init__.py
index 2f76d5676..cd55d46fc 100644
--- a/src/lerobot/data_processing/__init__.py
+++ b/src/lerobot/data_processing/__init__.py
@@ -11,3 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""
+Data processing utilities (annotation tools, dataset transformations).
+
+Available sub-modules (import directly)::
+
+ from lerobot.data_processing.sarm_annotations import ...
+"""
+
+__all__: list[str] = []
diff --git a/src/lerobot/data_processing/sarm_annotations/__init__.py b/src/lerobot/data_processing/sarm_annotations/__init__.py
index 2f76d5676..cd4c38f33 100644
--- a/src/lerobot/data_processing/sarm_annotations/__init__.py
+++ b/src/lerobot/data_processing/sarm_annotations/__init__.py
@@ -11,3 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+"""
+SARM subtask annotation tools.
+
+Available modules (import directly)::
+
+ from lerobot.data_processing.sarm_annotations.subtask_annotation import ...
+"""
+
+__all__: list[str] = []
diff --git a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py
index 8f3a65e39..b26257d44 100644
--- a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py
+++ b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py
@@ -76,7 +76,7 @@ import torch
from pydantic import BaseModel, Field
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
# Pydantic Models for SARM Subtask Annotation
@@ -746,8 +746,7 @@ def save_annotations_to_dataset(
dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse"
):
"""Save annotations to LeRobot dataset parquet format."""
- from lerobot.datasets.io_utils import load_episodes
- from lerobot.datasets.utils import DEFAULT_EPISODES_PATH
+ from lerobot.datasets import DEFAULT_EPISODES_PATH, load_episodes
episodes_dataset = load_episodes(dataset_path)
if not episodes_dataset or len(episodes_dataset) == 0:
@@ -841,7 +840,7 @@ def generate_auto_sparse_annotations(
def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]:
"""Load annotations from LeRobot dataset parquet files."""
- from lerobot.datasets.io_utils import load_episodes
+ from lerobot.datasets import load_episodes
episodes_dataset = load_episodes(dataset_path)
if not episodes_dataset or len(episodes_dataset) == 0:
diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py
index 42c4ab810..6c42959a5 100644
--- a/src/lerobot/datasets/__init__.py
+++ b/src/lerobot/datasets/__init__.py
@@ -15,19 +15,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.multi_dataset import MultiLeRobotDataset
-from lerobot.datasets.sampler import EpisodeAwareSampler
-from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
-from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
+from lerobot.utils.import_utils import require_package
+
+require_package("datasets", extra="dataset")
+require_package("av", extra="dataset")
+
+from .aggregate import aggregate_datasets
+from .compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
+from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
+from .dataset_tools import (
+ add_features,
+ convert_image_to_video_dataset,
+ delete_episodes,
+ merge_datasets,
+ modify_features,
+ modify_tasks,
+ recompute_stats,
+ remove_feature,
+ split_dataset,
+)
+from .factory import make_dataset, resolve_delta_timestamps
+from .image_writer import safe_stop_image_writer
+from .io_utils import load_episodes, write_stats
+from .lerobot_dataset import LeRobotDataset
+from .multi_dataset import MultiLeRobotDataset
+from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
+from .sampler import EpisodeAwareSampler
+from .streaming_dataset import StreamingLeRobotDataset
+from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
+from .video_utils import VideoEncodingManager
+
+# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
+# and legacy migration constants are intentionally NOT re-exported here.
+# Import directly: ``from lerobot.datasets.io_utils import ...``
__all__ = [
+ "CODEBASE_VERSION",
+ "DEFAULT_EPISODES_PATH",
+ "DEFAULT_QUANTILES",
"EpisodeAwareSampler",
- "ImageTransforms",
- "ImageTransformsConfig",
"LeRobotDataset",
"LeRobotDatasetMetadata",
"MultiLeRobotDataset",
"StreamingLeRobotDataset",
+ "VideoEncodingManager",
+ "add_features",
+ "aggregate_datasets",
+ "aggregate_pipeline_dataset_features",
+ "aggregate_stats",
+ "convert_image_to_video_dataset",
+ "create_initial_features",
+ "create_lerobot_dataset_card",
+ "delete_episodes",
+ "get_feature_stats",
+ "load_episodes",
+ "make_dataset",
+ "merge_datasets",
+ "modify_features",
+ "modify_tasks",
+ "recompute_stats",
+ "remove_feature",
+ "resolve_delta_timestamps",
+ "safe_stop_image_writer",
+ "split_dataset",
+ "write_stats",
]
diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py
index 66f055f04..0da1da964 100644
--- a/src/lerobot/datasets/aggregate.py
+++ b/src/lerobot/datasets/aggregate.py
@@ -23,10 +23,10 @@ import datasets
import pandas as pd
import tqdm
-from lerobot.datasets.compute_stats import aggregate_stats
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.datasets.feature_utils import get_hf_features_from_features
-from lerobot.datasets.io_utils import (
+from .compute_stats import aggregate_stats
+from .dataset_metadata import LeRobotDatasetMetadata
+from .feature_utils import get_hf_features_from_features
+from .io_utils import (
get_file_size_in_mb,
get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
@@ -34,7 +34,7 @@ from lerobot.datasets.io_utils import (
write_stats,
write_tasks,
)
-from lerobot.datasets.utils import (
+from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
@@ -43,7 +43,7 @@ from lerobot.datasets.utils import (
DEFAULT_VIDEO_PATH,
update_chunk_file_indices,
)
-from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
+from .video_utils import concatenate_video_files, get_video_duration_in_s
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py
index 03eefe40e..f489c84a7 100644
--- a/src/lerobot/datasets/compute_stats.py
+++ b/src/lerobot/datasets/compute_stats.py
@@ -19,9 +19,11 @@ import logging
import numpy as np
-from lerobot.datasets.io_utils import load_image_as_numpy
+from lerobot.processor import RelativeActionsProcessorStep
from lerobot.utils.constants import ACTION, OBS_STATE
+from .io_utils import load_image_as_numpy
+
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
@@ -696,8 +698,6 @@ def compute_relative_action_stats(
ValueError: If the dataset has fewer frames than ``chunk_size``.
RuntimeError: If no valid (single-episode) chunks are found.
"""
- from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep
-
if exclude_joints is None:
exclude_joints = []
diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py
index d79f4bfba..8bf67fa39 100644
--- a/src/lerobot/datasets/dataset_metadata.py
+++ b/src/lerobot/datasets/dataset_metadata.py
@@ -23,9 +23,13 @@ import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import snapshot_download
-from lerobot.datasets.compute_stats import aggregate_stats
-from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info
-from lerobot.datasets.io_utils import (
+from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
+from lerobot.utils.feature_utils import _validate_feature_names
+from lerobot.utils.utils import flatten_dict
+
+from .compute_stats import aggregate_stats
+from .feature_utils import create_empty_dataset_info
+from .io_utils import (
get_file_size_in_mb,
load_episodes,
load_info,
@@ -37,19 +41,16 @@ from lerobot.datasets.io_utils import (
write_stats,
write_tasks,
)
-from lerobot.datasets.utils import (
+from .utils import (
DEFAULT_EPISODES_PATH,
- DEFAULT_FEATURES,
INFO_PATH,
check_version_compatibility,
- flatten_dict,
get_safe_version,
has_legacy_hub_download_metadata,
is_valid_version,
update_chunk_file_indices,
)
-from lerobot.datasets.video_utils import get_video_info
-from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
+from .video_utils import get_video_info
CODEBASE_VERSION = "v3.0"
diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py
index 3720a5084..fc7ce36ce 100644
--- a/src/lerobot/datasets/dataset_reader.py
+++ b/src/lerobot/datasets/dataset_reader.py
@@ -21,17 +21,17 @@ from pathlib import Path
import datasets
import torch
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.datasets.feature_utils import (
+from .dataset_metadata import LeRobotDatasetMetadata
+from .feature_utils import (
check_delta_timestamps,
get_delta_indices,
get_hf_features_from_features,
)
-from lerobot.datasets.io_utils import (
+from .io_utils import (
hf_transform_to_torch,
load_nested_dataset,
)
-from lerobot.datasets.video_utils import decode_video_frames
+from .video_utils import decode_video_frames
class DatasetReader:
diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py
index 16bf24822..cbf4e5c49 100644
--- a/src/lerobot/datasets/dataset_tools.py
+++ b/src/lerobot/datasets/dataset_tools.py
@@ -36,22 +36,25 @@ import pyarrow.parquet as pq
import torch
from tqdm import tqdm
-from lerobot.datasets.aggregate import aggregate_datasets
-from lerobot.datasets.compute_stats import (
+from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
+from lerobot.utils.utils import flatten_dict
+
+from .aggregate import aggregate_datasets
+from .compute_stats import (
aggregate_stats,
compute_episode_stats,
compute_relative_action_stats,
)
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.datasets.io_utils import (
+from .dataset_metadata import LeRobotDatasetMetadata
+from .io_utils import (
get_parquet_file_size_in_mb,
load_episodes,
write_info,
write_stats,
write_tasks,
)
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.utils import (
+from .lerobot_dataset import LeRobotDataset
+from .utils import (
DATA_DIR,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -59,8 +62,7 @@ from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
update_chunk_file_indices,
)
-from lerobot.datasets.video_utils import encode_video_frames, get_video_info
-from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
+from .video_utils import encode_video_frames, get_video_info
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
@@ -829,8 +831,6 @@ def _copy_and_reindex_episodes_metadata(
data_metadata: Dict mapping new episode index to its data file metadata
video_metadata: Optional dict mapping new episode index to its video metadata
"""
- from lerobot.datasets.utils import flatten_dict
-
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
@@ -922,8 +922,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
"""
- from lerobot.datasets.feature_utils import get_hf_features_from_features
- from lerobot.datasets.io_utils import embed_images
+ from .feature_utils import get_hf_features_from_features
+ from .io_utils import embed_images
hf_features = get_hf_features_from_features(meta.features)
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
@@ -1367,7 +1367,7 @@ def _copy_data_without_images(
episode_indices: Episodes to include
img_keys: Image keys to remove
"""
- from lerobot.datasets.utils import DATA_DIR
+ from .utils import DATA_DIR
data_dir = src_dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py
index 787ecd337..60ec9e348 100644
--- a/src/lerobot/datasets/dataset_writer.py
+++ b/src/lerobot/datasets/dataset_writer.py
@@ -31,26 +31,26 @@ import PIL.Image
import pyarrow.parquet as pq
import torch
-from lerobot.datasets.compute_stats import compute_episode_stats
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.datasets.feature_utils import (
+from .compute_stats import compute_episode_stats
+from .dataset_metadata import LeRobotDatasetMetadata
+from .feature_utils import (
get_hf_features_from_features,
validate_episode_buffer,
validate_frame,
)
-from lerobot.datasets.image_writer import AsyncImageWriter, write_image
-from lerobot.datasets.io_utils import (
+from .image_writer import AsyncImageWriter, write_image
+from .io_utils import (
embed_images,
get_file_size_in_mb,
load_episodes,
write_info,
)
-from lerobot.datasets.utils import (
+from .utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
update_chunk_file_indices,
)
-from lerobot.datasets.video_utils import (
+from .video_utils import (
StreamingVideoEncoder,
concatenate_video_files,
encode_video_frames,
diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py
index 76ece8961..040cba5cb 100644
--- a/src/lerobot/datasets/factory.py
+++ b/src/lerobot/datasets/factory.py
@@ -18,19 +18,15 @@ from pprint import pformat
import torch
-from lerobot.configs.policies import PreTrainedConfig
+from lerobot.configs import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.multi_dataset import MultiLeRobotDataset
-from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
-from lerobot.datasets.transforms import ImageTransforms
-from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
+from lerobot.transforms import ImageTransforms
+from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD
-IMAGENET_STATS = {
- "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
- "std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
-}
+from .dataset_metadata import LeRobotDatasetMetadata
+from .lerobot_dataset import LeRobotDataset
+from .multi_dataset import MultiLeRobotDataset
+from .streaming_dataset import StreamingLeRobotDataset
def resolve_delta_timestamps(
diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py
index 46154d92a..b05dbf2cc 100644
--- a/src/lerobot/datasets/feature_utils.py
+++ b/src/lerobot/datasets/feature_utils.py
@@ -14,23 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pprint import pformat
-from typing import Any
import datasets
import numpy as np
from PIL import Image as PILImage
-from lerobot.configs.types import FeatureType, PolicyFeature
-from lerobot.datasets.utils import (
+from lerobot.utils.constants import DEFAULT_FEATURES
+from lerobot.utils.utils import is_valid_numpy_dtype_string
+
+from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
- DEFAULT_FEATURES,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
)
-from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
-from lerobot.utils.utils import is_valid_numpy_dtype_string
def get_hf_features_from_features(features: dict) -> datasets.Features:
@@ -71,199 +69,6 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
return datasets.Features(hf_features)
-def _validate_feature_names(features: dict[str, dict]) -> None:
- """Validate that feature names do not contain invalid characters.
-
- Args:
- features (dict): The LeRobot features dictionary.
-
- Raises:
- ValueError: If any feature name contains '/'.
- """
- invalid_features = {name: ft for name, ft in features.items() if "/" in name}
- if invalid_features:
- raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
-
-
-def hw_to_dataset_features(
- hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
-) -> dict[str, dict]:
- """Convert hardware-specific features to a LeRobot dataset feature dictionary.
-
- This function takes a dictionary describing hardware outputs (like joint states
- or camera image shapes) and formats it into the standard LeRobot feature
- specification.
-
- Args:
- hw_features (dict): Dictionary mapping feature names to their type (float for
- joints) or shape (tuple for images).
- prefix (str): The prefix to add to the feature keys (e.g., "observation"
- or "action").
- use_video (bool): If True, image features are marked as "video", otherwise "image".
-
- Returns:
- dict: A LeRobot features dictionary.
- """
- features = {}
- joint_fts = {
- key: ftype
- for key, ftype in hw_features.items()
- if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
- }
- cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
-
- if joint_fts and prefix == ACTION:
- features[prefix] = {
- "dtype": "float32",
- "shape": (len(joint_fts),),
- "names": list(joint_fts),
- }
-
- if joint_fts and prefix == OBS_STR:
- features[f"{prefix}.state"] = {
- "dtype": "float32",
- "shape": (len(joint_fts),),
- "names": list(joint_fts),
- }
-
- for key, shape in cam_fts.items():
- features[f"{prefix}.images.{key}"] = {
- "dtype": "video" if use_video else "image",
- "shape": shape,
- "names": ["height", "width", "channels"],
- }
-
- _validate_feature_names(features)
- return features
-
-
-def build_dataset_frame(
- ds_features: dict[str, dict], values: dict[str, Any], prefix: str
-) -> dict[str, np.ndarray]:
- """Construct a single data frame from raw values based on dataset features.
-
- A "frame" is a dictionary containing all the data for a single timestep,
- formatted as numpy arrays according to the feature specification.
-
- Args:
- ds_features (dict): The LeRobot dataset features dictionary.
- values (dict): A dictionary of raw values from the hardware/environment.
- prefix (str): The prefix to filter features by (e.g., "observation"
- or "action").
-
- Returns:
- dict: A dictionary representing a single frame of data.
- """
- frame = {}
- for key, ft in ds_features.items():
- if key in DEFAULT_FEATURES or not key.startswith(prefix):
- continue
- elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
- frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
- elif ft["dtype"] in ["image", "video"]:
- frame[key] = values[key.removeprefix(f"{prefix}.images.")]
-
- return frame
-
-
-def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
- """Convert dataset features to policy features.
-
- This function transforms the dataset's feature specification into a format
- that a policy can use, classifying features by type (e.g., visual, state,
- action) and ensuring correct shapes (e.g., channel-first for images).
-
- Args:
- features (dict): The LeRobot dataset features dictionary.
-
- Returns:
- dict: A dictionary mapping feature keys to `PolicyFeature` objects.
-
- Raises:
- ValueError: If an image feature does not have a 3D shape.
- """
- # TODO(aliberts): Implement "type" in dataset features and simplify this
- policy_features = {}
- for key, ft in features.items():
- shape = ft["shape"]
- if ft["dtype"] in ["image", "video"]:
- type = FeatureType.VISUAL
- if len(shape) != 3:
- raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
-
- names = ft["names"]
- # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
- if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
- shape = (shape[2], shape[0], shape[1])
- elif key == OBS_ENV_STATE:
- type = FeatureType.ENV
- elif key.startswith(OBS_STR):
- type = FeatureType.STATE
- elif key.startswith(ACTION):
- type = FeatureType.ACTION
- else:
- continue
-
- policy_features[key] = PolicyFeature(
- type=type,
- shape=shape,
- )
-
- return policy_features
-
-
-def combine_feature_dicts(*dicts: dict) -> dict:
- """Merge LeRobot grouped feature dicts.
-
- - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
- - For others (e.g. `observation.images.*`), the last one wins (if they are identical).
-
- Args:
- *dicts: A variable number of LeRobot feature dictionaries to merge.
-
- Returns:
- dict: A single merged feature dictionary.
-
- Raises:
- ValueError: If there's a dtype mismatch for a feature being merged.
- """
- out: dict = {}
- for d in dicts:
- for key, value in d.items():
- if not isinstance(value, dict):
- out[key] = value
- continue
-
- dtype = value.get("dtype")
- shape = value.get("shape")
- is_vector = (
- dtype not in ("image", "video", "string")
- and isinstance(shape, tuple)
- and len(shape) == 1
- and "names" in value
- )
-
- if is_vector:
- # Initialize or retrieve the accumulating dict for this feature key
- target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
- # Ensure consistent data types across merged entries
- if "dtype" in target and dtype != target["dtype"]:
- raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
-
- # Merge feature names: append only new ones to preserve order without duplicates
- seen = set(target["names"])
- for n in value["names"]:
- if n not in seen:
- target["names"].append(n)
- seen.add(n)
- # Recompute the shape to reflect the updated number of features
- target["shape"] = (len(target["names"]),)
- else:
- # For images/videos and non-1D entries: override with the latest definition
- out[key] = value
- return out
-
-
def create_empty_dataset_info(
codebase_version: str,
fps: int,
diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py
index cee6cfba8..2ee859e97 100644
--- a/src/lerobot/datasets/io_utils.py
+++ b/src/lerobot/datasets/io_utils.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
from pathlib import Path
from typing import Any
@@ -29,7 +28,10 @@ from datasets.table import embed_table_storage
from PIL import Image as PILImage
from torchvision import transforms
-from lerobot.datasets.utils import (
+from lerobot.utils.io_utils import load_json, write_json
+from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
+
+from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_EPISODES_PATH,
DEFAULT_SUBTASKS_PATH,
@@ -37,11 +39,8 @@ from lerobot.datasets.utils import (
EPISODES_DIR,
INFO_PATH,
STATS_PATH,
- flatten_dict,
serialize_dict,
- unflatten_dict,
)
-from lerobot.utils.utils import SuppressProgressBars
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
@@ -116,33 +115,6 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
return dataset
-def load_json(fpath: Path) -> Any:
- """Load data from a JSON file.
-
- Args:
- fpath (Path): Path to the JSON file.
-
- Returns:
- Any: The data loaded from the JSON file.
- """
- with open(fpath) as f:
- return json.load(f)
-
-
-def write_json(data: dict, fpath: Path) -> None:
- """Write data to a JSON file.
-
- Creates parent directories if they don't exist.
-
- Args:
- data (dict): The dictionary to write.
- fpath (Path): The path to the output JSON file.
- """
- fpath.parent.mkdir(exist_ok=True, parents=True)
- with open(fpath, "w") as f:
- json.dump(data, f, indent=4, ensure_ascii=False)
-
-
def write_info(info: dict, local_dir: Path) -> None:
write_json(info, local_dir / INFO_PATH)
diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py
index 2f0154cda..7cda5d677 100644
--- a/src/lerobot/datasets/lerobot_dataset.py
+++ b/src/lerobot/datasets/lerobot_dataset.py
@@ -24,20 +24,21 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
-from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
-from lerobot.datasets.dataset_reader import DatasetReader
-from lerobot.datasets.dataset_writer import DatasetWriter
-from lerobot.datasets.utils import (
+from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
+
+from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
+from .dataset_reader import DatasetReader
+from .dataset_writer import DatasetWriter
+from .utils import (
create_lerobot_dataset_card,
get_safe_version,
is_valid_version,
)
-from lerobot.datasets.video_utils import (
+from .video_utils import (
StreamingVideoEncoder,
get_safe_default_codec,
resolve_vcodec,
)
-from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE
logger = logging.getLogger(__name__)
diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py
index 092443077..b4b7a941d 100644
--- a/src/lerobot/datasets/multi_dataset.py
+++ b/src/lerobot/datasets/multi_dataset.py
@@ -21,12 +21,13 @@ import datasets
import torch
import torch.utils
-from lerobot.datasets.compute_stats import aggregate_stats
-from lerobot.datasets.feature_utils import get_hf_features_from_features
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.video_utils import VideoFrame
from lerobot.utils.constants import HF_LEROBOT_HOME
+from .compute_stats import aggregate_stats
+from .feature_utils import get_hf_features_from_features
+from .lerobot_dataset import LeRobotDataset
+from .video_utils import VideoFrame
+
logger = logging.getLogger(__name__)
diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py
index 96779fdc6..cf02a52ac 100644
--- a/src/lerobot/datasets/pipeline_features.py
+++ b/src/lerobot/datasets/pipeline_features.py
@@ -16,11 +16,11 @@ import re
from collections.abc import Sequence
from typing import Any
-from lerobot.configs.types import PipelineFeatureType
-from lerobot.datasets.feature_utils import hw_to_dataset_features
+from lerobot.configs import PipelineFeatureType
from lerobot.processor import DataProcessorPipeline
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
+from lerobot.utils.feature_utils import hw_to_dataset_features
def create_initial_features(
diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py
index 1767cc79d..f47d71367 100644
--- a/src/lerobot/datasets/streaming_dataset.py
+++ b/src/lerobot/datasets/streaming_dataset.py
@@ -22,20 +22,21 @@ import numpy as np
import torch
from datasets import load_dataset
-from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
-from lerobot.datasets.feature_utils import get_delta_indices
-from lerobot.datasets.io_utils import item_to_torch
-from lerobot.datasets.utils import (
+from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
+
+from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
+from .feature_utils import get_delta_indices
+from .io_utils import item_to_torch
+from .utils import (
check_version_compatibility,
find_float_index,
is_float_in_list,
safe_shard,
)
-from lerobot.datasets.video_utils import (
+from .video_utils import (
VideoDecoderCache,
decode_video_frames_torchcodec,
)
-from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
class LookBackError(Exception):
diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py
index 36e7934ed..c6815e0f5 100644
--- a/src/lerobot/datasets/utils.py
+++ b/src/lerobot/datasets/utils.py
@@ -17,9 +17,7 @@ import contextlib
import importlib.resources
import json
import logging
-from collections.abc import Iterator
from pathlib import Path
-from typing import Any
import datasets
import numpy as np
@@ -28,6 +26,8 @@ import torch
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
+from lerobot.utils.utils import flatten_dict, unflatten_dict
+
V30_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
@@ -93,14 +93,6 @@ LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
-DEFAULT_FEATURES = {
- "timestamp": {"dtype": "float32", "shape": (1,), "names": None},
- "frame_index": {"dtype": "int64", "shape": (1,), "names": None},
- "episode_index": {"dtype": "int64", "shape": (1,), "names": None},
- "index": {"dtype": "int64", "shape": (1,), "names": None},
- "task_index": {"dtype": "int64", "shape": (1,), "names": None},
-}
-
def has_legacy_hub_download_metadata(root: Path) -> bool:
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
@@ -123,59 +115,6 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
return chunk_idx, file_idx
-def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
- """Flatten a nested dictionary by joining keys with a separator.
-
- Example:
- >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
- >>> print(flatten_dict(dct))
- {'a/b': 1, 'a/c/d': 2, 'e': 3}
-
- Args:
- d (dict): The dictionary to flatten.
- parent_key (str): The base key to prepend to the keys in this level.
- sep (str): The separator to use between keys.
-
- Returns:
- dict: A flattened dictionary.
- """
- items = []
- for k, v in d.items():
- new_key = f"{parent_key}{sep}{k}" if parent_key else k
- if isinstance(v, dict):
- items.extend(flatten_dict(v, new_key, sep=sep).items())
- else:
- items.append((new_key, v))
- return dict(items)
-
-
-def unflatten_dict(d: dict, sep: str = "/") -> dict:
- """Unflatten a dictionary with delimited keys into a nested dictionary.
-
- Example:
- >>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3}
- >>> print(unflatten_dict(flat_dct))
- {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
-
- Args:
- d (dict): A dictionary with flattened keys.
- sep (str): The separator used in the keys.
-
- Returns:
- dict: A nested dictionary.
- """
- outdict = {}
- for key, value in d.items():
- parts = key.split(sep)
- d = outdict
- for part in parts[:-1]:
- if part not in d:
- d[part] = {}
- d = d[part]
- d[parts[-1]] = value
- return outdict
-
-
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
"""Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible.
@@ -332,27 +271,6 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
raise ForwardCompatibilityError(repo_id, min(upper_versions))
-def cycle(iterable: Any) -> Iterator[Any]:
- """Create a dataloader-safe cyclical iterator.
-
- This is an equivalent of `itertools.cycle` but is safe for use with
- PyTorch DataLoaders with multiple workers.
- See https://github.com/pytorch/pytorch/issues/23900 for details.
-
- Args:
- iterable: The iterable to cycle over.
-
- Yields:
- Items from the iterable, restarting from the beginning when exhausted.
- """
- iterator = iter(iterable)
- while True:
- try:
- yield next(iterator)
- except StopIteration:
- iterator = iter(iterable)
-
-
def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None:
"""Create a branch on an existing Hugging Face repo.
diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py
index 59c8c7d3e..cabe592d0 100644
--- a/src/lerobot/datasets/video_utils.py
+++ b/src/lerobot/datasets/video_utils.py
@@ -37,6 +37,8 @@ import torchvision
from datasets.features.features import register_feature
from PIL import Image
+from lerobot.utils.import_utils import get_safe_default_codec
+
logger = logging.getLogger(__name__)
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
@@ -116,16 +118,6 @@ def resolve_vcodec(vcodec: str) -> str:
return "libsvtav1"
-def get_safe_default_codec():
- if importlib.util.find_spec("torchcodec"):
- return "torchcodec"
- else:
- logger.warning(
- "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
- )
- return "pyav"
-
-
def decode_video_frames(
video_path: Path | str,
timestamps: list[float],
@@ -271,7 +263,10 @@ class VideoDecoderCache:
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
- raise ImportError("torchcodec is required but not available.")
+ raise ImportError(
+ "'torchcodec' is required but not installed. "
+ "Install it with: pip install 'lerobot[dataset]' (or uv pip install 'lerobot[dataset]')"
+ )
video_path = str(video_path)
@@ -606,7 +601,7 @@ class _CameraEncoderThread(threading.Thread):
self.encoder_threads = encoder_threads
def run(self) -> None:
- from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width
+ from .compute_stats import RunningQuantileStats, auto_downsample_height_width
container = None
output_stream = None
diff --git a/src/lerobot/envs/__init__.py b/src/lerobot/envs/__init__.py
index 183c12325..277fd04f4 100644
--- a/src/lerobot/envs/__init__.py
+++ b/src/lerobot/envs/__init__.py
@@ -12,4 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .configs import AlohaEnv, EnvConfig, HubEnvConfig, PushtEnv # noqa: F401
+# NOTE: gymnasium is currently a core dependency but is a candidate for moving to an
+# optional extra in the future. When that transition happens, uncomment the guard below
+# and update the extra name to the one that will contain gymnasium.
+# from lerobot.utils.import_utils import require_package
+# require_package("gymnasium", extra="", import_name="gymnasium")
+
+from .configs import AlohaEnv, EnvConfig, HILSerlRobotEnvConfig, HubEnvConfig, PushtEnv
+from .factory import make_env, make_env_config, make_env_pre_post_processors
+from .utils import check_env_attributes_and_types, close_envs, env_to_policy_features, preprocess_observation
+
+__all__ = [
+ "AlohaEnv",
+ "EnvConfig",
+ "HILSerlRobotEnvConfig",
+ "HubEnvConfig",
+ "PushtEnv",
+ "check_env_attributes_and_types",
+ "close_envs",
+ "env_to_policy_features",
+ "make_env",
+ "make_env_config",
+ "make_env_pre_post_processors",
+ "preprocess_observation",
+]
diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py
index af5bda33f..2a7c52d45 100644
--- a/src/lerobot/envs/configs.py
+++ b/src/lerobot/envs/configs.py
@@ -23,7 +23,8 @@ import draccus
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
-from lerobot.configs.types import FeatureType, PolicyFeature
+from lerobot.configs import FeatureType, PolicyFeature
+from lerobot.processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, PolicyProcessorPipeline
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.constants import (
@@ -124,8 +125,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def get_env_processors(self):
"""Return (preprocessor, postprocessor) for this env. Default: identity."""
- from lerobot.processor.pipeline import PolicyProcessorPipeline
-
return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[])
@@ -418,7 +417,7 @@ class LiberoEnv(EnvConfig):
return kwargs
def create_envs(self, n_envs: int, use_async_envs: bool = False):
- from lerobot.envs.libero import create_libero_envs
+ from .libero import create_libero_envs
if self.task is None:
raise ValueError("LiberoEnv requires a task to be specified")
@@ -436,9 +435,6 @@ class LiberoEnv(EnvConfig):
)
def get_env_processors(self):
- from lerobot.processor.env_processor import LiberoProcessorStep
- from lerobot.processor.pipeline import PolicyProcessorPipeline
-
return (
PolicyProcessorPipeline(steps=[LiberoProcessorStep()]),
PolicyProcessorPipeline(steps=[]),
@@ -487,7 +483,7 @@ class MetaworldEnv(EnvConfig):
}
def create_envs(self, n_envs: int, use_async_envs: bool = False):
- from lerobot.envs.metaworld import create_metaworld_envs
+ from .metaworld import create_metaworld_envs
if self.task is None:
raise ValueError("MetaWorld requires a task to be specified")
@@ -568,9 +564,6 @@ class IsaaclabArenaEnv(HubEnvConfig):
return {}
def get_env_processors(self):
- from lerobot.processor.env_processor import IsaaclabArenaProcessorStep
- from lerobot.processor.pipeline import PolicyProcessorPipeline
-
state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip())
camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip())
if not state_keys and not camera_keys:
diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py
index 40d5425cc..317cf2e6f 100644
--- a/src/lerobot/envs/factory.py
+++ b/src/lerobot/envs/factory.py
@@ -19,8 +19,8 @@ from typing import Any
import gymnasium as gym
-from lerobot.envs.configs import EnvConfig, HubEnvConfig
-from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
+from .configs import EnvConfig, HubEnvConfig
+from .utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py
index 1b814db52..ec90d0ffd 100644
--- a/src/lerobot/envs/libero.py
+++ b/src/lerobot/envs/libero.py
@@ -29,9 +29,10 @@ from gymnasium import spaces
from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
-from lerobot.envs.utils import _LazyAsyncVectorEnv
from lerobot.types import RobotObservation
+from .utils import _LazyAsyncVectorEnv
+
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings."""
diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py
index 49c775957..1dc513a68 100644
--- a/src/lerobot/envs/metaworld.py
+++ b/src/lerobot/envs/metaworld.py
@@ -25,9 +25,10 @@ import metaworld.policies as policies
import numpy as np
from gymnasium import spaces
-from lerobot.envs.utils import _LazyAsyncVectorEnv
from lerobot.types import RobotObservation
+from .utils import _LazyAsyncVectorEnv
+
# ---- Load configuration data from the external JSON file ----
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
try:
diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py
index ff5f53735..b0d834a05 100644
--- a/src/lerobot/envs/utils.py
+++ b/src/lerobot/envs/utils.py
@@ -27,11 +27,12 @@ import torch
from huggingface_hub import hf_hub_download, snapshot_download
from torch import Tensor
-from lerobot.configs.types import FeatureType, PolicyFeature
-from lerobot.envs.configs import EnvConfig
+from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import get_channel_first_image_shape
+from .configs import EnvConfig
+
def _convert_nested_dict(d):
result = {}
diff --git a/src/lerobot/model/__init__.py b/src/lerobot/model/__init__.py
new file mode 100644
index 000000000..2f82e5053
--- /dev/null
+++ b/src/lerobot/model/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Kinematics utilities for robot modeling.
+
+from .kinematics import RobotKinematics as RobotKinematics
+
+__all__ = ["RobotKinematics"]
diff --git a/src/lerobot/motors/__init__.py b/src/lerobot/motors/__init__.py
index 5df80d5ba..63ac14a38 100644
--- a/src/lerobot/motors/__init__.py
+++ b/src/lerobot/motors/__init__.py
@@ -19,3 +19,5 @@ from .motors_bus import (
MotorCalibration,
MotorNormMode,
)
+
+__all__ = ["Motor", "MotorCalibration", "MotorNormMode"]
diff --git a/src/lerobot/motors/damiao/__init__.py b/src/lerobot/motors/damiao/__init__.py
index 8240138cf..5a98fa4d2 100644
--- a/src/lerobot/motors/damiao/__init__.py
+++ b/src/lerobot/motors/damiao/__init__.py
@@ -15,4 +15,6 @@
# limitations under the License.
from .damiao import DamiaoMotorsBus
-from .tables import *
+from .tables import * # noqa: F403 — hardware constant tables
+
+__all__ = ["DamiaoMotorsBus"]
diff --git a/src/lerobot/motors/dynamixel/__init__.py b/src/lerobot/motors/dynamixel/__init__.py
index 425f8538a..01fcadf4f 100644
--- a/src/lerobot/motors/dynamixel/__init__.py
+++ b/src/lerobot/motors/dynamixel/__init__.py
@@ -15,4 +15,6 @@
# limitations under the License.
from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode
-from .tables import *
+from .tables import * # noqa: F403 — hardware constant tables
+
+__all__ = ["DriveMode", "DynamixelMotorsBus", "OperatingMode", "TorqueMode"]
diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py
index bca455dc5..4502bd668 100644
--- a/src/lerobot/motors/dynamixel/dynamixel.py
+++ b/src/lerobot/motors/dynamixel/dynamixel.py
@@ -21,6 +21,9 @@
import logging
from copy import deepcopy
from enum import Enum
+from typing import TYPE_CHECKING
+
+from lerobot.utils.import_utils import _dynamixel_sdk_available, require_package
from ..encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
@@ -33,6 +36,11 @@ from .tables import (
MODEL_RESOLUTION,
)
+if TYPE_CHECKING or _dynamixel_sdk_available:
+ import dynamixel_sdk as dxl
+else:
+ dxl = None
+
PROTOCOL_VERSION = 2.0
DEFAULT_BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000
@@ -82,23 +90,6 @@ class TorqueMode(Enum):
DISABLED = 0
-def _split_into_byte_chunks(value: int, length: int) -> list[int]:
- import dynamixel_sdk as dxl
-
- if length == 1:
- data = [value]
- elif length == 2:
- data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
- elif length == 4:
- data = [
- dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
- dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
- dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
- dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
- ]
- return data
-
-
class DynamixelMotorsBus(SerialMotorsBus):
"""
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
@@ -123,9 +114,8 @@ class DynamixelMotorsBus(SerialMotorsBus):
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
+ require_package("dynamixel-sdk", extra="dynamixel", import_name="dynamixel_sdk")
super().__init__(port, motors, calibration)
- import dynamixel_sdk as dxl
-
self.port_handler = dxl.PortHandler(self.port)
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
@@ -244,7 +234,18 @@ class DynamixelMotorsBus(SerialMotorsBus):
return half_turn_homings
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
- return _split_into_byte_chunks(value, length)
+ if length == 1:
+ data = [value]
+ elif length == 2:
+ data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
+ elif length == 4:
+ data = [
+ dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
+ dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
+ dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
+ dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
+ ]
+ return data
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
for n_try in range(1 + num_retry):
diff --git a/src/lerobot/motors/feetech/__init__.py b/src/lerobot/motors/feetech/__init__.py
index 75da2d221..6c06d8b95 100644
--- a/src/lerobot/motors/feetech/__init__.py
+++ b/src/lerobot/motors/feetech/__init__.py
@@ -15,4 +15,6 @@
# limitations under the License.
from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode
-from .tables import *
+from .tables import * # noqa: F403 — hardware constant tables
+
+__all__ = ["DriveMode", "FeetechMotorsBus", "OperatingMode", "TorqueMode"]
diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py
index 58a65310d..9b1e0fb7e 100644
--- a/src/lerobot/motors/feetech/feetech.py
+++ b/src/lerobot/motors/feetech/feetech.py
@@ -16,6 +16,9 @@ import logging
from copy import deepcopy
from enum import Enum
from pprint import pformat
+from typing import TYPE_CHECKING
+
+from lerobot.utils.import_utils import _feetech_sdk_available, require_package
from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address
@@ -32,6 +35,11 @@ from .tables import (
SCAN_BAUDRATES,
)
+if TYPE_CHECKING or _feetech_sdk_available:
+ import scservo_sdk as scs
+else:
+ scs = None
+
DEFAULT_PROTOCOL_VERSION = 0
DEFAULT_BAUDRATE = 1_000_000
DEFAULT_TIMEOUT_MS = 1000
@@ -65,23 +73,6 @@ class TorqueMode(Enum):
DISABLED = 0
-def _split_into_byte_chunks(value: int, length: int) -> list[int]:
- import scservo_sdk as scs
-
- if length == 1:
- data = [value]
- elif length == 2:
- data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
- elif length == 4:
- data = [
- scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
- scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
- scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
- scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
- ]
- return data
-
-
def patch_setPacketTimeout(self, packet_length): # noqa: N802
"""
HACK: This patches the PortHandler behavior to set the correct packet timeouts.
@@ -119,11 +110,10 @@ class FeetechMotorsBus(SerialMotorsBus):
calibration: dict[str, MotorCalibration] | None = None,
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
):
+ require_package("feetech-servo-sdk", extra="feetech", import_name="scservo_sdk")
super().__init__(port, motors, calibration)
self.protocol_version = protocol_version
self._assert_same_protocol()
- import scservo_sdk as scs
-
self.port_handler = scs.PortHandler(self.port)
# HACK: monkeypatch
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign]
@@ -195,8 +185,6 @@ class FeetechMotorsBus(SerialMotorsBus):
raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.")
def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]:
- import scservo_sdk as scs
-
model = self.motors[motor].model
search_baudrates = (
[initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model]
@@ -329,11 +317,20 @@ class FeetechMotorsBus(SerialMotorsBus):
return ids_values
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
- return _split_into_byte_chunks(value, length)
+ if length == 1:
+ data = [value]
+ elif length == 2:
+ data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
+ elif length == 4:
+ data = [
+ scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
+ scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
+ scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
+ scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
+ ]
+ return data
def _broadcast_ping(self) -> tuple[dict[int, int], int]:
- import scservo_sdk as scs
-
data_list: dict[int, int] = {}
status_length = 6
diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py
index 509f5e95f..209489bb9 100644
--- a/src/lerobot/motors/motors_bus.py
+++ b/src/lerobot/motors/motors_bus.py
@@ -29,12 +29,22 @@ from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from pprint import pformat
-from typing import Protocol
+from typing import TYPE_CHECKING, Protocol
-import serial
-from deepdiff import DeepDiff
from tqdm import tqdm
+from lerobot.utils.import_utils import _deepdiff_available, _serial_available, require_package
+
+if TYPE_CHECKING or _serial_available:
+ import serial
+else:
+ serial = None # type: ignore[assignment]
+
+if TYPE_CHECKING or _deepdiff_available:
+ from deepdiff import DeepDiff
+else:
+ DeepDiff = None # type: ignore[assignment, misc]
+
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.utils import enter_pressed, move_cursor_up
@@ -346,6 +356,8 @@ class SerialMotorsBus(MotorsBusBase):
motors: dict[str, Motor],
calibration: dict[str, MotorCalibration] | None = None,
):
+ require_package("pyserial", extra="hardware", import_name="serial")
+ require_package("deepdiff", extra="hardware")
super().__init__(port, motors, calibration)
self.port_handler: PortHandler
diff --git a/src/lerobot/motors/robstride/__init__.py b/src/lerobot/motors/robstride/__init__.py
index 7933ac6fa..4729b3968 100644
--- a/src/lerobot/motors/robstride/__init__.py
+++ b/src/lerobot/motors/robstride/__init__.py
@@ -15,4 +15,6 @@
# limitations under the License.
from .robstride import RobstrideMotorsBus
-from .tables import *
+from .tables import * # noqa: F403 — hardware constant tables
+
+__all__ = ["RobstrideMotorsBus"]
diff --git a/src/lerobot/optim/__init__.py b/src/lerobot/optim/__init__.py
index de2c4c996..46676027b 100644
--- a/src/lerobot/optim/__init__.py
+++ b/src/lerobot/optim/__init__.py
@@ -12,4 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .optimizers import OptimizerConfig as OptimizerConfig
+from .optimizers import (
+ AdamConfig as AdamConfig,
+ AdamWConfig as AdamWConfig,
+ MultiAdamConfig as MultiAdamConfig,
+ OptimizerConfig as OptimizerConfig,
+ SGDConfig as SGDConfig,
+ XVLAAdamWConfig as XVLAAdamWConfig,
+ load_optimizer_state,
+ save_optimizer_state,
+)
+from .schedulers import (
+ CosineDecayWithWarmupSchedulerConfig as CosineDecayWithWarmupSchedulerConfig,
+ DiffuserSchedulerConfig as DiffuserSchedulerConfig,
+ LRSchedulerConfig as LRSchedulerConfig,
+ VQBeTSchedulerConfig as VQBeTSchedulerConfig,
+ load_scheduler_state,
+ save_scheduler_state,
+)
+
+# NOTE: make_optimizer_and_scheduler is intentionally NOT re-exported here
+# to avoid circular dependencies (it imports lerobot.configs.train and lerobot.policies).
+# Import directly: ``from lerobot.optim.factory import make_optimizer_and_scheduler``
+
+__all__ = [
+ # Optimizer configs
+ "AdamConfig",
+ "AdamWConfig",
+ "MultiAdamConfig",
+ "OptimizerConfig",
+ "SGDConfig",
+ "XVLAAdamWConfig",
+ # Scheduler configs
+ "CosineDecayWithWarmupSchedulerConfig",
+ "DiffuserSchedulerConfig",
+ "LRSchedulerConfig",
+ "VQBeTSchedulerConfig",
+ # State management
+ "load_optimizer_state",
+ "load_scheduler_state",
+ "save_optimizer_state",
+ "save_scheduler_state",
+]
diff --git a/src/lerobot/optim/factory.py b/src/lerobot/optim/factory.py
index 699289993..ce519e0b2 100644
--- a/src/lerobot/optim/factory.py
+++ b/src/lerobot/optim/factory.py
@@ -19,7 +19,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.configs.train import TrainPipelineConfig
-from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.policies import PreTrainedPolicy
def make_optimizer_and_scheduler(
diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py
index e2e3d8937..0bdd7a37e 100644
--- a/src/lerobot/optim/optimizers.py
+++ b/src/lerobot/optim/optimizers.py
@@ -23,13 +23,12 @@ import draccus
import torch
from safetensors.torch import load_file, save_file
-from lerobot.datasets.io_utils import write_json
-from lerobot.datasets.utils import flatten_dict, unflatten_dict
from lerobot.utils.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
-from lerobot.utils.io_utils import deserialize_json_into_object
+from lerobot.utils.io_utils import deserialize_json_into_object, write_json
+from lerobot.utils.utils import flatten_dict, unflatten_dict
# Type alias for parameters accepted by optimizer build() methods.
# This matches PyTorch's optimizer signature while also supporting:
diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py
index 19c3fd7bd..914edd2db 100644
--- a/src/lerobot/optim/schedulers.py
+++ b/src/lerobot/optim/schedulers.py
@@ -23,9 +23,8 @@ import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
-from lerobot.datasets.io_utils import write_json
from lerobot.utils.constants import SCHEDULER_STATE
-from lerobot.utils.io_utils import deserialize_json_into_object
+from lerobot.utils.io_utils import deserialize_json_into_object, write_json
@dataclass
@@ -48,6 +47,9 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
num_warmup_steps: int | None = None
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
+ from lerobot.utils.import_utils import require_package
+
+ require_package("diffusers", extra="diffusion")
from diffusers.optimization import get_scheduler
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py
index 55ce09cf9..e138a84d9 100644
--- a/src/lerobot/policies/__init__.py
+++ b/src/lerobot/policies/__init__.py
@@ -14,30 +14,55 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
+from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors
from .groot.configuration_groot import GrootConfig as GrootConfig
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
from .pi05.configuration_pi05 import PI05Config as PI05Config
+from .pretrained import PreTrainedPolicy as PreTrainedPolicy
+from .rtc import ActionInterpolator as ActionInterpolator
+from .sac.configuration_sac import SACConfig as SACConfig
+from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
+from .sarm.configuration_sarm import SARMConfig as SARMConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
-from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
+from .utils import make_robot_action, prepare_observation_for_inference
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
+# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here.
+# They have heavy optional dependencies and are loaded lazily via get_policy_class().
+# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy``
+
__all__ = [
+ # Configuration classes
"ACTConfig",
"DiffusionConfig",
+ "GrootConfig",
"MultiTaskDiTConfig",
"PI0Config",
- "PI05Config",
"PI0FastConfig",
- "SmolVLAConfig",
+ "PI05Config",
+ "RewardClassifierConfig",
+ "SACConfig",
"SARMConfig",
+ "SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
- "GrootConfig",
- "XVLAConfig",
"WallXConfig",
+ "XVLAConfig",
+ # Base class
+ "PreTrainedPolicy",
+ # RTC utilities
+ "ActionInterpolator",
+ # Utility functions
+ "make_robot_action",
+ "prepare_observation_for_inference",
+ # Factory functions
+ "get_policy_class",
+ "make_policy",
+ "make_policy_config",
+ "make_pre_post_processors",
]
diff --git a/src/lerobot/policies/act/__init__.py b/src/lerobot/policies/act/__init__.py
new file mode 100644
index 000000000..44f15189f
--- /dev/null
+++ b/src/lerobot/policies/act/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_act import ACTConfig
+from .modeling_act import ACTPolicy
+from .processor_act import make_act_pre_post_processors
+
+__all__ = ["ACTConfig", "ACTPolicy", "make_act_pre_post_processors"]
diff --git a/src/lerobot/policies/act/configuration_act.py b/src/lerobot/policies/act/configuration_act.py
index bd89185fd..b5c3d68f1 100644
--- a/src/lerobot/policies/act/configuration_act.py
+++ b/src/lerobot/policies/act/configuration_act.py
@@ -15,9 +15,8 @@
# limitations under the License.
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import NormalizationMode
-from lerobot.optim.optimizers import AdamWConfig
+from lerobot.configs import NormalizationMode, PreTrainedConfig
+from lerobot.optim import AdamWConfig
@PreTrainedConfig.register_subclass("act")
diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py
index a5c48eb3d..0120258ee 100644
--- a/src/lerobot/policies/act/modeling_act.py
+++ b/src/lerobot/policies/act/modeling_act.py
@@ -33,10 +33,11 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
-from lerobot.policies.act.configuration_act import ACTConfig
-from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
+from ..pretrained import PreTrainedPolicy
+from .configuration_act import ACTConfig
+
class ACTPolicy(PreTrainedPolicy):
"""
diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py
index 727b18cef..d87ade900 100644
--- a/src/lerobot/policies/act/processor_act.py
+++ b/src/lerobot/policies/act/processor_act.py
@@ -17,7 +17,6 @@ from typing import Any
import torch
-from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -26,10 +25,13 @@ from lerobot.processor import (
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+from .configuration_act import ACTConfig
+
def make_act_pre_post_processors(
config: ACTConfig,
diff --git a/src/lerobot/policies/diffusion/__init__.py b/src/lerobot/policies/diffusion/__init__.py
new file mode 100644
index 000000000..4f6ee820a
--- /dev/null
+++ b/src/lerobot/policies/diffusion/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_diffusion import DiffusionConfig
+from .modeling_diffusion import DiffusionPolicy
+from .processor_diffusion import make_diffusion_pre_post_processors
+
+__all__ = ["DiffusionConfig", "DiffusionPolicy", "make_diffusion_pre_post_processors"]
diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py
index 91b3df214..8e3d4bf19 100644
--- a/src/lerobot/policies/diffusion/configuration_diffusion.py
+++ b/src/lerobot/policies/diffusion/configuration_diffusion.py
@@ -16,10 +16,8 @@
# limitations under the License.
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import NormalizationMode
-from lerobot.optim.optimizers import AdamConfig
-from lerobot.optim.schedulers import DiffuserSchedulerConfig
+from lerobot.configs import NormalizationMode, PreTrainedConfig
+from lerobot.optim import AdamConfig, DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("diffusion")
diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py
index aa8d5dd14..5b3b97571 100644
--- a/src/lerobot/policies/diffusion/modeling_diffusion.py
+++ b/src/lerobot/policies/diffusion/modeling_diffusion.py
@@ -29,19 +29,18 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
-from diffusers.schedulers.scheduling_ddim import DDIMScheduler
-from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
-from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.utils import (
+from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
+
+from ..pretrained import PreTrainedPolicy
+from ..utils import (
get_device_from_parameters,
get_dtype_from_parameters,
get_output_shape,
populate_queues,
)
-from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
+from .configuration_diffusion import DiffusionConfig
class DiffusionPolicy(PreTrainedPolicy):
@@ -151,11 +150,17 @@ class DiffusionPolicy(PreTrainedPolicy):
return loss, None
-def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
+def _make_noise_scheduler(name: str, **kwargs: dict):
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
+ from lerobot.utils.import_utils import require_package
+
+ require_package("diffusers", extra="diffusion")
+ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
+ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
+
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py
index a7799be64..c4bc17680 100644
--- a/src/lerobot/policies/diffusion/processor_diffusion.py
+++ b/src/lerobot/policies/diffusion/processor_diffusion.py
@@ -18,7 +18,6 @@ from typing import Any
import torch
-from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -27,10 +26,13 @@ from lerobot.processor import (
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+from .configuration_diffusion import DiffusionConfig
+
def make_diffusion_pre_post_processors(
config: DiffusionConfig,
diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py
index 501dd7af1..611a6e9bc 100644
--- a/src/lerobot/policies/factory.py
+++ b/src/lerobot/policies/factory.py
@@ -18,34 +18,19 @@ from __future__ import annotations
import importlib
import logging
-from typing import Any, TypedDict, Unpack
+from typing import TYPE_CHECKING, Any, TypedDict, Unpack
import torch
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType
-from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
-from lerobot.datasets.feature_utils import dataset_to_policy_features
-from lerobot.envs.configs import EnvConfig
-from lerobot.envs.utils import env_to_policy_features
-from lerobot.policies.act.configuration_act import ACTConfig
-from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
-from lerobot.policies.groot.configuration_groot import GrootConfig
-from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
-from lerobot.policies.pi0.configuration_pi0 import PI0Config
-from lerobot.policies.pi05.configuration_pi05 import PI05Config
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.sac.configuration_sac import SACConfig
-from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
-from lerobot.policies.sarm.configuration_sarm import SARMConfig
-from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
-from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
-from lerobot.policies.utils import validate_visual_features_consistency
-from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
-from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
-from lerobot.policies.xvla.configuration_xvla import XVLAConfig
-from lerobot.processor import PolicyProcessorPipeline
-from lerobot.processor.converters import (
+if TYPE_CHECKING:
+ from lerobot.datasets import LeRobotDatasetMetadata
+
+from lerobot.configs import FeatureType, PreTrainedConfig
+from lerobot.envs import EnvConfig, env_to_policy_features
+from lerobot.processor import (
+ AbsoluteActionsProcessorStep,
+ PolicyProcessorPipeline,
+ RelativeActionsProcessorStep,
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
@@ -57,6 +42,24 @@ from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
+from lerobot.utils.feature_utils import dataset_to_policy_features
+
+from .act.configuration_act import ACTConfig
+from .diffusion.configuration_diffusion import DiffusionConfig
+from .groot.configuration_groot import GrootConfig
+from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
+from .pi0.configuration_pi0 import PI0Config
+from .pi05.configuration_pi05 import PI05Config
+from .pretrained import PreTrainedPolicy
+from .sac.configuration_sac import SACConfig
+from .sac.reward_model.configuration_classifier import RewardClassifierConfig
+from .sarm.configuration_sarm import SARMConfig
+from .smolvla.configuration_smolvla import SmolVLAConfig
+from .tdmpc.configuration_tdmpc import TDMPCConfig
+from .utils import validate_visual_features_consistency
+from .vqbet.configuration_vqbet import VQBeTConfig
+from .wall_x.configuration_wall_x import WallXConfig
+from .xvla.configuration_xvla import XVLAConfig
def _reconnect_relative_absolute_steps(
@@ -69,11 +72,6 @@ def _reconnect_relative_absolute_steps(
the RelativeActionsProcessorStep so it can read the cached state at inference time.
That reference is not serializable, so we re-establish it here after loading.
"""
- from lerobot.processor.relative_action_processor import (
- AbsoluteActionsProcessorStep,
- RelativeActionsProcessorStep,
- )
-
relative_step = next((s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep)), None)
if relative_step is None:
return
@@ -99,63 +97,63 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
NotImplementedError: If the policy name is not recognized.
"""
if name == "tdmpc":
- from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
+ from .tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy
elif name == "diffusion":
- from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
+ from .diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy
elif name == "act":
- from lerobot.policies.act.modeling_act import ACTPolicy
+ from .act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "multi_task_dit":
- from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
+ from .multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
return MultiTaskDiTPolicy
elif name == "vqbet":
- from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
+ from .vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
elif name == "pi0":
- from lerobot.policies.pi0.modeling_pi0 import PI0Policy
+ from .pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi0_fast":
- from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy
+ from .pi0_fast.modeling_pi0_fast import PI0FastPolicy
return PI0FastPolicy
elif name == "pi05":
- from lerobot.policies.pi05.modeling_pi05 import PI05Policy
+ from .pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "sac":
- from lerobot.policies.sac.modeling_sac import SACPolicy
+ from .sac.modeling_sac import SACPolicy
return SACPolicy
elif name == "reward_classifier":
- from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
+ from .sac.reward_model.modeling_classifier import Classifier
return Classifier
elif name == "smolvla":
- from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
+ from .smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "sarm":
- from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
+ from .sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot":
- from lerobot.policies.groot.modeling_groot import GrootPolicy
+ from .groot.modeling_groot import GrootPolicy
return GrootPolicy
elif name == "xvla":
- from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
+ from .xvla.modeling_xvla import XVLAPolicy
return XVLAPolicy
elif name == "wall_x":
- from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy
+ from .wall_x.modeling_wall_x import WallXPolicy
return WallXPolicy
else:
@@ -315,7 +313,7 @@ def make_pre_post_processors(
# Create a new processor based on policy type
if isinstance(policy_cfg, TDMPCConfig):
- from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
+ from .tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
processors = make_tdmpc_pre_post_processors(
config=policy_cfg,
@@ -323,7 +321,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, DiffusionConfig):
- from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
+ from .diffusion.processor_diffusion import make_diffusion_pre_post_processors
processors = make_diffusion_pre_post_processors(
config=policy_cfg,
@@ -331,7 +329,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, ACTConfig):
- from lerobot.policies.act.processor_act import make_act_pre_post_processors
+ from .act.processor_act import make_act_pre_post_processors
processors = make_act_pre_post_processors(
config=policy_cfg,
@@ -339,7 +337,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, MultiTaskDiTConfig):
- from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
+ from .multi_task_dit.processor_multi_task_dit import (
make_multi_task_dit_pre_post_processors,
)
@@ -349,7 +347,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, VQBeTConfig):
- from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
+ from .vqbet.processor_vqbet import make_vqbet_pre_post_processors
processors = make_vqbet_pre_post_processors(
config=policy_cfg,
@@ -357,7 +355,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, PI0Config):
- from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
+ from .pi0.processor_pi0 import make_pi0_pre_post_processors
processors = make_pi0_pre_post_processors(
config=policy_cfg,
@@ -365,7 +363,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, PI05Config):
- from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors
+ from .pi05.processor_pi05 import make_pi05_pre_post_processors
processors = make_pi05_pre_post_processors(
config=policy_cfg,
@@ -373,7 +371,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, SACConfig):
- from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
+ from .sac.processor_sac import make_sac_pre_post_processors
processors = make_sac_pre_post_processors(
config=policy_cfg,
@@ -381,7 +379,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, RewardClassifierConfig):
- from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
+ from .sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
config=policy_cfg,
@@ -389,7 +387,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, SmolVLAConfig):
- from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
+ from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
processors = make_smolvla_pre_post_processors(
config=policy_cfg,
@@ -397,7 +395,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, SARMConfig):
- from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
+ from .sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
@@ -405,7 +403,7 @@ def make_pre_post_processors(
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig):
- from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
+ from .groot.processor_groot import make_groot_pre_post_processors
processors = make_groot_pre_post_processors(
config=policy_cfg,
@@ -413,7 +411,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, XVLAConfig):
- from lerobot.policies.xvla.processor_xvla import (
+ from .xvla.processor_xvla import (
make_xvla_pre_post_processors,
)
@@ -423,7 +421,7 @@ def make_pre_post_processors(
)
elif isinstance(policy_cfg, WallXConfig):
- from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors
+ from .wall_x.processor_wall_x import make_wall_x_pre_post_processors
processors = make_wall_x_pre_post_processors(
config=policy_cfg,
diff --git a/src/lerobot/policies/groot/action_head/__init__.py b/src/lerobot/policies/groot/action_head/__init__.py
index 3159bfe65..63ffc39e6 100644
--- a/src/lerobot/policies/groot/action_head/__init__.py
+++ b/src/lerobot/policies/groot/action_head/__init__.py
@@ -12,3 +12,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+__all__: list[str] = []
diff --git a/src/lerobot/policies/groot/action_head/cross_attention_dit.py b/src/lerobot/policies/groot/action_head/cross_attention_dit.py
index 40f7ba603..a4cd1a0b7 100755
--- a/src/lerobot/policies/groot/action_head/cross_attention_dit.py
+++ b/src/lerobot/policies/groot/action_head/cross_attention_dit.py
@@ -14,21 +14,37 @@
# limitations under the License.
+from typing import TYPE_CHECKING
+
import torch
import torch.nn.functional as F # noqa: N812
-from diffusers import ConfigMixin, ModelMixin
-from diffusers.configuration_utils import register_to_config
-from diffusers.models.attention import Attention, FeedForward
-from diffusers.models.embeddings import (
- SinusoidalPositionalEmbedding,
- TimestepEmbedding,
- Timesteps,
-)
from torch import nn
+from lerobot.utils.import_utils import _diffusers_available, require_package
+
+if TYPE_CHECKING or _diffusers_available:
+ from diffusers import ConfigMixin, ModelMixin
+ from diffusers.configuration_utils import register_to_config
+ from diffusers.models.attention import Attention, FeedForward
+ from diffusers.models.embeddings import (
+ SinusoidalPositionalEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+ )
+else:
+ ConfigMixin = object
+ ModelMixin = nn.Module
+ register_to_config = lambda fn: fn # noqa: E731
+ Attention = None
+ FeedForward = None
+ SinusoidalPositionalEmbedding = None
+ TimestepEmbedding = None
+ Timesteps = None
+
class TimestepEncoder(nn.Module):
def __init__(self, embedding_dim, compute_dtype=torch.float32):
+ require_package("diffusers", extra="groot")
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
@@ -88,6 +104,7 @@ class BasicTransformerBlock(nn.Module):
ff_bias: bool = True,
attention_out_bias: bool = True,
):
+ require_package("diffusers", extra="groot")
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py
index bfc456ba0..4fda21ca5 100644
--- a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py
+++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py
@@ -31,11 +31,10 @@ else:
PretrainedConfig = object
BatchFeature = None
-from lerobot.policies.groot.action_head.action_encoder import (
+from .action_encoder import (
SinusoidalPositionalEncoding,
swish,
)
-
from .cross_attention_dit import DiT, SelfAttentionTransformer
diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py
index 4f3d78222..17cb631d7 100644
--- a/src/lerobot/policies/groot/configuration_groot.py
+++ b/src/lerobot/policies/groot/configuration_groot.py
@@ -16,10 +16,8 @@
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
+from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py
index 06ff5a04d..fc753839a 100644
--- a/src/lerobot/policies/groot/groot_n1.py
+++ b/src/lerobot/policies/groot/groot_n1.py
@@ -41,12 +41,13 @@ try:
except ImportError:
tree = None
-from lerobot.policies.groot.action_head.flow_matching_action_head import (
+from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
+
+from .action_head.flow_matching_action_head import (
FlowmatchingActionHead,
FlowmatchingActionHeadConfig,
)
-from lerobot.policies.groot.utils import ensure_eagle_cache_ready
-from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME
+from .utils import ensure_eagle_cache_ready
DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve())
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py
index 9a479b8f9..4b612bca4 100644
--- a/src/lerobot/policies/groot/modeling_groot.py
+++ b/src/lerobot/policies/groot/modeling_groot.py
@@ -41,12 +41,13 @@ from typing import TypeVar
import torch
from torch import Tensor
-from lerobot.configs.types import FeatureType, PolicyFeature
-from lerobot.policies.groot.configuration_groot import GrootConfig
-from lerobot.policies.groot.groot_n1 import GR00TN15
-from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES
+from ..pretrained import PreTrainedPolicy
+from .configuration_groot import GrootConfig
+from .groot_n1 import GR00TN15
+
T = TypeVar("T", bound="GrootPolicy")
diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py
index 8bf9dabca..3367de711 100644
--- a/src/lerobot/policies/groot/processor_groot.py
+++ b/src/lerobot/policies/groot/processor_groot.py
@@ -30,12 +30,11 @@ else:
AutoProcessor = None
ProcessorMixin = object
-from lerobot.configs.types import (
+from lerobot.configs import (
FeatureType,
NormalizationMode,
PolicyFeature,
)
-from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -44,8 +43,6 @@ from lerobot.processor import (
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
-)
-from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
@@ -60,6 +57,8 @@ from lerobot.utils.constants import (
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
+from .configuration_groot import GrootConfig
+
# Defaults for Eagle processor locations
DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5"
diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py
index 061230687..33be3113f 100644
--- a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py
+++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py
@@ -17,10 +17,8 @@
import logging
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import NormalizationMode
-from lerobot.optim.optimizers import AdamConfig
-from lerobot.optim.schedulers import DiffuserSchedulerConfig
+from lerobot.configs import NormalizationMode, PreTrainedConfig
+from lerobot.optim import AdamConfig, DiffuserSchedulerConfig
@PreTrainedConfig.register_subclass("multi_task_dit")
diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py
index 4fee851e0..8e5d1e3cb 100644
--- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py
+++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py
@@ -34,21 +34,18 @@ import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
import torchvision
-from diffusers.schedulers.scheduling_ddim import DDIMScheduler
-from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor
-from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.utils.import_utils import _transformers_available
+from .configuration_multi_task_dit import MultiTaskDiTConfig
+
# Conditional import for type checking and lazy loading
if TYPE_CHECKING or _transformers_available:
from transformers import CLIPTextModel, CLIPVisionModel
else:
CLIPTextModel = None
CLIPVisionModel = None
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import (
ACTION,
OBS_IMAGES,
@@ -57,6 +54,9 @@ from lerobot.utils.constants import (
OBS_STATE,
)
+from ..pretrained import PreTrainedPolicy
+from ..utils import populate_queues
+
# -- Policy --
@@ -643,6 +643,12 @@ class DiffusionObjective(nn.Module):
"prediction_type": config.prediction_type,
}
+ from lerobot.utils.import_utils import require_package
+
+ require_package("diffusers", extra="multi_task_dit")
+ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
+ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
+
if config.noise_scheduler_type == "DDPM":
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
elif config.noise_scheduler_type == "DDIM":
diff --git a/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py
index fc94599c2..5f5b9994e 100644
--- a/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py
+++ b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py
@@ -18,7 +18,6 @@ from typing import Any
import torch
-from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -28,10 +27,13 @@ from lerobot.processor import (
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+from .configuration_multi_task_dit import MultiTaskDiTConfig
+
def make_multi_task_dit_pre_post_processors(
config: MultiTaskDiTConfig,
diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py
index cf4b636a3..a06315f07 100644
--- a/src/lerobot/policies/pi0/configuration_pi0.py
+++ b/src/lerobot/policies/pi0/configuration_pi0.py
@@ -16,13 +16,12 @@
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
-from lerobot.policies.rtc.configuration_rtc import RTCConfig
+from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
+from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
+from ..rtc.configuration_rtc import RTCConfig
+
DEFAULT_IMAGE_SIZE = 224
diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py
index aebf32964..22e4e6a26 100644
--- a/src/lerobot/policies/pi0/modeling_pi0.py
+++ b/src/lerobot/policies/pi0/modeling_pi0.py
@@ -33,7 +33,7 @@ if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
- from lerobot.policies.pi_gemma import (
+ from ..pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
@@ -48,10 +48,7 @@ else:
PaliGemmaForConditionalGenerationWithPiGemma = None
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
-from lerobot.policies.pretrained import PreTrainedPolicy, T
-from lerobot.policies.rtc.modeling_rtc import RTCProcessor
+from lerobot.configs import PreTrainedConfig
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -60,6 +57,10 @@ from lerobot.utils.constants import (
OPENPI_ATTENTION_MASK_VALUE,
)
+from ..pretrained import PreTrainedPolicy, T
+from ..rtc.modeling_rtc import RTCProcessor
+from .configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
+
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py
index 0302876a1..ad861f85b 100644
--- a/src/lerobot/policies/pi0/processor_pi0.py
+++ b/src/lerobot/policies/pi0/processor_pi0.py
@@ -18,8 +18,7 @@ from typing import Any
import torch
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
-from lerobot.policies.pi0.configuration_pi0 import PI0Config
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
@@ -34,10 +33,13 @@ from lerobot.processor import (
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+from .configuration_pi0 import PI0Config
+
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py
index 6760be0a2..124e85cc9 100644
--- a/src/lerobot/policies/pi05/configuration_pi05.py
+++ b/src/lerobot/policies/pi05/configuration_pi05.py
@@ -16,13 +16,12 @@
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
-from lerobot.policies.rtc.configuration_rtc import RTCConfig
+from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
+from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
+from ..rtc.configuration_rtc import RTCConfig
+
DEFAULT_IMAGE_SIZE = 224
diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py
index 96c4002f2..a44817a74 100644
--- a/src/lerobot/policies/pi05/modeling_pi05.py
+++ b/src/lerobot/policies/pi05/modeling_pi05.py
@@ -33,7 +33,7 @@ if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
- from lerobot.policies.pi_gemma import (
+ from ..pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
@@ -46,10 +46,7 @@ else:
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
-from lerobot.policies.pretrained import PreTrainedPolicy, T
-from lerobot.policies.rtc.modeling_rtc import RTCProcessor
+from lerobot.configs import PreTrainedConfig
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -57,6 +54,10 @@ from lerobot.utils.constants import (
OPENPI_ATTENTION_MASK_VALUE,
)
+from ..pretrained import PreTrainedPolicy, T
+from ..rtc.modeling_rtc import RTCProcessor
+from .configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
+
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py
index cb616af87..2d015b24f 100644
--- a/src/lerobot/policies/pi05/processor_pi05.py
+++ b/src/lerobot/policies/pi05/processor_pi05.py
@@ -21,8 +21,7 @@ from typing import Any
import numpy as np
import torch
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
-from lerobot.policies.pi05.configuration_pi05 import PI05Config
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
@@ -36,8 +35,9 @@ from lerobot.processor import (
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
@@ -45,6 +45,8 @@ from lerobot.utils.constants import (
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
+from .configuration_pi05 import PI05Config
+
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
@dataclass
diff --git a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py
index 6a645fae1..e5c6851f4 100644
--- a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py
+++ b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py
@@ -16,13 +16,12 @@
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
-from lerobot.policies.rtc.configuration_rtc import RTCConfig
+from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
+from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
+from ..rtc.configuration_rtc import RTCConfig
+
DEFAULT_IMAGE_SIZE = 224
diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py
index 1bcf9794c..e86b8ad27 100644
--- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py
+++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py
@@ -38,7 +38,7 @@ if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
- from lerobot.policies.pi_gemma import (
+ from ..pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaModel,
)
@@ -48,10 +48,7 @@ else:
PiGemmaModel = None
PaliGemmaForConditionalGenerationWithPiGemma = None
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
-from lerobot.policies.pretrained import PreTrainedPolicy, T
-from lerobot.policies.rtc.modeling_rtc import RTCProcessor
+from lerobot.configs import PreTrainedConfig
from lerobot.utils.constants import (
ACTION,
ACTION_TOKEN_MASK,
@@ -61,6 +58,10 @@ from lerobot.utils.constants import (
OPENPI_ATTENTION_MASK_VALUE,
)
+from ..pretrained import PreTrainedPolicy, T
+from ..rtc.modeling_rtc import RTCProcessor
+from .configuration_pi0_fast import PI0FastConfig
+
class ActionSelectKwargs(TypedDict, total=False):
temperature: float | None
diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py
index c4a510615..60a519786 100644
--- a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py
+++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py
@@ -21,8 +21,7 @@ from typing import Any
import numpy as np
import torch
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
-from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AbsoluteActionsProcessorStep,
ActionTokenizerProcessorStep,
@@ -37,8 +36,9 @@ from lerobot.processor import (
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
@@ -46,6 +46,8 @@ from lerobot.utils.constants import (
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
+from .configuration_pi0_fast import PI0FastConfig
+
@ProcessorStepRegistry.register(name="pi0_fast_prepare_state_tokenizer_processor_step")
@dataclass
diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py
index 70efeba6f..724f920f3 100644
--- a/src/lerobot/policies/pretrained.py
+++ b/src/lerobot/policies/pretrained.py
@@ -29,11 +29,12 @@ from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
-from lerobot.configs.policies import PreTrainedConfig
+from lerobot.configs import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
-from lerobot.policies.utils import log_model_loading_keys
from lerobot.utils.hub import HubMixin
+from .utils import log_model_loading_keys
+
T = TypeVar("T", bound="PreTrainedPolicy")
diff --git a/src/lerobot/policies/rtc/__init__.py b/src/lerobot/policies/rtc/__init__.py
index ac7b72ef7..7a29dcac0 100644
--- a/src/lerobot/policies/rtc/__init__.py
+++ b/src/lerobot/policies/rtc/__init__.py
@@ -14,11 +14,11 @@
"""Real-Time Chunking (RTC) utilities for action-chunking policies."""
-from lerobot.policies.rtc.action_interpolator import ActionInterpolator
-from lerobot.policies.rtc.action_queue import ActionQueue
-from lerobot.policies.rtc.configuration_rtc import RTCConfig
-from lerobot.policies.rtc.latency_tracker import LatencyTracker
-from lerobot.policies.rtc.modeling_rtc import RTCProcessor
+from .action_interpolator import ActionInterpolator
+from .action_queue import ActionQueue
+from .configuration_rtc import RTCConfig
+from .latency_tracker import LatencyTracker
+from .modeling_rtc import RTCProcessor
__all__ = [
"ActionInterpolator",
diff --git a/src/lerobot/policies/rtc/action_queue.py b/src/lerobot/policies/rtc/action_queue.py
index 3c20d6d21..dbbdc41df 100644
--- a/src/lerobot/policies/rtc/action_queue.py
+++ b/src/lerobot/policies/rtc/action_queue.py
@@ -27,7 +27,7 @@ from threading import Lock
import torch
from torch import Tensor
-from lerobot.policies.rtc.configuration_rtc import RTCConfig
+from .configuration_rtc import RTCConfig
logger = logging.getLogger(__name__)
diff --git a/src/lerobot/policies/rtc/configuration_rtc.py b/src/lerobot/policies/rtc/configuration_rtc.py
index 70a8dfb09..c70fe3de0 100644
--- a/src/lerobot/policies/rtc/configuration_rtc.py
+++ b/src/lerobot/policies/rtc/configuration_rtc.py
@@ -23,7 +23,7 @@ Based on:
from dataclasses import dataclass
-from lerobot.configs.types import RTCAttentionSchedule
+from lerobot.configs import RTCAttentionSchedule
@dataclass
diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py
index 280905adf..c1aeed328 100644
--- a/src/lerobot/policies/rtc/modeling_rtc.py
+++ b/src/lerobot/policies/rtc/modeling_rtc.py
@@ -27,9 +27,10 @@ import math
import torch
from torch import Tensor
-from lerobot.configs.types import RTCAttentionSchedule
-from lerobot.policies.rtc.configuration_rtc import RTCConfig
-from lerobot.policies.rtc.debug_tracker import Tracker
+from lerobot.configs import RTCAttentionSchedule
+
+from .configuration_rtc import RTCConfig
+from .debug_tracker import Tracker
logger = logging.getLogger(__name__)
diff --git a/src/lerobot/policies/sac/__init__.py b/src/lerobot/policies/sac/__init__.py
new file mode 100644
index 000000000..cf5f149f3
--- /dev/null
+++ b/src/lerobot/policies/sac/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_sac import SACConfig
+from .modeling_sac import SACPolicy
+from .processor_sac import make_sac_pre_post_processors
+
+__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/sac/configuration_sac.py
index ada12330c..db0a77672 100644
--- a/src/lerobot/policies/sac/configuration_sac.py
+++ b/src/lerobot/policies/sac/configuration_sac.py
@@ -17,9 +17,8 @@
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import NormalizationMode
-from lerobot.optim.optimizers import MultiAdamConfig
+from lerobot.configs import NormalizationMode, PreTrainedConfig
+from lerobot.optim import MultiAdamConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py
index d5dd71a48..cc7030ce2 100644
--- a/src/lerobot/policies/sac/modeling_sac.py
+++ b/src/lerobot/policies/sac/modeling_sac.py
@@ -28,11 +28,12 @@ import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
-from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
+from ..pretrained import PreTrainedPolicy
+from ..utils import get_device_from_parameters
+from .configuration_sac import SACConfig, is_image_feature
+
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py
index cf90e3cb4..3409307c2 100644
--- a/src/lerobot/policies/sac/processor_sac.py
+++ b/src/lerobot/policies/sac/processor_sac.py
@@ -19,7 +19,6 @@ from typing import Any
import torch
-from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -28,10 +27,13 @@ from lerobot.processor import (
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+from .configuration_sac import SACConfig
+
def make_sac_pre_post_processors(
config: SACConfig,
diff --git a/src/lerobot/policies/sac/reward_model/__init__.py b/src/lerobot/policies/sac/reward_model/__init__.py
new file mode 100644
index 000000000..1504a9947
--- /dev/null
+++ b/src/lerobot/policies/sac/reward_model/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_classifier import RewardClassifierConfig
+from .modeling_classifier import Classifier
+from .processor_classifier import make_classifier_processor
+
+__all__ = ["RewardClassifierConfig", "Classifier", "make_classifier_processor"]
diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/policies/sac/reward_model/configuration_classifier.py
index 879e3c1af..3a5bfa424 100644
--- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py
+++ b/src/lerobot/policies/sac/reward_model/configuration_classifier.py
@@ -15,10 +15,8 @@
# limitations under the License.
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import NormalizationMode
-from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig
-from lerobot.optim.schedulers import LRSchedulerConfig
+from lerobot.configs import NormalizationMode, PreTrainedConfig
+from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig
from lerobot.utils.constants import OBS_IMAGE
diff --git a/src/lerobot/policies/sac/reward_model/modeling_classifier.py b/src/lerobot/policies/sac/reward_model/modeling_classifier.py
index dba6a174b..c8b7efe58 100644
--- a/src/lerobot/policies/sac/reward_model/modeling_classifier.py
+++ b/src/lerobot/policies/sac/reward_model/modeling_classifier.py
@@ -19,10 +19,11 @@ import logging
import torch
from torch import Tensor, nn
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.utils.constants import OBS_IMAGE, REWARD
+from ...pretrained import PreTrainedPolicy
+from .configuration_classifier import RewardClassifierConfig
+
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
diff --git a/src/lerobot/policies/sac/reward_model/processor_classifier.py b/src/lerobot/policies/sac/reward_model/processor_classifier.py
index c2a34eab2..1f7a66e58 100644
--- a/src/lerobot/policies/sac/reward_model/processor_classifier.py
+++ b/src/lerobot/policies/sac/reward_model/processor_classifier.py
@@ -18,15 +18,17 @@ from typing import Any
import torch
-from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.processor import (
DeviceProcessorStep,
IdentityProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
+
+from .configuration_classifier import RewardClassifierConfig
def make_classifier_processor(
diff --git a/src/lerobot/policies/sarm/__init__.py b/src/lerobot/policies/sarm/__init__.py
new file mode 100644
index 000000000..b164c87ef
--- /dev/null
+++ b/src/lerobot/policies/sarm/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_sarm import SARMConfig
+from .modeling_sarm import SARMRewardModel
+
+__all__ = ["SARMConfig", "SARMRewardModel"]
diff --git a/src/lerobot/policies/sarm/compute_rabc_weights.py b/src/lerobot/policies/sarm/compute_rabc_weights.py
index 485c1096b..07d0780b5 100644
--- a/src/lerobot/policies/sarm/compute_rabc_weights.py
+++ b/src/lerobot/policies/sarm/compute_rabc_weights.py
@@ -57,10 +57,11 @@ import pyarrow.parquet as pq
import torch
from tqdm import tqdm
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
-from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
-from lerobot.policies.sarm.sarm_utils import normalize_stage_tau
+from lerobot.datasets import LeRobotDataset
+
+from .modeling_sarm import SARMRewardModel
+from .processor_sarm import make_sarm_pre_post_processors
+from .sarm_utils import normalize_stage_tau
def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None:
diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py
index 673422fe2..fc8daa055 100644
--- a/src/lerobot/policies/sarm/configuration_sarm.py
+++ b/src/lerobot/policies/sarm/configuration_sarm.py
@@ -22,10 +22,8 @@ Paper: https://arxiv.org/abs/2509.25358
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
+from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py
index 6051d90f8..710554e4b 100644
--- a/src/lerobot/policies/sarm/modeling_sarm.py
+++ b/src/lerobot/policies/sarm/modeling_sarm.py
@@ -34,13 +34,14 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.sarm.configuration_sarm import SARMConfig
-from lerobot.policies.sarm.sarm_utils import (
+from lerobot.utils.constants import OBS_STR
+
+from ..pretrained import PreTrainedPolicy
+from .configuration_sarm import SARMConfig
+from .sarm_utils import (
normalize_stage_tau,
pad_state_to_max_dim,
)
-from lerobot.utils.constants import OBS_STR
class StageTransformer(nn.Module):
diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py
index f377a7ffa..e939b3485 100644
--- a/src/lerobot/policies/sarm/processor_sarm.py
+++ b/src/lerobot/policies/sarm/processor_sarm.py
@@ -16,41 +16,60 @@
"""SARM Processor for encoding images/text and generating stage+tau targets."""
+from __future__ import annotations
+
import random
-from typing import Any
+from typing import TYPE_CHECKING, Any
import numpy as np
-import pandas as pd
import torch
-from faker import Faker
from PIL import Image
-from transformers import CLIPModel, CLIPProcessor
-from lerobot.configs.types import FeatureType, PolicyFeature
-from lerobot.policies.sarm.configuration_sarm import SARMConfig
-from lerobot.policies.sarm.sarm_utils import (
+from lerobot.utils.import_utils import (
+ _faker_available,
+ _pandas_available,
+ _transformers_available,
+ require_package,
+)
+
+if TYPE_CHECKING or _transformers_available:
+ from transformers import CLIPModel, CLIPProcessor
+else:
+ CLIPModel = None # type: ignore[assignment, misc]
+ CLIPProcessor = None # type: ignore[assignment, misc]
+
+if TYPE_CHECKING or _pandas_available:
+ import pandas as pd
+else:
+ pd = None # type: ignore[assignment]
+
+if TYPE_CHECKING or _faker_available:
+ from faker import Faker
+else:
+ Faker = None # type: ignore[assignment, misc]
+
+from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.processor import (
+ AddBatchDimensionProcessorStep,
+ DeviceProcessorStep,
+ NormalizerProcessorStep,
+ PolicyProcessorPipeline,
+ ProcessorStep,
+ RenameObservationsProcessorStep,
+ from_tensor_to_numpy,
+ policy_action_to_transition,
+ transition_to_policy_action,
+)
+from lerobot.types import EnvTransition, PolicyAction, TransitionKey
+from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+
+from .configuration_sarm import SARMConfig
+from .sarm_utils import (
apply_rewind_augmentation,
compute_absolute_indices,
find_stage_and_tau,
pad_state_to_max_dim,
)
-from lerobot.processor import (
- AddBatchDimensionProcessorStep,
- DeviceProcessorStep,
- NormalizerProcessorStep,
- PolicyAction,
- PolicyProcessorPipeline,
- ProcessorStep,
- RenameObservationsProcessorStep,
-)
-from lerobot.processor.converters import (
- from_tensor_to_numpy,
- policy_action_to_transition,
- transition_to_policy_action,
-)
-from lerobot.processor.pipeline import PipelineFeatureType
-from lerobot.types import EnvTransition, TransitionKey
-from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
class SARMEncodingProcessorStep(ProcessorStep):
@@ -63,6 +82,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
dataset_meta=None,
dataset_stats: dict | None = None,
):
+ require_package("transformers", extra="sarm")
+ require_package("faker", extra="sarm")
+ require_package("pandas", extra="dataset")
super().__init__()
self.config = config
self.image_key = image_key or config.image_key
diff --git a/src/lerobot/policies/smolvla/__init__.py b/src/lerobot/policies/smolvla/__init__.py
new file mode 100644
index 000000000..690f15860
--- /dev/null
+++ b/src/lerobot/policies/smolvla/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_smolvla import SmolVLAConfig
+from .modeling_smolvla import SmolVLAPolicy
+from .processor_smolvla import make_smolvla_pre_post_processors
+
+__all__ = ["SmolVLAConfig", "SmolVLAPolicy", "make_smolvla_pre_post_processors"]
diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py
index 5007abbb4..6d5288db3 100644
--- a/src/lerobot/policies/smolvla/configuration_smolvla.py
+++ b/src/lerobot/policies/smolvla/configuration_smolvla.py
@@ -14,15 +14,12 @@
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import (
- CosineDecayWithWarmupSchedulerConfig,
-)
-from lerobot.policies.rtc.configuration_rtc import RTCConfig
+from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
+from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
+from ..rtc.configuration_rtc import RTCConfig
+
@PreTrainedConfig.register_subclass("smolvla")
@dataclass
diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py
index 7110ba7d2..ee3ff4db9 100644
--- a/src/lerobot/policies/smolvla/modeling_smolvla.py
+++ b/src/lerobot/policies/smolvla/modeling_smolvla.py
@@ -60,16 +60,17 @@ import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.rtc.modeling_rtc import RTCProcessor
-from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
-from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
-from lerobot.policies.utils import (
- populate_queues,
-)
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.utils.device_utils import get_safe_dtype
+from ..pretrained import PreTrainedPolicy
+from ..rtc.modeling_rtc import RTCProcessor
+from ..utils import (
+ populate_queues,
+)
+from .configuration_smolvla import SmolVLAConfig
+from .smolvlm_with_expert import SmolVLMWithExpertModel
+
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py
index 3fc130aa1..8d6c8aca4 100644
--- a/src/lerobot/policies/smolvla/processor_smolvla.py
+++ b/src/lerobot/policies/smolvla/processor_smolvla.py
@@ -18,23 +18,23 @@ from typing import Any
import torch
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
-from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
- ComplementaryDataProcessorStep,
DeviceProcessorStep,
+ NewLineTaskProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
- ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+from .configuration_smolvla import SmolVLAConfig
+
def make_smolvla_pre_post_processors(
config: SmolVLAConfig,
@@ -69,7 +69,7 @@ def make_smolvla_pre_post_processors(
input_steps = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
- SmolVLANewLineProcessor(),
+ NewLineTaskProcessorStep(),
TokenizerProcessorStep(
tokenizer_name=config.vlm_model_name,
padding=config.pad_language_to,
@@ -101,41 +101,3 @@ def make_smolvla_pre_post_processors(
to_output=transition_to_policy_action,
),
)
-
-
-@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
-class SmolVLANewLineProcessor(ComplementaryDataProcessorStep):
- """
- A processor step that ensures the 'task' description ends with a newline character.
-
- This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
- newline at the end of the prompt. It handles both single string tasks and lists
- of string tasks.
- """
-
- def complementary_data(self, complementary_data):
- if "task" not in complementary_data:
- return complementary_data
-
- task = complementary_data["task"]
- if task is None:
- return complementary_data
-
- new_complementary_data = dict(complementary_data)
-
- # Handle both string and list of strings
- if isinstance(task, str):
- # Single string: add newline if not present
- if not task.endswith("\n"):
- new_complementary_data["task"] = f"{task}\n"
- elif isinstance(task, list) and all(isinstance(t, str) for t in task):
- # List of strings: add newline to each if not present
- new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
- # If task is neither string nor list of strings, leave unchanged
-
- return new_complementary_data
-
- def transform_features(
- self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
- ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
- return features
diff --git a/src/lerobot/policies/smolvla/smolvlm_with_expert.py b/src/lerobot/policies/smolvla/smolvlm_with_expert.py
index caca41dab..ea806f185 100644
--- a/src/lerobot/policies/smolvla/smolvlm_with_expert.py
+++ b/src/lerobot/policies/smolvla/smolvlm_with_expert.py
@@ -13,16 +13,27 @@
# limitations under the License.
import copy
+from typing import TYPE_CHECKING
import torch
from torch import nn
-from transformers import (
- AutoConfig,
- AutoModel,
- AutoModelForImageTextToText,
- AutoProcessor,
- SmolVLMForConditionalGeneration,
-)
+
+from lerobot.utils.import_utils import _transformers_available, require_package
+
+if TYPE_CHECKING or _transformers_available:
+ from transformers import (
+ AutoConfig,
+ AutoModel,
+ AutoModelForImageTextToText,
+ AutoProcessor,
+ SmolVLMForConditionalGeneration,
+ )
+else:
+ AutoConfig = None
+ AutoModel = None
+ AutoModelForImageTextToText = None
+ AutoProcessor = None
+ SmolVLMForConditionalGeneration = None
def apply_rope(x, positions, max_wavelength=10_000):
@@ -73,6 +84,7 @@ class SmolVLMWithExpertModel(nn.Module):
device: str = "auto",
):
super().__init__()
+ require_package("transformers", extra="smolvla")
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
diff --git a/src/lerobot/policies/tdmpc/__init__.py b/src/lerobot/policies/tdmpc/__init__.py
new file mode 100644
index 000000000..5663e23c4
--- /dev/null
+++ b/src/lerobot/policies/tdmpc/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_tdmpc import TDMPCConfig
+from .modeling_tdmpc import TDMPCPolicy
+from .processor_tdmpc import make_tdmpc_pre_post_processors
+
+__all__ = ["TDMPCConfig", "TDMPCPolicy", "make_tdmpc_pre_post_processors"]
diff --git a/src/lerobot/policies/tdmpc/configuration_tdmpc.py b/src/lerobot/policies/tdmpc/configuration_tdmpc.py
index 3ec493472..bb8a2cf96 100644
--- a/src/lerobot/policies/tdmpc/configuration_tdmpc.py
+++ b/src/lerobot/policies/tdmpc/configuration_tdmpc.py
@@ -16,9 +16,8 @@
# limitations under the License.
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import NormalizationMode
-from lerobot.optim.optimizers import AdamConfig
+from lerobot.configs import NormalizationMode, PreTrainedConfig
+from lerobot.optim import AdamConfig
@PreTrainedConfig.register_subclass("tdmpc")
diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py
index f83c82e21..a50bb9670 100644
--- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py
+++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py
@@ -35,11 +35,12 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
-from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD
+from ..pretrained import PreTrainedPolicy
+from ..utils import get_device_from_parameters, get_output_shape, populate_queues
+from .configuration_tdmpc import TDMPCConfig
+
class TDMPCPolicy(PreTrainedPolicy):
"""Implementation of TD-MPC learning + inference.
diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py
index 9b6f97e50..7afe956dc 100644
--- a/src/lerobot/policies/tdmpc/processor_tdmpc.py
+++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py
@@ -18,7 +18,6 @@ from typing import Any
import torch
-from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -27,10 +26,13 @@ from lerobot.processor import (
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+from .configuration_tdmpc import TDMPCConfig
+
def make_tdmpc_pre_post_processors(
config: TDMPCConfig,
diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py
index 82ab51005..c37127813 100644
--- a/src/lerobot/policies/utils.py
+++ b/src/lerobot/policies/utils.py
@@ -21,11 +21,10 @@ import numpy as np
import torch
from torch import nn
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, PolicyFeature
-from lerobot.datasets.feature_utils import build_dataset_frame
+from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig
from lerobot.types import PolicyAction, RobotAction, RobotObservation
from lerobot.utils.constants import ACTION, OBS_STR
+from lerobot.utils.feature_utils import build_dataset_frame
def populate_queues(
diff --git a/src/lerobot/policies/vqbet/__init__.py b/src/lerobot/policies/vqbet/__init__.py
new file mode 100644
index 000000000..842dd5d0b
--- /dev/null
+++ b/src/lerobot/policies/vqbet/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration_vqbet import VQBeTConfig
+from .modeling_vqbet import VQBeTPolicy
+from .processor_vqbet import make_vqbet_pre_post_processors
+
+__all__ = ["VQBeTConfig", "VQBeTPolicy", "make_vqbet_pre_post_processors"]
diff --git a/src/lerobot/policies/vqbet/configuration_vqbet.py b/src/lerobot/policies/vqbet/configuration_vqbet.py
index 32906e528..d02745321 100644
--- a/src/lerobot/policies/vqbet/configuration_vqbet.py
+++ b/src/lerobot/policies/vqbet/configuration_vqbet.py
@@ -18,10 +18,8 @@
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import NormalizationMode
-from lerobot.optim.optimizers import AdamConfig
-from lerobot.optim.schedulers import VQBeTSchedulerConfig
+from lerobot.configs import NormalizationMode, PreTrainedConfig
+from lerobot.optim import AdamConfig, VQBeTSchedulerConfig
@PreTrainedConfig.register_subclass("vqbet")
diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py
index 6d3976b79..153f7fe3c 100644
--- a/src/lerobot/policies/vqbet/modeling_vqbet.py
+++ b/src/lerobot/policies/vqbet/modeling_vqbet.py
@@ -27,12 +27,13 @@ import torch.nn.functional as F # noqa: N812
import torchvision
from torch import Tensor, nn
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
-from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
-from lerobot.policies.vqbet.vqbet_utils import GPT, ResidualVQ
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
+from ..pretrained import PreTrainedPolicy
+from ..utils import get_device_from_parameters, get_output_shape, populate_queues
+from .configuration_vqbet import VQBeTConfig
+from .vqbet_utils import GPT, ResidualVQ
+
# ruff: noqa: N806
diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py
index 1e19ff779..f7b6a061e 100644
--- a/src/lerobot/policies/vqbet/processor_vqbet.py
+++ b/src/lerobot/policies/vqbet/processor_vqbet.py
@@ -19,7 +19,6 @@ from typing import Any
import torch
-from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -28,10 +27,13 @@ from lerobot.processor import (
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+from .configuration_vqbet import VQBeTConfig
+
def make_vqbet_pre_post_processors(
config: VQBeTConfig,
diff --git a/src/lerobot/policies/vqbet/vqbet_utils.py b/src/lerobot/policies/vqbet/vqbet_utils.py
index 7b13577f6..f8bfcb06a 100644
--- a/src/lerobot/policies/vqbet/vqbet_utils.py
+++ b/src/lerobot/policies/vqbet/vqbet_utils.py
@@ -30,7 +30,7 @@ from torch import einsum, nn
from torch.cuda.amp import autocast
from torch.optim import Optimizer
-from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
+from .configuration_vqbet import VQBeTConfig
# ruff: noqa: N806
diff --git a/src/lerobot/policies/wall_x/__init__.py b/src/lerobot/policies/wall_x/__init__.py
index d80c27bda..16fd2c8ab 100644
--- a/src/lerobot/policies/wall_x/__init__.py
+++ b/src/lerobot/policies/wall_x/__init__.py
@@ -15,5 +15,7 @@
# limitations under the License.
from .configuration_wall_x import WallXConfig
+from .modeling_wall_x import WallXPolicy
+from .processor_wall_x import make_wall_x_pre_post_processors
__all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"]
diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py
index 5269c4e10..70576a46b 100644
--- a/src/lerobot/policies/wall_x/configuration_wall_x.py
+++ b/src/lerobot/policies/wall_x/configuration_wall_x.py
@@ -14,10 +14,8 @@
from dataclasses import dataclass, field
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.optim.optimizers import AdamWConfig
-from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
+from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py
index 84ee05743..bfecf3852 100644
--- a/src/lerobot/policies/wall_x/modeling_wall_x.py
+++ b/src/lerobot/policies/wall_x/modeling_wall_x.py
@@ -34,35 +34,31 @@ lerobot-train \
```
"""
+import logging
import math
from collections import deque
from os import PathLike
-from typing import Any
+from typing import TYPE_CHECKING, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-from peft import LoraConfig, get_peft_model
from PIL import Image
-from qwen_vl_utils.vision_process import smart_resize
from torch import Tensor
from torch.distributions import Beta
from torch.nn import CrossEntropyLoss
-from torchdiffeq import odeint
-from transformers import AutoProcessor, BatchFeature
-from transformers.cache_utils import (
- StaticCache,
-)
-from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
- Qwen2_5_VLForConditionalGeneration,
-)
-from transformers.utils import is_torchdynamo_compiling, logging
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.utils import populate_queues
-from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
-from lerobot.policies.wall_x.constant import (
+from lerobot.utils.constants import ACTION, OBS_STATE
+from lerobot.utils.import_utils import (
+ _wallx_deps_available,
+ require_package,
+)
+
+from ..pretrained import PreTrainedPolicy
+from ..utils import populate_queues
+from .configuration_wall_x import WallXConfig
+from .constant import (
GENERATE_SUBTASK_RATIO,
IMAGE_FACTOR,
MAX_PIXELS,
@@ -72,21 +68,47 @@ from lerobot.policies.wall_x.constant import (
RESOLUTION,
TOKENIZER_MAX_LENGTH,
)
-from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
-from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import (
- Qwen2_5_VisionTransformerPretrainedModel,
- Qwen2_5_VLACausalLMOutputWithPast,
- Qwen2_5_VLMoEModel,
-)
-from lerobot.policies.wall_x.utils import (
+
+if TYPE_CHECKING or _wallx_deps_available:
+ from peft import LoraConfig, get_peft_model
+ from qwen_vl_utils.vision_process import smart_resize
+ from torchdiffeq import odeint
+ from transformers import AutoProcessor, BatchFeature
+ from transformers.cache_utils import StaticCache
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
+ Qwen2_5_VLForConditionalGeneration,
+ )
+ from transformers.utils import is_torchdynamo_compiling
+
+ from .qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig
+ from .qwen_model.qwen2_5_vl_moe import (
+ Qwen2_5_VisionTransformerPretrainedModel,
+ Qwen2_5_VLACausalLMOutputWithPast,
+ Qwen2_5_VLMoEModel,
+ )
+else:
+ LoraConfig = None
+ get_peft_model = None
+ smart_resize = None
+ odeint = None
+ AutoProcessor = None
+ BatchFeature = None
+ StaticCache = None
+ Qwen2_5_VLForConditionalGeneration = None
+ is_torchdynamo_compiling = None
+ Qwen2_5_VLConfig = None
+ Qwen2_5_VisionTransformerPretrainedModel = None
+ Qwen2_5_VLACausalLMOutputWithPast = None
+ Qwen2_5_VLMoEModel = None
+
+from .utils import (
get_wallx_normal_text,
preprocesser_call,
process_grounding_points,
replace_action_token,
)
-from lerobot.utils.constants import ACTION, OBS_STATE
-logger = logging.get_logger(__name__)
+logger = logging.getLogger(__name__)
class SinusoidalPosEmb(nn.Module):
@@ -253,7 +275,13 @@ class ActionHead(nn.Module):
return self.propri_proj(proprioception)
-class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
+# Conditional base: when transformers is unavailable the class still parses
+# (inheriting from nn.Module) but cannot be instantiated—require_package in
+# WallXPolicy.__init__ gives the user a clear error before that happens.
+_Qwen2_5_VLForAction_Base = Qwen2_5_VLForConditionalGeneration if _wallx_deps_available else nn.Module
+
+
+class Qwen2_5_VLMoEForAction(_Qwen2_5_VLForAction_Base):
"""
Qwen2.5 Vision-Language Mixture of Experts model for action processing.
@@ -1708,6 +1736,10 @@ class WallXPolicy(PreTrainedPolicy):
name = "wall_x"
def __init__(self, config: WallXConfig, **kwargs):
+ require_package("transformers", extra="wallx")
+ require_package("peft", extra="wallx")
+ require_package("torchdiffeq", extra="wallx")
+ require_package("qwen-vl-utils", extra="wallx", import_name="qwen_vl_utils")
super().__init__(config)
config.validate_features()
self.config = config
diff --git a/src/lerobot/policies/wall_x/processor_wall_x.py b/src/lerobot/policies/wall_x/processor_wall_x.py
index e4e281541..069cef5d6 100644
--- a/src/lerobot/policies/wall_x/processor_wall_x.py
+++ b/src/lerobot/policies/wall_x/processor_wall_x.py
@@ -18,8 +18,7 @@ from typing import Any
import torch
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
-from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
@@ -30,10 +29,13 @@ from lerobot.processor import (
ProcessorStepRegistry,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
+from .configuration_wall_x import WallXConfig
+
def make_wall_x_pre_post_processors(
config: WallXConfig,
diff --git a/src/lerobot/policies/wall_x/utils.py b/src/lerobot/policies/wall_x/utils.py
index e08ef69d5..d38a2d509 100644
--- a/src/lerobot/policies/wall_x/utils.py
+++ b/src/lerobot/policies/wall_x/utils.py
@@ -25,15 +25,22 @@ import random
import re
from collections import OrderedDict
from dataclasses import dataclass, field
-from typing import Any
+from typing import TYPE_CHECKING, Any
import torch
-from transformers import BatchFeature
-from lerobot.policies.wall_x.constant import (
+from lerobot.utils.import_utils import _transformers_available
+
+if TYPE_CHECKING or _transformers_available:
+ from transformers import BatchFeature
+else:
+ BatchFeature = None
+
+from lerobot.utils.constants import OBS_IMAGES
+
+from .constant import (
CAMERA_NAME_MAPPING,
)
-from lerobot.utils.constants import OBS_IMAGES
@dataclass
diff --git a/src/lerobot/policies/xvla/__init__.py b/src/lerobot/policies/xvla/__init__.py
index 71b04e76f..58609e91c 100644
--- a/src/lerobot/policies/xvla/__init__.py
+++ b/src/lerobot/policies/xvla/__init__.py
@@ -1,6 +1,15 @@
-# register the processor steps
-from lerobot.policies.xvla.processor_xvla import (
+from .configuration_xvla import XVLAConfig
+from .modeling_xvla import XVLAPolicy
+from .processor_xvla import (
XVLAAddDomainIdProcessorStep,
XVLAImageNetNormalizeProcessorStep,
XVLAImageToFloatProcessorStep,
)
+
+__all__ = [
+ "XVLAConfig",
+ "XVLAPolicy",
+ "XVLAAddDomainIdProcessorStep",
+ "XVLAImageNetNormalizeProcessorStep",
+ "XVLAImageToFloatProcessorStep",
+]
diff --git a/src/lerobot/policies/xvla/configuration_xvla.py b/src/lerobot/policies/xvla/configuration_xvla.py
index 30700b042..614c9a944 100644
--- a/src/lerobot/policies/xvla/configuration_xvla.py
+++ b/src/lerobot/policies/xvla/configuration_xvla.py
@@ -21,10 +21,8 @@ from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.optim.optimizers import XVLAAdamWConfig
-from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
+from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
+from lerobot.optim import CosineDecayWithWarmupSchedulerConfig, XVLAAdamWConfig
from lerobot.utils.constants import OBS_IMAGES
# Conditional import for type checking and lazy loading
diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py
index 0436ae527..04e923fdd 100644
--- a/src/lerobot/policies/xvla/modeling_xvla.py
+++ b/src/lerobot/policies/xvla/modeling_xvla.py
@@ -23,22 +23,30 @@ import logging
import os
from collections import deque
from pathlib import Path
+from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.policies.pretrained import PreTrainedPolicy, T
-from lerobot.policies.utils import populate_queues
+from lerobot.configs import PreTrainedConfig
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
+from lerobot.utils.import_utils import _transformers_available, require_package
+from ..pretrained import PreTrainedPolicy, T
+from ..utils import populate_queues
from .action_hub import build_action_space
-from .configuration_florence2 import Florence2Config
from .configuration_xvla import XVLAConfig
-from .modeling_florence2 import Florence2ForConditionalGeneration
from .soft_transformer import SoftPromptedTransformer
+# Florence2 config and modeling depend on transformers
+if TYPE_CHECKING or _transformers_available:
+ from .configuration_florence2 import Florence2Config
+ from .modeling_florence2 import Florence2ForConditionalGeneration
+else:
+ Florence2Config = None
+ Florence2ForConditionalGeneration = None
+
class XVLAModel(nn.Module):
"""
@@ -274,6 +282,7 @@ class XVLAPolicy(PreTrainedPolicy):
name = "xvla"
def __init__(self, config: XVLAConfig, **kwargs):
+ require_package("transformers", extra="xvla")
super().__init__(config)
config.validate_features()
florence_config = config.get_florence_config()
diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py
index 0fa9ffe3f..0336ec722 100644
--- a/src/lerobot/policies/xvla/processor_xvla.py
+++ b/src/lerobot/policies/xvla/processor_xvla.py
@@ -20,10 +20,7 @@ from typing import Any
import numpy as np
import torch
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
-from lerobot.datasets.factory import IMAGENET_STATS
-from lerobot.policies.xvla.configuration_xvla import XVLAConfig
-from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -36,10 +33,12 @@ from lerobot.processor import (
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
+ policy_action_to_transition,
+ transition_to_policy_action,
)
-from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import (
+ IMAGENET_STATS,
OBS_IMAGES,
OBS_PREFIX,
OBS_STATE,
@@ -47,6 +46,9 @@ from lerobot.utils.constants import (
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
+from .configuration_xvla import XVLAConfig
+from .utils import rotate6d_to_axis_angle
+
def make_xvla_pre_post_processors(
config: XVLAConfig,
diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py
index 122b3533c..3688a4b8c 100644
--- a/src/lerobot/processor/__init__.py
+++ b/src/lerobot/processor/__init__.py
@@ -27,10 +27,20 @@ from .batch_processor import AddBatchDimensionProcessorStep
from .converters import (
batch_to_transition,
create_transition,
+ from_tensor_to_numpy,
+ identity_transition,
+ observation_to_transition,
+ policy_action_to_transition,
+ robot_action_observation_to_transition,
+ robot_action_to_transition,
transition_to_batch,
+ transition_to_observation,
+ transition_to_policy_action,
+ transition_to_robot_action,
)
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .device_processor import DeviceProcessorStep
+from .env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
from .factory import (
make_default_processors,
make_default_robot_action_processor,
@@ -51,6 +61,7 @@ from .hil_processor import (
RewardClassifierProcessorStep,
TimeLimitProcessorStep,
)
+from .newline_task_processor import NewLineTaskProcessorStep
from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats
from .observation_processor import VanillaObservationProcessorStep
from .pipeline import (
@@ -81,7 +92,7 @@ from .relative_action_processor import (
to_absolute_actions,
to_relative_actions,
)
-from .rename_processor import RenameObservationsProcessorStep
+from .rename_processor import RenameObservationsProcessorStep, rename_stats
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
__all__ = [
@@ -91,6 +102,15 @@ __all__ = [
"ComplementaryDataProcessorStep",
"batch_to_transition",
"create_transition",
+ "from_tensor_to_numpy",
+ "identity_transition",
+ "observation_to_transition",
+ "policy_action_to_transition",
+ "robot_action_observation_to_transition",
+ "robot_action_to_transition",
+ "transition_to_observation",
+ "transition_to_policy_action",
+ "transition_to_robot_action",
"DeviceProcessorStep",
"DoneProcessorStep",
"EnvAction",
@@ -110,6 +130,7 @@ __all__ = [
"RelativeActionsProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
+ "NewLineTaskProcessorStep",
"NormalizerProcessorStep",
"Numpy2TorchActionProcessorStep",
"ObservationProcessorStep",
@@ -122,10 +143,13 @@ __all__ = [
"RobotAction",
"RobotActionProcessorStep",
"RobotObservation",
+ "rename_stats",
"RenameObservationsProcessorStep",
"RewardClassifierProcessorStep",
"RewardProcessorStep",
"DataProcessorPipeline",
+ "IsaaclabArenaProcessorStep",
+ "LiberoProcessorStep",
"TimeLimitProcessorStep",
"AddBatchDimensionProcessorStep",
"RobotProcessorPipeline",
diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py
index c904acf84..eb7db255a 100644
--- a/src/lerobot/processor/batch_processor.py
+++ b/src/lerobot/processor/batch_processor.py
@@ -24,7 +24,7 @@ from dataclasses import dataclass, field
from torch import Tensor
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, PolicyAction
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py
index f7f5676ac..86b2feec1 100644
--- a/src/lerobot/processor/delta_action_processor.py
+++ b/src/lerobot/processor/delta_action_processor.py
@@ -16,7 +16,7 @@
from dataclasses import dataclass
-from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.types import PolicyAction, RobotAction
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py
index 36c80e58e..1171c7e78 100644
--- a/src/lerobot/processor/device_processor.py
+++ b/src/lerobot/processor/device_processor.py
@@ -24,7 +24,7 @@ from typing import Any
import torch
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
from lerobot.utils.device_utils import get_safe_torch_device
diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py
index a77e066cf..75cbb79de 100644
--- a/src/lerobot/processor/env_processor.py
+++ b/src/lerobot/processor/env_processor.py
@@ -17,7 +17,7 @@ from dataclasses import dataclass
import torch
-from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py
index e756ded7f..2ec5f6e64 100644
--- a/src/lerobot/processor/gym_action_processor.py
+++ b/src/lerobot/processor/gym_action_processor.py
@@ -16,8 +16,8 @@
from dataclasses import dataclass
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
-from lerobot.types import EnvAction, EnvTransition, PolicyAction
+from lerobot.configs import PipelineFeatureType, PolicyFeature
+from lerobot.types import EnvAction, EnvTransition, PolicyAction, TransitionKey
from .converters import to_tensor
from .hil_processor import TELEOP_ACTION_KEY
@@ -75,8 +75,6 @@ class Numpy2TorchActionProcessorStep(ProcessorStep):
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Converts numpy action to torch tensor if action exists, otherwise passes through."""
- from lerobot.types import TransitionKey
-
self._current_transition = transition.copy()
new_transition = self._current_transition
diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py
index 0b8521c2b..c6f98c689 100644
--- a/src/lerobot/processor/hil_processor.py
+++ b/src/lerobot/processor/hil_processor.py
@@ -24,7 +24,7 @@ import numpy as np
import torch
import torchvision.transforms.functional as F # noqa: N812
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.teleoperators.utils import TeleopEvents
if TYPE_CHECKING:
diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py
index 525b7431c..37df4be41 100644
--- a/src/lerobot/processor/migrate_policy_normalization.py
+++ b/src/lerobot/processor/migrate_policy_normalization.py
@@ -57,8 +57,8 @@ import torch
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file as load_safetensors
-from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
-from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
+from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature
+from lerobot.policies import get_policy_class, make_policy_config, make_pre_post_processors
from lerobot.utils.constants import ACTION
diff --git a/src/lerobot/processor/newline_task_processor.py b/src/lerobot/processor/newline_task_processor.py
new file mode 100644
index 000000000..ea61bdd71
--- /dev/null
+++ b/src/lerobot/processor/newline_task_processor.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python
+
+# Copyright 2025 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from lerobot.configs import PipelineFeatureType, PolicyFeature
+
+from .pipeline import ComplementaryDataProcessorStep, ProcessorStepRegistry
+
+
+# NOTE: The registry name "smolvla_new_line_processor" is kept for backward compatibility
+# with serialized processor configs that reference this name.
+@ProcessorStepRegistry.register(name="smolvla_new_line_processor")
+class NewLineTaskProcessorStep(ComplementaryDataProcessorStep):
+ """
+ A processor step that ensures the 'task' description ends with a newline character.
+
+ This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a
+ newline at the end of the prompt. It handles both single string tasks and lists
+ of string tasks.
+ """
+
+ def complementary_data(self, complementary_data):
+ if "task" not in complementary_data:
+ return complementary_data
+
+ task = complementary_data["task"]
+ if task is None:
+ return complementary_data
+
+ new_complementary_data = dict(complementary_data)
+
+ # Handle both string and list of strings
+ if isinstance(task, str):
+ # Single string: add newline if not present
+ if not task.endswith("\n"):
+ new_complementary_data["task"] = f"{task}\n"
+ elif isinstance(task, list) and all(isinstance(t, str) for t in task):
+ # List of strings: add newline to each if not present
+ new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
+ # If task is neither string nor list of strings, leave unchanged
+
+ return new_complementary_data
+
+ def transform_features(
+ self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
+ ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
+ return features
diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py
index 8a7a1176a..7516c7b47 100644
--- a/src/lerobot/processor/normalize_processor.py
+++ b/src/lerobot/processor/normalize_processor.py
@@ -19,14 +19,17 @@ from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass, field
-from typing import Any
+from typing import TYPE_CHECKING, Any
import torch
from torch import Tensor
-from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.configs import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, PolicyAction, TransitionKey
+
+if TYPE_CHECKING:
+ from lerobot.datasets import LeRobotDataset
+
from lerobot.utils.constants import ACTION
from .converters import from_tensor_to_numpy, to_tensor
diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py
index d22d8fb96..12d1f82a2 100644
--- a/src/lerobot/processor/observation_processor.py
+++ b/src/lerobot/processor/observation_processor.py
@@ -20,7 +20,7 @@ import numpy as np
import torch
from torch import Tensor
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py
index abfb31421..2b949d5cb 100644
--- a/src/lerobot/processor/pipeline.py
+++ b/src/lerobot/processor/pipeline.py
@@ -45,8 +45,9 @@ import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
+from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.hub import HubMixin
from .converters import batch_to_transition, create_transition, transition_to_batch
@@ -422,8 +423,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
"""
if save_directory is None:
# Use default directory in HF_LEROBOT_HOME
- from lerobot.utils.constants import HF_LEROBOT_HOME
-
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name
diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py
index 25887d414..25d622dc2 100644
--- a/src/lerobot/processor/policy_robot_bridge.py
+++ b/src/lerobot/processor/policy_robot_bridge.py
@@ -19,10 +19,12 @@ from typing import Any
import torch
-from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
-from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction
+from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.types import PolicyAction, RobotAction
from lerobot.utils.constants import ACTION
+from .pipeline import ActionProcessorStep, ProcessorStepRegistry
+
@dataclass
@ProcessorStepRegistry.register("robot_action_to_policy_action_processor")
diff --git a/src/lerobot/processor/relative_action_processor.py b/src/lerobot/processor/relative_action_processor.py
index e00d26e98..d9f97f2c6 100644
--- a/src/lerobot/processor/relative_action_processor.py
+++ b/src/lerobot/processor/relative_action_processor.py
@@ -19,7 +19,7 @@ from typing import Any
import torch
from torch import Tensor
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.constants import OBS_STATE
diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py
index 6cae5921f..5ffec6868 100644
--- a/src/lerobot/processor/rename_processor.py
+++ b/src/lerobot/processor/rename_processor.py
@@ -17,7 +17,7 @@ from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
+from lerobot.configs import PipelineFeatureType, PolicyFeature
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py
index 0b5305dcf..a808e6127 100644
--- a/src/lerobot/processor/tokenizer_processor.py
+++ b/src/lerobot/processor/tokenizer_processor.py
@@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any
import torch
-from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.types import EnvTransition, RobotObservation, TransitionKey
from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
diff --git a/src/lerobot/rl/__init__.py b/src/lerobot/rl/__init__.py
new file mode 100644
index 000000000..6a7c750d3
--- /dev/null
+++ b/src/lerobot/rl/__init__.py
@@ -0,0 +1,34 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Reinforcement learning modules.
+
+Requires: ``pip install 'lerobot[hilserl]'``
+
+Available modules (import directly)::
+
+ from lerobot.rl.actor import ...
+ from lerobot.rl.learner import ...
+ from lerobot.rl.learner_service import ...
+ from lerobot.rl.buffer import ...
+ from lerobot.rl.eval_policy import ...
+ from lerobot.rl.gym_manipulator import ...
+"""
+
+from lerobot.utils.import_utils import require_package
+
+require_package("grpcio", extra="hilserl", import_name="grpc")
+
+__all__: list[str] = []
diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py
index 18c0ca1ea..0d785bde3 100644
--- a/src/lerobot/rl/actor.py
+++ b/src/lerobot/rl/actor.py
@@ -60,10 +60,8 @@ from torch.multiprocessing import Event, Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
-from lerobot.policies.factory import make_policy
+from lerobot.policies import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
-from lerobot.rl.process import ProcessSignalHandler
-from lerobot.rl.queue import get_last_item_from_queue
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
@@ -96,6 +94,8 @@ from .gym_manipulator import (
make_robot_env,
step_env_and_process_transition,
)
+from .process import ProcessSignalHandler
+from .queue import get_last_item_from_queue
# Main entry point
diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py
index 68954162d..97aaa9caa 100644
--- a/src/lerobot/rl/buffer.py
+++ b/src/lerobot/rl/buffer.py
@@ -23,7 +23,7 @@ import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, REWARD
from lerobot.utils.transition import Transition
diff --git a/src/lerobot/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py
index 4345fed3c..b6bde2273 100644
--- a/src/lerobot/rl/crop_dataset_roi.py
+++ b/src/lerobot/rl/crop_dataset_roi.py
@@ -24,7 +24,7 @@ import torch
import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import DONE, REWARD
diff --git a/src/lerobot/rl/eval_policy.py b/src/lerobot/rl/eval_policy.py
index fb2504f2a..4398351c5 100644
--- a/src/lerobot/rl/eval_policy.py
+++ b/src/lerobot/rl/eval_policy.py
@@ -18,8 +18,8 @@ import logging
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.policies.factory import make_policy
+from lerobot.datasets import LeRobotDataset
+from lerobot.policies import make_policy
from lerobot.robots import ( # noqa: F401
RobotConfig,
make_robot_from_config,
diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py
index bd64d205f..b6ff7155a 100644
--- a/src/lerobot/rl/gym_manipulator.py
+++ b/src/lerobot/rl/gym_manipulator.py
@@ -25,9 +25,9 @@ import torch
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.envs.configs import HILSerlRobotEnvConfig
-from lerobot.model.kinematics import RobotKinematics
+from lerobot.datasets import LeRobotDataset
+from lerobot.envs import HILSerlRobotEnvConfig
+from lerobot.model import RobotKinematics
from lerobot.processor import (
AddBatchDimensionProcessorStep,
AddTeleopActionAsComplimentaryDataStep,
@@ -50,8 +50,8 @@ from lerobot.processor import (
TransitionKey,
VanillaObservationProcessorStep,
create_transition,
+ identity_transition,
)
-from lerobot.processor.converters import identity_transition
from lerobot.robots import ( # noqa: F401
RobotConfig,
make_robot_from_config,
diff --git a/src/lerobot/rl/joint_observations_processor.py b/src/lerobot/rl/joint_observations_processor.py
index 2fbcc7c46..dc677e26c 100644
--- a/src/lerobot/rl/joint_observations_processor.py
+++ b/src/lerobot/rl/joint_observations_processor.py
@@ -19,8 +19,8 @@ from typing import Any
import torch
-from lerobot.configs.types import PipelineFeatureType, PolicyFeature
-from lerobot.processor.pipeline import (
+from lerobot.configs import PipelineFeatureType, PolicyFeature
+from lerobot.processor import (
ObservationProcessorStep,
ProcessorStepRegistry,
)
diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py
index 2853fbcb3..073d9a65f 100644
--- a/src/lerobot/rl/learner.py
+++ b/src/lerobot/rl/learner.py
@@ -60,15 +60,18 @@ from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.cameras import opencv # noqa: F401
+from lerobot.common.train_utils import (
+ get_step_checkpoint_dir,
+ load_training_state as utils_load_training_state,
+ save_checkpoint,
+ update_last_checkpoint,
+)
+from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
-from lerobot.datasets.factory import make_dataset
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.policies.factory import make_policy
+from lerobot.datasets import LeRobotDataset, make_dataset
+from lerobot.policies import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
-from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
-from lerobot.rl.process import ProcessSignalHandler
-from lerobot.rl.wandb_utils import WandBLogger
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
@@ -88,19 +91,15 @@ from lerobot.utils.constants import (
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.random_utils import set_seed
-from lerobot.utils.train_utils import (
- get_step_checkpoint_dir,
- load_training_state as utils_load_training_state,
- save_checkpoint,
- update_last_checkpoint,
-)
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
init_logging,
)
+from .buffer import ReplayBuffer, concatenate_batch_transitions
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
+from .process import ProcessSignalHandler
@parser.wrap()
@@ -152,7 +151,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None):
# Setup WandB logging if enabled
if cfg.wandb.enable and cfg.wandb.project:
- from lerobot.rl.wandb_utils import WandBLogger
+ from lerobot.common.wandb_utils import WandBLogger
wandb_logger = WandBLogger(cfg)
else:
diff --git a/src/lerobot/rl/learner_service.py b/src/lerobot/rl/learner_service.py
index 7ef38119b..4128cdf55 100644
--- a/src/lerobot/rl/learner_service.py
+++ b/src/lerobot/rl/learner_service.py
@@ -19,10 +19,11 @@ import logging
import time
from multiprocessing import Event, Queue
-from lerobot.rl.queue import get_last_item_from_queue
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
+from .queue import get_last_item_from_queue
+
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
diff --git a/src/lerobot/robots/__init__.py b/src/lerobot/robots/__init__.py
index 1dba0f1b0..eb8b06fb8 100644
--- a/src/lerobot/robots/__init__.py
+++ b/src/lerobot/robots/__init__.py
@@ -17,3 +17,5 @@
from .config import RobotConfig
from .robot import Robot
from .utils import make_robot_from_config
+
+__all__ = ["Robot", "RobotConfig", "make_robot_from_config"]
diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py
index c48ac5934..c27398278 100644
--- a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py
+++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py
@@ -17,10 +17,10 @@
import logging
from functools import cached_property
-from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
+from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from ..robot import Robot
from .config_bi_openarm_follower import BiOpenArmFollowerConfig
diff --git a/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py
index ef5d70cab..9ed56aeac 100644
--- a/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py
+++ b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py
@@ -17,9 +17,9 @@
from dataclasses import dataclass, field
from lerobot.cameras import CameraConfig
-from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase
from ..config import RobotConfig
+from ..openarm_follower import OpenArmFollowerConfigBase
@RobotConfig.register_subclass("bi_openarm_follower")
diff --git a/src/lerobot/robots/bi_so_follower/__init__.py b/src/lerobot/robots/bi_so_follower/__init__.py
index f631a14db..1d63dcb2c 100644
--- a/src/lerobot/robots/bi_so_follower/__init__.py
+++ b/src/lerobot/robots/bi_so_follower/__init__.py
@@ -16,3 +16,5 @@
from .bi_so_follower import BiSOFollower
from .config_bi_so_follower import BiSOFollowerConfig
+
+__all__ = ["BiSOFollower", "BiSOFollowerConfig"]
diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py
index ba1826e29..f592150a6 100644
--- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py
+++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py
@@ -17,11 +17,11 @@
import logging
from functools import cached_property
-from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..robot import Robot
+from ..so_follower import SOFollower, SOFollowerRobotConfig
from .config_bi_so_follower import BiSOFollowerConfig
logger = logging.getLogger(__name__)
diff --git a/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py b/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py
index dca74fa2d..97afbab4f 100644
--- a/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py
+++ b/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py
@@ -16,9 +16,8 @@
from dataclasses import dataclass
-from lerobot.robots.so_follower import SOFollowerConfig
-
from ..config import RobotConfig
+from ..so_follower import SOFollowerConfig
@RobotConfig.register_subclass("bi_so_follower")
diff --git a/src/lerobot/robots/hope_jr/__init__.py b/src/lerobot/robots/hope_jr/__init__.py
index 26603ebb0..94fcf86e4 100644
--- a/src/lerobot/robots/hope_jr/__init__.py
+++ b/src/lerobot/robots/hope_jr/__init__.py
@@ -17,3 +17,5 @@
from .config_hope_jr import HopeJrArmConfig, HopeJrHandConfig
from .hope_jr_arm import HopeJrArm
from .hope_jr_hand import HopeJrHand
+
+__all__ = ["HopeJrArm", "HopeJrArmConfig", "HopeJrHand", "HopeJrHandConfig"]
diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py
index 7f6492ef0..4918bcae3 100644
--- a/src/lerobot/robots/hope_jr/hope_jr_arm.py
+++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py
@@ -18,7 +18,7 @@ import logging
import time
from functools import cached_property
-from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.cameras import make_cameras_from_configs
from lerobot.motors import Motor, MotorNormMode
from lerobot.motors.calibration_gui import RangeFinderGUI
from lerobot.motors.feetech import (
diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py
index 784804836..566628724 100644
--- a/src/lerobot/robots/hope_jr/hope_jr_hand.py
+++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py
@@ -18,7 +18,7 @@ import logging
import time
from functools import cached_property
-from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.cameras import make_cameras_from_configs
from lerobot.motors import Motor, MotorNormMode
from lerobot.motors.calibration_gui import RangeFinderGUI
from lerobot.motors.feetech import (
diff --git a/src/lerobot/robots/koch_follower/__init__.py b/src/lerobot/robots/koch_follower/__init__.py
index 6271c4e55..8f4435924 100644
--- a/src/lerobot/robots/koch_follower/__init__.py
+++ b/src/lerobot/robots/koch_follower/__init__.py
@@ -16,3 +16,5 @@
from .config_koch_follower import KochFollowerConfig
from .koch_follower import KochFollower
+
+__all__ = ["KochFollower", "KochFollowerConfig"]
diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py
index 44e83f6a3..3f40ac738 100644
--- a/src/lerobot/robots/koch_follower/koch_follower.py
+++ b/src/lerobot/robots/koch_follower/koch_follower.py
@@ -18,7 +18,7 @@ import logging
import time
from functools import cached_property
-from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.cameras import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DynamixelMotorsBus,
diff --git a/src/lerobot/robots/lekiwi/__init__.py b/src/lerobot/robots/lekiwi/__init__.py
index ada2ff368..3d2191242 100644
--- a/src/lerobot/robots/lekiwi/__init__.py
+++ b/src/lerobot/robots/lekiwi/__init__.py
@@ -17,3 +17,5 @@
from .config_lekiwi import LeKiwiClientConfig, LeKiwiConfig
from .lekiwi import LeKiwi
from .lekiwi_client import LeKiwiClient
+
+__all__ = ["LeKiwi", "LeKiwiClient", "LeKiwiClientConfig", "LeKiwiConfig"]
diff --git a/src/lerobot/robots/lekiwi/config_lekiwi.py b/src/lerobot/robots/lekiwi/config_lekiwi.py
index acaf5f0ec..51fa8f03f 100644
--- a/src/lerobot/robots/lekiwi/config_lekiwi.py
+++ b/src/lerobot/robots/lekiwi/config_lekiwi.py
@@ -14,8 +14,8 @@
from dataclasses import dataclass, field
-from lerobot.cameras.configs import CameraConfig, Cv2Rotation
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
+from lerobot.cameras import CameraConfig, Cv2Rotation
+from lerobot.cameras.opencv import OpenCVCameraConfig
from ..config import RobotConfig
diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py
index 60fac89e5..b73ebeab9 100644
--- a/src/lerobot/robots/lekiwi/lekiwi.py
+++ b/src/lerobot/robots/lekiwi/lekiwi.py
@@ -22,7 +22,7 @@ from typing import Any
import numpy as np
-from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.cameras import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
diff --git a/src/lerobot/robots/omx_follower/__init__.py b/src/lerobot/robots/omx_follower/__init__.py
index db48dffe9..328ac8d80 100644
--- a/src/lerobot/robots/omx_follower/__init__.py
+++ b/src/lerobot/robots/omx_follower/__init__.py
@@ -19,3 +19,5 @@
from .config_omx_follower import OmxFollowerConfig
from .omx_follower import OmxFollower
+
+__all__ = ["OmxFollower", "OmxFollowerConfig"]
diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py
index 5d161daa2..c30eec97a 100644
--- a/src/lerobot/robots/omx_follower/omx_follower.py
+++ b/src/lerobot/robots/omx_follower/omx_follower.py
@@ -18,7 +18,7 @@ import logging
import time
from functools import cached_property
-from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.cameras import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.dynamixel import (
DriveMode,
diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py
index 99e8b920b..4d1765f07 100644
--- a/src/lerobot/robots/openarm_follower/openarm_follower.py
+++ b/src/lerobot/robots/openarm_follower/openarm_follower.py
@@ -19,7 +19,7 @@ import time
from functools import cached_property
from typing import Any
-from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.cameras import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.damiao import DamiaoMotorsBus
from lerobot.types import RobotAction, RobotObservation
diff --git a/src/lerobot/robots/reachy2/__init__.py b/src/lerobot/robots/reachy2/__init__.py
index 1a38fd03b..b7afd006d 100644
--- a/src/lerobot/robots/reachy2/__init__.py
+++ b/src/lerobot/robots/reachy2/__init__.py
@@ -23,3 +23,13 @@ from .robot_reachy2 import (
REACHY2_VEL,
Reachy2Robot,
)
+
+__all__ = [
+ "REACHY2_ANTENNAS_JOINTS",
+ "REACHY2_L_ARM_JOINTS",
+ "REACHY2_NECK_JOINTS",
+ "REACHY2_R_ARM_JOINTS",
+ "REACHY2_VEL",
+ "Reachy2Robot",
+ "Reachy2RobotConfig",
+]
diff --git a/src/lerobot/robots/reachy2/configuration_reachy2.py b/src/lerobot/robots/reachy2/configuration_reachy2.py
index 63293e675..8cb67a495 100644
--- a/src/lerobot/robots/reachy2/configuration_reachy2.py
+++ b/src/lerobot/robots/reachy2/configuration_reachy2.py
@@ -14,8 +14,7 @@
from dataclasses import dataclass, field
-from lerobot.cameras import CameraConfig
-from lerobot.cameras.configs import ColorMode
+from lerobot.cameras import CameraConfig, ColorMode
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig
from ..config import RobotConfig
diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py
index 5227a096a..ef55f71b9 100644
--- a/src/lerobot/robots/reachy2/robot_reachy2.py
+++ b/src/lerobot/robots/reachy2/robot_reachy2.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING, Any
-from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.cameras import make_cameras_from_configs
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.import_utils import _reachy2_sdk_available
diff --git a/src/lerobot/robots/so_follower/__init__.py b/src/lerobot/robots/so_follower/__init__.py
index eea2fcbdf..45de205a8 100644
--- a/src/lerobot/robots/so_follower/__init__.py
+++ b/src/lerobot/robots/so_follower/__init__.py
@@ -21,3 +21,13 @@ from .config_so_follower import (
SOFollowerRobotConfig,
)
from .so_follower import SO100Follower, SO101Follower, SOFollower
+
+__all__ = [
+ "SO100Follower",
+ "SO100FollowerConfig",
+ "SO101Follower",
+ "SO101FollowerConfig",
+ "SOFollower",
+ "SOFollowerConfig",
+ "SOFollowerRobotConfig",
+]
diff --git a/src/lerobot/robots/so_follower/robot_kinematic_processor.py b/src/lerobot/robots/so_follower/robot_kinematic_processor.py
index 2aa60e12a..8114fdc2c 100644
--- a/src/lerobot/robots/so_follower/robot_kinematic_processor.py
+++ b/src/lerobot/robots/so_follower/robot_kinematic_processor.py
@@ -19,8 +19,8 @@ from typing import Any
import numpy as np
-from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
-from lerobot.model.kinematics import RobotKinematics
+from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.model import RobotKinematics
from lerobot.processor import (
EnvTransition,
ObservationProcessorStep,
diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py
index ca132d102..0651f566c 100644
--- a/src/lerobot/robots/so_follower/so_follower.py
+++ b/src/lerobot/robots/so_follower/so_follower.py
@@ -18,7 +18,7 @@ import logging
import time
from functools import cached_property
-from lerobot.cameras.utils import make_cameras_from_configs
+from lerobot.cameras import make_cameras_from_configs
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
diff --git a/src/lerobot/robots/unitree_g1/gr00t_locomotion.py b/src/lerobot/robots/unitree_g1/gr00t_locomotion.py
index 31166e123..12fe26073 100644
--- a/src/lerobot/robots/unitree_g1/gr00t_locomotion.py
+++ b/src/lerobot/robots/unitree_g1/gr00t_locomotion.py
@@ -21,7 +21,7 @@ import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
-from lerobot.robots.unitree_g1.g1_utils import (
+from .g1_utils import (
REMOTE_AXES,
REMOTE_BUTTONS,
G1_29_JointIndex,
diff --git a/src/lerobot/robots/unitree_g1/holosoma_locomotion.py b/src/lerobot/robots/unitree_g1/holosoma_locomotion.py
index 857bb97bc..3d3bccbdc 100644
--- a/src/lerobot/robots/unitree_g1/holosoma_locomotion.py
+++ b/src/lerobot/robots/unitree_g1/holosoma_locomotion.py
@@ -22,7 +22,7 @@ import onnx
import onnxruntime as ort
from huggingface_hub import hf_hub_download
-from lerobot.robots.unitree_g1.g1_utils import (
+from .g1_utils import (
REMOTE_AXES,
G1_29_JointArmIndex,
G1_29_JointIndex,
diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py
index 9e373c05f..785861a5a 100644
--- a/src/lerobot/robots/unitree_g1/unitree_g1.py
+++ b/src/lerobot/robots/unitree_g1/unitree_g1.py
@@ -25,9 +25,14 @@ from typing import TYPE_CHECKING, Protocol, runtime_checkable
import numpy as np
-from lerobot.cameras.utils import make_cameras_from_configs
-from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK
-from lerobot.robots.unitree_g1.g1_utils import (
+from lerobot.cameras import make_cameras_from_configs
+from lerobot.types import RobotAction, RobotObservation
+from lerobot.utils.import_utils import _unitree_sdk_available
+
+from ..robot import Robot
+from .config_unitree_g1 import UnitreeG1Config
+from .g1_kinematics import G1_29_ArmIK
+from .g1_utils import (
REMOTE_AXES,
REMOTE_KEYS,
G1_29_JointArmIndex,
@@ -35,11 +40,6 @@ from lerobot.robots.unitree_g1.g1_utils import (
default_remote_input,
make_locomotion_controller,
)
-from lerobot.types import RobotAction, RobotObservation
-from lerobot.utils.import_utils import _unitree_sdk_available
-
-from ..robot import Robot
-from .config_unitree_g1 import UnitreeG1Config
if TYPE_CHECKING or _unitree_sdk_available:
from unitree_sdk2py.core.channel import (
@@ -127,7 +127,7 @@ class UnitreeG1(Robot):
self._ChannelPublisher = _SDKChannelPublisher
self._ChannelSubscriber = _SDKChannelSubscriber
else:
- from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
+ from .unitree_sdk2_socket import (
ChannelFactoryInitialize,
ChannelPublisher,
ChannelSubscriber,
@@ -290,7 +290,7 @@ class UnitreeG1(Robot):
def connect(self, calibrate: bool = True) -> None: # connect to DDS
# Initialize DDS channel and simulation environment
if self.config.is_simulation:
- from lerobot.envs.factory import make_env
+ from lerobot.envs import make_env
self._ChannelFactoryInitialize(0, "lo")
self._env_wrapper = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True)
diff --git a/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py b/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py
index 0f1f8f8d6..4f0b787aa 100644
--- a/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py
+++ b/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py
@@ -20,7 +20,7 @@ from typing import Any
import zmq
-from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
+from .config_unitree_g1 import UnitreeG1Config
# Module-level ZMQ state mirrors the Unitree SDK's global ChannelFactory Singleton.
# Only one robot connection per process is supported.
diff --git a/src/lerobot/scripts/augment_dataset_quantile_stats.py b/src/lerobot/scripts/augment_dataset_quantile_stats.py
index 4d80c9332..4ee99a541 100644
--- a/src/lerobot/scripts/augment_dataset_quantile_stats.py
+++ b/src/lerobot/scripts/augment_dataset_quantile_stats.py
@@ -44,10 +44,14 @@ from huggingface_hub import HfApi
from requests import HTTPError
from tqdm import tqdm
-from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
-from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
-from lerobot.datasets.io_utils import write_stats
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import (
+ CODEBASE_VERSION,
+ DEFAULT_QUANTILES,
+ LeRobotDataset,
+ aggregate_stats,
+ get_feature_stats,
+ write_stats,
+)
from lerobot.utils.utils import init_logging
diff --git a/src/lerobot/scripts/convert_dataset_v21_to_v30.py b/src/lerobot/scripts/convert_dataset_v21_to_v30.py
index 2b6dcf732..59e635712 100644
--- a/src/lerobot/scripts/convert_dataset_v21_to_v30.py
+++ b/src/lerobot/scripts/convert_dataset_v21_to_v30.py
@@ -51,6 +51,10 @@ import shutil
from pathlib import Path
from typing import Any
+from lerobot.utils.import_utils import require_package
+
+require_package("jsonlines", extra="dataset")
+
import jsonlines
import pandas as pd
import pyarrow as pa
@@ -59,8 +63,7 @@ from datasets import Dataset, Features, Image
from huggingface_hub import HfApi, snapshot_download
from requests import HTTPError
-from lerobot.datasets.compute_stats import aggregate_stats
-from lerobot.datasets.dataset_metadata import CODEBASE_VERSION
+from lerobot.datasets import CODEBASE_VERSION, LeRobotDataset, aggregate_stats
from lerobot.datasets.io_utils import (
cast_stats_to_numpy,
get_file_size_in_mb,
@@ -72,7 +75,6 @@ from lerobot.datasets.io_utils import (
write_stats,
write_tasks,
)
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -82,12 +84,11 @@ from lerobot.datasets.utils import (
LEGACY_EPISODES_PATH,
LEGACY_EPISODES_STATS_PATH,
LEGACY_TASKS_PATH,
- flatten_dict,
update_chunk_file_indices,
)
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
from lerobot.utils.constants import HF_LEROBOT_HOME
-from lerobot.utils.utils import init_logging
+from lerobot.utils.utils import flatten_dict, init_logging
V21 = "v2.1"
V30 = "v3.0"
diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py
index 242067978..e68d7438b 100644
--- a/src/lerobot/scripts/lerobot_calibrate.py
+++ b/src/lerobot/scripts/lerobot_calibrate.py
@@ -15,6 +15,8 @@
"""
Helper to recalibrate your device (robot or teleoperator).
+Requires: pip install 'lerobot[hardware]'
+
Example:
```shell
@@ -31,8 +33,8 @@ from pprint import pformat
import draccus
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
-from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
+from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
+from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py
index c4b676c67..d07a2767d 100644
--- a/src/lerobot/scripts/lerobot_dataset_viz.py
+++ b/src/lerobot/scripts/lerobot_dataset_viz.py
@@ -15,6 +15,8 @@
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
+Requires: pip install 'lerobot[dataset_viz]' (includes dataset + viz extras)
+
Note: The last frame of the episode doesn't always correspond to a final state.
That's because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state.
@@ -66,12 +68,11 @@ import time
from pathlib import Path
import numpy as np
-import rerun as rr
import torch
import torch.utils.data
import tqdm
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.utils import init_logging
@@ -117,6 +118,11 @@ def visualize_dataset(
if mode not in ["local", "distant"]:
raise ValueError(mode)
+ from lerobot.utils.import_utils import require_package
+
+ require_package("rerun-sdk", extra="viz", import_name="rerun")
+ import rerun as rr
+
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py
index db06f90c6..0cfb34325 100644
--- a/src/lerobot/scripts/lerobot_edit_dataset.py
+++ b/src/lerobot/scripts/lerobot_edit_dataset.py
@@ -17,6 +17,8 @@
"""
Edit LeRobot datasets using various transformation tools.
+Requires: pip install 'lerobot[dataset]'
+
This script allows you to delete episodes, split datasets, merge datasets,
remove features, modify tasks, recompute stats, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset.
@@ -178,7 +180,8 @@ from pathlib import Path
import draccus
from lerobot.configs import parser
-from lerobot.datasets.dataset_tools import (
+from lerobot.datasets import (
+ LeRobotDataset,
convert_image_to_video_dataset,
delete_episodes,
merge_datasets,
@@ -187,7 +190,6 @@ from lerobot.datasets.dataset_tools import (
remove_feature,
split_dataset,
)
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py
index cd912280f..d45483d21 100644
--- a/src/lerobot/scripts/lerobot_eval.py
+++ b/src/lerobot/scripts/lerobot_eval.py
@@ -15,6 +15,9 @@
# limitations under the License.
"""Evaluate a policy on an environment by running rollouts and computing metrics.
+Requires: pip install 'lerobot[evaluation]' plus the policy extra (e.g. lerobot[pi])
+ and the environment extra (e.g. lerobot[pusht]) if evaluating in simulation.
+
Usage examples:
You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_pusht)
@@ -71,14 +74,14 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
-from lerobot.envs.factory import make_env, make_env_pre_post_processors
-from lerobot.envs.utils import (
+from lerobot.envs import (
check_env_attributes_and_types,
close_envs,
+ make_env,
+ make_env_pre_post_processors,
preprocess_observation,
)
-from lerobot.policies.factory import make_policy, make_pre_post_processors
-from lerobot.policies.pretrained import PreTrainedPolicy
+from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
diff --git a/src/lerobot/scripts/lerobot_find_cameras.py b/src/lerobot/scripts/lerobot_find_cameras.py
index 0248a2768..72f4096da 100644
--- a/src/lerobot/scripts/lerobot_find_cameras.py
+++ b/src/lerobot/scripts/lerobot_find_cameras.py
@@ -37,11 +37,9 @@ from typing import Any
import numpy as np
from PIL import Image
-from lerobot.cameras.configs import ColorMode
-from lerobot.cameras.opencv.camera_opencv import OpenCVCamera
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
-from lerobot.cameras.realsense.camera_realsense import RealSenseCamera
-from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig
+from lerobot.cameras import ColorMode
+from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
+from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
logger = logging.getLogger(__name__)
diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py
index bcb93ba12..c4f867631 100644
--- a/src/lerobot/scripts/lerobot_find_joint_limits.py
+++ b/src/lerobot/scripts/lerobot_find_joint_limits.py
@@ -41,7 +41,7 @@ from dataclasses import dataclass
import draccus
import numpy as np
-from lerobot.model.kinematics import RobotKinematics
+from lerobot.model import RobotKinematics
from lerobot.robots import ( # noqa: F401
RobotConfig,
bi_openarm_follower,
diff --git a/src/lerobot/scripts/lerobot_find_port.py b/src/lerobot/scripts/lerobot_find_port.py
index e32b9cb99..93065c473 100644
--- a/src/lerobot/scripts/lerobot_find_port.py
+++ b/src/lerobot/scripts/lerobot_find_port.py
@@ -28,7 +28,10 @@ from pathlib import Path
def find_available_ports():
- from serial.tools import list_ports # Part of pyserial library
+ from lerobot.utils.import_utils import require_package
+
+ require_package("pyserial", extra="hardware", import_name="serial")
+ from serial.tools import list_ports
if platform.system() == "Windows":
# List COM ports using pyserial
diff --git a/src/lerobot/scripts/lerobot_imgtransform_viz.py b/src/lerobot/scripts/lerobot_imgtransform_viz.py
index bc13f0508..7cd4c782d 100644
--- a/src/lerobot/scripts/lerobot_imgtransform_viz.py
+++ b/src/lerobot/scripts/lerobot_imgtransform_viz.py
@@ -35,9 +35,9 @@ from pathlib import Path
import draccus
from torchvision.transforms import ToPILImage
-from lerobot.configs.default import DatasetConfig
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.transforms import (
+from lerobot.configs import DatasetConfig
+from lerobot.datasets import LeRobotDataset
+from lerobot.transforms import (
ImageTransforms,
ImageTransformsConfig,
make_transform_from_config,
diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py
index c58f8f103..fa92a296d 100644
--- a/src/lerobot/scripts/lerobot_record.py
+++ b/src/lerobot/scripts/lerobot_record.py
@@ -15,6 +15,8 @@
"""
Records a dataset. Actions for the robot can be either generated by teleoperation or by a policy.
+Requires: pip install 'lerobot[core_scripts]' (includes dataset + hardware + viz extras)
+
Example:
```shell
@@ -76,24 +78,33 @@ from typing import Any
import torch
-from lerobot.cameras import ( # noqa: F401
- CameraConfig, # noqa: F401
+from lerobot.cameras import CameraConfig # noqa: F401
+from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
+from lerobot.cameras.reachy2_camera import Reachy2CameraConfig # noqa: F401
+from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
+from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
+from lerobot.common.control_utils import (
+ init_keyboard_listener,
+ is_headless,
+ predict_action,
+ sanity_check_dataset_name,
+ sanity_check_dataset_robot_compatibility,
+)
+from lerobot.configs import PreTrainedConfig, parser
+from lerobot.datasets import (
+ LeRobotDataset,
+ VideoEncodingManager,
+ aggregate_pipeline_dataset_features,
+ create_initial_features,
+ safe_stop_image_writer,
+)
+from lerobot.policies import (
+ ActionInterpolator,
+ PreTrainedPolicy,
+ make_policy,
+ make_pre_post_processors,
+ make_robot_action,
)
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
-from lerobot.cameras.reachy2_camera.configuration_reachy2_camera import Reachy2CameraConfig # noqa: F401
-from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
-from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
-from lerobot.configs import parser
-from lerobot.configs.policies import PreTrainedConfig
-from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts
-from lerobot.datasets.image_writer import safe_stop_image_writer
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
-from lerobot.datasets.video_utils import VideoEncodingManager
-from lerobot.policies.factory import make_policy, make_pre_post_processors
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.policies.rtc import ActionInterpolator
-from lerobot.policies.utils import make_robot_action
from lerobot.processor import (
PolicyAction,
PolicyProcessorPipeline,
@@ -101,8 +112,8 @@ from lerobot.processor import (
RobotObservation,
RobotProcessorPipeline,
make_default_processors,
+ rename_stats,
)
-from lerobot.processor.rename_processor import rename_stats
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -133,16 +144,10 @@ from lerobot.teleoperators import ( # noqa: F401
so_leader,
unitree_g1,
)
-from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
+from lerobot.teleoperators.keyboard import KeyboardTeleop
from lerobot.utils.constants import ACTION, OBS_STR
-from lerobot.utils.control_utils import (
- init_keyboard_listener,
- is_headless,
- predict_action,
- sanity_check_dataset_name,
- sanity_check_dataset_robot_compatibility,
-)
from lerobot.utils.device_utils import get_safe_torch_device
+from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import (
diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py
index 09e7d4e8b..41d2926cc 100644
--- a/src/lerobot/scripts/lerobot_replay.py
+++ b/src/lerobot/scripts/lerobot_replay.py
@@ -15,6 +15,8 @@
"""
Replays the actions of an episode from a dataset on a robot.
+Requires: pip install 'lerobot[core_scripts]' (includes dataset + hardware + viz extras)
+
Examples:
```shell
@@ -46,7 +48,7 @@ from pathlib import Path
from pprint import pformat
from lerobot.configs import parser
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.datasets import LeRobotDataset
from lerobot.processor import (
make_default_robot_action_processor,
)
diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py
index f050d572a..76157595e 100644
--- a/src/lerobot/scripts/lerobot_teleoperate.py
+++ b/src/lerobot/scripts/lerobot_teleoperate.py
@@ -15,6 +15,8 @@
"""
Simple script to control a robot from teleoperation.
+Requires: pip install 'lerobot[hardware]'
+
Example:
```shell
@@ -56,11 +58,9 @@ import time
from dataclasses import asdict, dataclass
from pprint import pformat
-import rerun as rr
-
-from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
-from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
-from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
+from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
+from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
+from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.processor import (
RobotAction,
@@ -103,7 +103,7 @@ from lerobot.teleoperators import ( # noqa: F401
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging, move_cursor_up
-from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
@dataclass
@@ -240,7 +240,7 @@ def teleoperate(cfg: TeleoperateConfig):
pass
finally:
if cfg.display_data:
- rr.rerun_shutdown()
+ shutdown_rerun()
teleop.disconnect()
robot.disconnect()
diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py
index 0a7212911..a862c640d 100644
--- a/src/lerobot/scripts/lerobot_train.py
+++ b/src/lerobot/scripts/lerobot_train.py
@@ -13,48 +13,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Train a policy.
+
+Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wandb extras)
+"""
+
import dataclasses
import logging
import time
from contextlib import nullcontext
from pprint import pformat
-from typing import Any
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from accelerate import Accelerator
import torch
-from accelerate import Accelerator
from termcolor import colored
from torch.optim import Optimizer
from tqdm import tqdm
-from lerobot.configs import parser
-from lerobot.configs.train import TrainPipelineConfig
-from lerobot.datasets.factory import make_dataset
-from lerobot.datasets.sampler import EpisodeAwareSampler
-from lerobot.datasets.utils import cycle
-from lerobot.envs.factory import make_env, make_env_pre_post_processors
-from lerobot.envs.utils import close_envs
-from lerobot.optim.factory import make_optimizer_and_scheduler
-from lerobot.policies.factory import make_policy, make_pre_post_processors
-from lerobot.policies.pretrained import PreTrainedPolicy
-from lerobot.rl.wandb_utils import WandBLogger
-from lerobot.scripts.lerobot_eval import eval_policy_all
-from lerobot.utils.import_utils import register_third_party_plugins
-from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
-from lerobot.utils.random_utils import set_seed
-from lerobot.utils.train_utils import (
+from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
+from lerobot.common.wandb_utils import WandBLogger
+from lerobot.configs import parser
+from lerobot.configs.train import TrainPipelineConfig
+from lerobot.datasets import EpisodeAwareSampler, make_dataset
+from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
+from lerobot.optim.factory import make_optimizer_and_scheduler
+from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
+from lerobot.utils.import_utils import register_third_party_plugins
+from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
+from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
+ cycle,
format_big_number,
has_method,
init_logging,
inside_slurm,
)
+from .lerobot_eval import eval_policy_all
+
def update_policy(
train_metrics: MetricsTracker,
@@ -62,7 +67,7 @@ def update_policy(
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
- accelerator: Accelerator,
+ accelerator: "Accelerator",
lr_scheduler=None,
lock=None,
rabc_weights_provider=None,
@@ -151,7 +156,7 @@ def update_policy(
@parser.wrap()
-def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
+def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
"""
Main function to train a policy.
@@ -167,6 +172,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
+ from lerobot.utils.import_utils import require_package
+
+ require_package("accelerate", extra="training")
+ from accelerate import Accelerator
+
cfg.validate()
# Create Accelerator if not provided
diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py
index 35c2b60cd..c821a4d54 100644
--- a/src/lerobot/scripts/lerobot_train_tokenizer.py
+++ b/src/lerobot/scripts/lerobot_train_tokenizer.py
@@ -60,9 +60,8 @@ if TYPE_CHECKING or _transformers_available:
else:
AutoProcessor = None
-from lerobot.configs import parser
-from lerobot.configs.types import NormalizationMode
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
+from lerobot.configs import NormalizationMode, parser
+from lerobot.datasets import LeRobotDataset
from lerobot.utils.constants import ACTION, OBS_STATE
diff --git a/src/lerobot/teleoperators/__init__.py b/src/lerobot/teleoperators/__init__.py
index ee508dddb..d66e4b67d 100644
--- a/src/lerobot/teleoperators/__init__.py
+++ b/src/lerobot/teleoperators/__init__.py
@@ -17,3 +17,5 @@
from .config import TeleoperatorConfig
from .teleoperator import Teleoperator
from .utils import TeleopEvents, make_teleoperator_from_config
+
+__all__ = ["Teleoperator", "TeleoperatorConfig", "TeleopEvents", "make_teleoperator_from_config"]
diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py
index b44f1fbea..624729c02 100644
--- a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py
+++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py
@@ -17,11 +17,10 @@
import logging
from functools import cached_property
-from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig
from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
-from ..openarm_leader import OpenArmLeader
+from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
from ..teleoperator import Teleoperator
from .config_bi_openarm_leader import BiOpenArmLeaderConfig
diff --git a/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py
index 39fc90add..f7ec929ed 100644
--- a/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py
+++ b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py
@@ -16,9 +16,8 @@
from dataclasses import dataclass
-from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfigBase
-
from ..config import TeleoperatorConfig
+from ..openarm_leader import OpenArmLeaderConfigBase
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
diff --git a/src/lerobot/teleoperators/bi_so_leader/__init__.py b/src/lerobot/teleoperators/bi_so_leader/__init__.py
index b902270f9..cf78beb0c 100644
--- a/src/lerobot/teleoperators/bi_so_leader/__init__.py
+++ b/src/lerobot/teleoperators/bi_so_leader/__init__.py
@@ -15,3 +15,5 @@
# limitations under the License.
from .bi_so_leader import BiSOLeader, BiSOLeaderConfig
+
+__all__ = ["BiSOLeader", "BiSOLeaderConfig"]
diff --git a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py
index e84ac6f50..f2e88d20a 100644
--- a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py
+++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py
@@ -17,10 +17,9 @@
import logging
from functools import cached_property
-from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
-from ..so_leader import SOLeader
+from ..so_leader import SOLeader, SOLeaderTeleopConfig
from ..teleoperator import Teleoperator
from .config_bi_so_leader import BiSOLeaderConfig
diff --git a/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py
index c2f23c617..f477d0f26 100644
--- a/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py
+++ b/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py
@@ -16,9 +16,8 @@
from dataclasses import dataclass
-from lerobot.teleoperators.so_leader import SOLeaderConfig
-
from ..config import TeleoperatorConfig
+from ..so_leader import SOLeaderConfig
@TeleoperatorConfig.register_subclass("bi_so_leader")
diff --git a/src/lerobot/teleoperators/gamepad/__init__.py b/src/lerobot/teleoperators/gamepad/__init__.py
index 6f9f7fbd9..3c2709dc7 100644
--- a/src/lerobot/teleoperators/gamepad/__init__.py
+++ b/src/lerobot/teleoperators/gamepad/__init__.py
@@ -16,3 +16,5 @@
from .configuration_gamepad import GamepadTeleopConfig
from .teleop_gamepad import GamepadTeleop
+
+__all__ = ["GamepadTeleop", "GamepadTeleopConfig"]
diff --git a/src/lerobot/teleoperators/homunculus/__init__.py b/src/lerobot/teleoperators/homunculus/__init__.py
index b3c6c0bf5..ee1544e4c 100644
--- a/src/lerobot/teleoperators/homunculus/__init__.py
+++ b/src/lerobot/teleoperators/homunculus/__init__.py
@@ -18,3 +18,11 @@ from .config_homunculus import HomunculusArmConfig, HomunculusGloveConfig
from .homunculus_arm import HomunculusArm
from .homunculus_glove import HomunculusGlove
from .joints_translation import homunculus_glove_to_hope_jr_hand
+
+__all__ = [
+ "HomunculusArm",
+ "HomunculusArmConfig",
+ "HomunculusGlove",
+ "HomunculusGloveConfig",
+ "homunculus_glove_to_hope_jr_hand",
+]
diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py
index 178eed544..225235b59 100644
--- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py
+++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py
@@ -18,11 +18,16 @@ import logging
import threading
from collections import deque
from pprint import pformat
-
-import serial
+from typing import TYPE_CHECKING
from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
+from lerobot.utils.import_utils import _serial_available, require_package
+
+if TYPE_CHECKING or _serial_available:
+ import serial
+else:
+ serial = None # type: ignore[assignment]
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..teleoperator import Teleoperator
@@ -40,6 +45,7 @@ class HomunculusArm(Teleoperator):
name = "homunculus_arm"
def __init__(self, config: HomunculusArmConfig):
+ require_package("pyserial", extra="hardware", import_name="serial")
super().__init__(config)
self.config = config
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py
index c4393d660..655bae726 100644
--- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py
+++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py
@@ -18,17 +18,22 @@ import logging
import threading
from collections import deque
from pprint import pformat
-
-import serial
+from typing import TYPE_CHECKING
from lerobot.motors import MotorCalibration
from lerobot.motors.motors_bus import MotorNormMode
-from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
+from lerobot.utils.import_utils import _serial_available, require_package
+
+if TYPE_CHECKING or _serial_available:
+ import serial
+else:
+ serial = None # type: ignore[assignment]
from lerobot.utils.utils import enter_pressed, move_cursor_up
from ..teleoperator import Teleoperator
from .config_homunculus import HomunculusGloveConfig
+from .joints_translation import homunculus_glove_to_hope_jr_hand
logger = logging.getLogger(__name__)
@@ -66,6 +71,7 @@ class HomunculusGlove(Teleoperator):
name = "homunculus_glove"
def __init__(self, config: HomunculusGloveConfig):
+ require_package("pyserial", extra="hardware", import_name="serial")
super().__init__(config)
self.config = config
self.serial = serial.Serial(config.port, config.baud_rate, timeout=1)
diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py
index 090aa7fae..0f1c7d7f1 100644
--- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py
+++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py
@@ -23,6 +23,7 @@ from typing import Any
from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
+from lerobot.utils.import_utils import _pynput_available
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
@@ -32,20 +33,18 @@ from .configuration_keyboard import (
KeyboardTeleopConfig,
)
-PYNPUT_AVAILABLE = True
-try:
- if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
- logging.info("No DISPLAY set. Skipping pynput import.")
- raise ImportError("pynput blocked intentionally due to no display.")
-
- from pynput import keyboard
-except ImportError:
- keyboard = None
- PYNPUT_AVAILABLE = False
-except Exception as e:
- keyboard = None
- PYNPUT_AVAILABLE = False
- logging.info(f"Could not import pynput: {e}")
+PYNPUT_AVAILABLE = _pynput_available
+keyboard = None
+if PYNPUT_AVAILABLE:
+ try:
+ if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
+ logging.info("No DISPLAY set. Skipping pynput import.")
+ PYNPUT_AVAILABLE = False
+ else:
+ from pynput import keyboard
+ except Exception as e:
+ PYNPUT_AVAILABLE = False
+ logging.info(f"Could not import pynput: {e}")
class KeyboardTeleop(Teleoperator):
diff --git a/src/lerobot/teleoperators/koch_leader/__init__.py b/src/lerobot/teleoperators/koch_leader/__init__.py
index 1bf9d51db..7176649ec 100644
--- a/src/lerobot/teleoperators/koch_leader/__init__.py
+++ b/src/lerobot/teleoperators/koch_leader/__init__.py
@@ -16,3 +16,5 @@
from .config_koch_leader import KochLeaderConfig
from .koch_leader import KochLeader
+
+__all__ = ["KochLeader", "KochLeaderConfig"]
diff --git a/src/lerobot/teleoperators/omx_leader/__init__.py b/src/lerobot/teleoperators/omx_leader/__init__.py
index 04d96d63e..259e26143 100644
--- a/src/lerobot/teleoperators/omx_leader/__init__.py
+++ b/src/lerobot/teleoperators/omx_leader/__init__.py
@@ -16,3 +16,5 @@
from .config_omx_leader import OmxLeaderConfig
from .omx_leader import OmxLeader
+
+__all__ = ["OmxLeader", "OmxLeaderConfig"]
diff --git a/src/lerobot/teleoperators/phone/__init__.py b/src/lerobot/teleoperators/phone/__init__.py
index 2b28c1f97..2656a5014 100644
--- a/src/lerobot/teleoperators/phone/__init__.py
+++ b/src/lerobot/teleoperators/phone/__init__.py
@@ -16,3 +16,5 @@
from .config_phone import PhoneConfig
from .teleop_phone import Phone
+
+__all__ = ["Phone", "PhoneConfig"]
diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py
index c498bed7d..3d57a5a71 100644
--- a/src/lerobot/teleoperators/phone/phone_processor.py
+++ b/src/lerobot/teleoperators/phone/phone_processor.py
@@ -16,11 +16,12 @@
from dataclasses import dataclass, field
-from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import ProcessorStepRegistry, RobotActionProcessorStep
-from lerobot.teleoperators.phone.config_phone import PhoneOS
from lerobot.types import RobotAction
+from .config_phone import PhoneOS
+
@ProcessorStepRegistry.register("map_phone_action_to_robot_action")
@dataclass
diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py
index 221ee8083..f68843194 100644
--- a/src/lerobot/teleoperators/phone/teleop_phone.py
+++ b/src/lerobot/teleoperators/phone/teleop_phone.py
@@ -26,11 +26,12 @@ import hebi
import numpy as np
from teleop import Teleop
-from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
-from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.rotation import Rotation
+from ..teleoperator import Teleoperator
+from .config_phone import PhoneConfig, PhoneOS
+
logger = logging.getLogger(__name__)
diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py b/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py
index a07a4a6cd..aab1aec14 100644
--- a/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py
+++ b/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py
@@ -23,3 +23,13 @@ from .reachy2_teleoperator import (
REACHY2_VEL,
Reachy2Teleoperator,
)
+
+__all__ = [
+ "REACHY2_ANTENNAS_JOINTS",
+ "REACHY2_L_ARM_JOINTS",
+ "REACHY2_NECK_JOINTS",
+ "REACHY2_R_ARM_JOINTS",
+ "REACHY2_VEL",
+ "Reachy2Teleoperator",
+ "Reachy2TeleoperatorConfig",
+]
diff --git a/src/lerobot/teleoperators/so_leader/__init__.py b/src/lerobot/teleoperators/so_leader/__init__.py
index e5aaa31b6..26ef66677 100644
--- a/src/lerobot/teleoperators/so_leader/__init__.py
+++ b/src/lerobot/teleoperators/so_leader/__init__.py
@@ -21,3 +21,13 @@ from .config_so_leader import (
SOLeaderTeleopConfig,
)
from .so_leader import SO100Leader, SO101Leader, SOLeader
+
+__all__ = [
+ "SO100Leader",
+ "SO100LeaderConfig",
+ "SO101Leader",
+ "SO101LeaderConfig",
+ "SOLeader",
+ "SOLeaderConfig",
+ "SOLeaderTeleopConfig",
+]
diff --git a/src/lerobot/teleoperators/unitree_g1/exo_calib.py b/src/lerobot/teleoperators/unitree_g1/exo_calib.py
index b90e8fd7e..05f5180ff 100644
--- a/src/lerobot/teleoperators/unitree_g1/exo_calib.py
+++ b/src/lerobot/teleoperators/unitree_g1/exo_calib.py
@@ -22,15 +22,24 @@ and calculate arctan2 of the unit circle to get the joint angle.
We then store the ellipse parameters and the zero offset for each joint to be used at runtime.
"""
+from __future__ import annotations
+
import json
import logging
import time
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
+from typing import TYPE_CHECKING
import numpy as np
-import serial
+
+from lerobot.utils.import_utils import _serial_available
+
+if TYPE_CHECKING or _serial_available:
+ import serial
+else:
+ serial = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
@@ -82,7 +91,7 @@ class ExoskeletonCalibration:
}
@classmethod
- def from_dict(cls, data: dict) -> "ExoskeletonCalibration":
+ def from_dict(cls, data: dict) -> ExoskeletonCalibration:
joints = [
ExoskeletonJointCalibration(
name=j["name"],
diff --git a/src/lerobot/teleoperators/unitree_g1/exo_serial.py b/src/lerobot/teleoperators/unitree_g1/exo_serial.py
index 4f45997c0..9b1c71891 100644
--- a/src/lerobot/teleoperators/unitree_g1/exo_serial.py
+++ b/src/lerobot/teleoperators/unitree_g1/exo_serial.py
@@ -14,12 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
+
import json
import logging
from dataclasses import dataclass
from pathlib import Path
+from typing import TYPE_CHECKING
-import serial
+from lerobot.utils.import_utils import _serial_available, require_package
+
+if TYPE_CHECKING or _serial_available:
+ import serial
+else:
+ serial = None # type: ignore[assignment]
from .exo_calib import ExoskeletonCalibration, exo_raw_to_angles, run_exo_calibration
@@ -68,6 +76,7 @@ class ExoskeletonArm:
calibration: ExoskeletonCalibration | None = None
def __post_init__(self):
+ require_package("pyserial", extra="hardware", import_name="serial")
if self.calibration_fpath.is_file():
self._load_calibration()
diff --git a/src/lerobot/transforms/__init__.py b/src/lerobot/transforms/__init__.py
new file mode 100644
index 000000000..6cf9699d0
--- /dev/null
+++ b/src/lerobot/transforms/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .transforms import (
+ ImageTransformConfig,
+ ImageTransforms,
+ ImageTransformsConfig,
+ RandomSubsetApply,
+ SharpnessJitter,
+ make_transform_from_config,
+)
+
+__all__ = [
+ "ImageTransformConfig",
+ "ImageTransforms",
+ "ImageTransformsConfig",
+ "RandomSubsetApply",
+ "SharpnessJitter",
+ "make_transform_from_config",
+]
diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/transforms/transforms.py
similarity index 100%
rename from src/lerobot/datasets/transforms.py
rename to src/lerobot/transforms/transforms.py
diff --git a/src/lerobot/transport/__init__.py b/src/lerobot/transport/__init__.py
new file mode 100644
index 000000000..92ed74188
--- /dev/null
+++ b/src/lerobot/transport/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+gRPC transport layer for async inference.
+
+Requires: ``pip install 'lerobot[grpcio-dep]'``
+
+Available modules (import directly)::
+
+ from lerobot.transport.utils import ...
+"""
+
+from lerobot.utils.import_utils import require_package
+
+require_package("grpcio", extra="grpcio-dep", import_name="grpc")
+
+__all__: list[str] = []
diff --git a/src/lerobot/utils/__init__.py b/src/lerobot/utils/__init__.py
new file mode 100644
index 000000000..ee4808353
--- /dev/null
+++ b/src/lerobot/utils/__init__.py
@@ -0,0 +1,65 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Public API for lightweight, base-dependency-only utilities.
+
+Heavy cross-cutting modules (train_utils, control_utils) have been moved
+to ``lerobot.common``. ``visualization_utils`` remains here but is
+intentionally NOT re-exported to avoid pulling in optional dependencies.
+"""
+
+from .constants import (
+ ACTION,
+ DEFAULT_FEATURES,
+ DONE,
+ IMAGENET_STATS,
+ OBS_ENV_STATE,
+ OBS_IMAGE,
+ OBS_IMAGES,
+ OBS_STATE,
+ OBS_STR,
+ REWARD,
+)
+from .decorators import check_if_already_connected, check_if_not_connected
+from .device_utils import auto_select_torch_device, get_safe_torch_device, is_torch_device_available
+from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from .import_utils import is_package_available, require_package
+
+__all__ = [
+ # Constants
+ "ACTION",
+ "DEFAULT_FEATURES",
+ "DONE",
+ "IMAGENET_STATS",
+ "OBS_ENV_STATE",
+ "OBS_IMAGE",
+ "OBS_IMAGES",
+ "OBS_STATE",
+ "OBS_STR",
+ "REWARD",
+ # Device utilities
+ "auto_select_torch_device",
+ "get_safe_torch_device",
+ "is_torch_device_available",
+ # Import guards
+ "is_package_available",
+ "require_package",
+ # Decorators
+ "check_if_already_connected",
+ "check_if_not_connected",
+ # Errors
+ "DeviceAlreadyConnectedError",
+ "DeviceNotConnectedError",
+]
diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py
index fd10cab35..43869228d 100644
--- a/src/lerobot/utils/constants.py
+++ b/src/lerobot/utils/constants.py
@@ -75,6 +75,21 @@ default_calibration_path = HF_LEROBOT_HOME / "calibration"
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
+# Dataset meta-features (auto-populated by the recording pipeline)
+DEFAULT_FEATURES = {
+ "timestamp": {"dtype": "float32", "shape": (1,), "names": None},
+ "frame_index": {"dtype": "int64", "shape": (1,), "names": None},
+ "episode_index": {"dtype": "int64", "shape": (1,), "names": None},
+ "index": {"dtype": "int64", "shape": (1,), "names": None},
+ "task_index": {"dtype": "int64", "shape": (1,), "names": None},
+}
+
+# ImageNet normalization constants
+IMAGENET_STATS = {
+ "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
+ "std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
+}
+
# streaming datasets
LOOKBACK_BACKTRACKTABLE = 100
LOOKAHEAD_BACKTRACKTABLE = 100
diff --git a/src/lerobot/utils/decorators.py b/src/lerobot/utils/decorators.py
index 8fc2f9a07..75171f637 100644
--- a/src/lerobot/utils/decorators.py
+++ b/src/lerobot/utils/decorators.py
@@ -16,7 +16,7 @@
from functools import wraps
-from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
+from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
def check_if_not_connected(func):
diff --git a/src/lerobot/utils/feature_utils.py b/src/lerobot/utils/feature_utils.py
new file mode 100644
index 000000000..2a4886234
--- /dev/null
+++ b/src/lerobot/utils/feature_utils.py
@@ -0,0 +1,223 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Lightweight feature-manipulation utilities.
+
+These functions are intentionally kept free of heavy dependencies (e.g. the
+HuggingFace ``datasets`` library) so that they can be imported from anywhere
+in the codebase – including modules that are part of the *minimal* install –
+without triggering the ``lerobot.datasets`` package guard.
+"""
+
+from typing import Any
+
+import numpy as np
+
+from lerobot.configs import FeatureType, PolicyFeature
+
+from .constants import ACTION, DEFAULT_FEATURES, OBS_ENV_STATE, OBS_STR
+
+
+def _validate_feature_names(features: dict[str, dict]) -> None:
+ """Validate that feature names do not contain invalid characters.
+
+ Args:
+ features (dict): The LeRobot features dictionary.
+
+ Raises:
+ ValueError: If any feature name contains '/'.
+ """
+ invalid_features = {name: ft for name, ft in features.items() if "/" in name}
+ if invalid_features:
+ raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
+
+
+def hw_to_dataset_features(
+ hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
+) -> dict[str, dict]:
+ """Convert hardware-specific features to a LeRobot dataset feature dictionary.
+
+ This function takes a dictionary describing hardware outputs (like joint states
+ or camera image shapes) and formats it into the standard LeRobot feature
+ specification.
+
+ Args:
+ hw_features (dict): Dictionary mapping feature names to their type (float for
+ joints) or shape (tuple for images).
+ prefix (str): The prefix to add to the feature keys (e.g., "observation"
+ or "action").
+ use_video (bool): If True, image features are marked as "video", otherwise "image".
+
+ Returns:
+ dict: A LeRobot features dictionary.
+ """
+ features = {}
+ joint_fts = {
+ key: ftype
+ for key, ftype in hw_features.items()
+ if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
+ }
+ cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
+
+ if joint_fts and prefix == ACTION:
+ features[prefix] = {
+ "dtype": "float32",
+ "shape": (len(joint_fts),),
+ "names": list(joint_fts),
+ }
+
+ if joint_fts and prefix == OBS_STR:
+ features[f"{prefix}.state"] = {
+ "dtype": "float32",
+ "shape": (len(joint_fts),),
+ "names": list(joint_fts),
+ }
+
+ for key, shape in cam_fts.items():
+ features[f"{prefix}.images.{key}"] = {
+ "dtype": "video" if use_video else "image",
+ "shape": shape,
+ "names": ["height", "width", "channels"],
+ }
+
+ _validate_feature_names(features)
+ return features
+
+
+def build_dataset_frame(
+ ds_features: dict[str, dict], values: dict[str, Any], prefix: str
+) -> dict[str, np.ndarray]:
+ """Construct a single data frame from raw values based on dataset features.
+
+ A "frame" is a dictionary containing all the data for a single timestep,
+ formatted as numpy arrays according to the feature specification.
+
+ Args:
+ ds_features (dict): The LeRobot dataset features dictionary.
+ values (dict): A dictionary of raw values from the hardware/environment.
+ prefix (str): The prefix to filter features by (e.g., "observation"
+ or "action").
+
+ Returns:
+ dict: A dictionary representing a single frame of data.
+ """
+ frame = {}
+ for key, ft in ds_features.items():
+ if key in DEFAULT_FEATURES or not key.startswith(prefix):
+ continue
+ elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
+ frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
+ elif ft["dtype"] in ["image", "video"]:
+ frame[key] = values[key.removeprefix(f"{prefix}.images.")]
+
+ return frame
+
+
+def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]:
+ """Convert dataset features to policy features.
+
+ This function transforms the dataset's feature specification into a format
+ that a policy can use, classifying features by type (e.g., visual, state,
+ action) and ensuring correct shapes (e.g., channel-first for images).
+
+ Args:
+ features (dict): The LeRobot dataset features dictionary.
+
+ Returns:
+ dict: A dictionary mapping feature keys to `PolicyFeature` objects.
+
+ Raises:
+ ValueError: If an image feature does not have a 3D shape.
+ """
+ # TODO(aliberts): Implement "type" in dataset features and simplify this
+ policy_features = {}
+ for key, ft in features.items():
+ shape = ft["shape"]
+ if ft["dtype"] in ["image", "video"]:
+ type = FeatureType.VISUAL
+ if len(shape) != 3:
+ raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
+
+ names = ft["names"]
+ # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
+ if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
+ shape = (shape[2], shape[0], shape[1])
+ elif key == OBS_ENV_STATE:
+ type = FeatureType.ENV
+ elif key.startswith(OBS_STR):
+ type = FeatureType.STATE
+ elif key.startswith(ACTION):
+ type = FeatureType.ACTION
+ else:
+ continue
+
+ policy_features[key] = PolicyFeature(
+ type=type,
+ shape=shape,
+ )
+
+ return policy_features
+
+
+def combine_feature_dicts(*dicts: dict) -> dict:
+ """Merge LeRobot grouped feature dicts.
+
+ - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape.
+ - For others (e.g. `observation.images.*`), the last one wins (if they are identical).
+
+ Args:
+ *dicts: A variable number of LeRobot feature dictionaries to merge.
+
+ Returns:
+ dict: A single merged feature dictionary.
+
+ Raises:
+ ValueError: If there's a dtype mismatch for a feature being merged.
+ """
+ out: dict = {}
+ for d in dicts:
+ for key, value in d.items():
+ if not isinstance(value, dict):
+ out[key] = value
+ continue
+
+ dtype = value.get("dtype")
+ shape = value.get("shape")
+ is_vector = (
+ dtype not in ("image", "video", "string")
+ and isinstance(shape, tuple)
+ and len(shape) == 1
+ and "names" in value
+ )
+
+ if is_vector:
+ # Initialize or retrieve the accumulating dict for this feature key
+ target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)})
+ # Ensure consistent data types across merged entries
+ if "dtype" in target and dtype != target["dtype"]:
+ raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}")
+
+ # Merge feature names: append only new ones to preserve order without duplicates
+ seen = set(target["names"])
+ for n in value["names"]:
+ if n not in seen:
+ target["names"].append(n)
+ seen.add(n)
+ # Recompute the shape to reflect the updated number of features
+ target["shape"] = (len(target["names"]),)
+ else:
+ # For images/videos and non-1D entries: override with the latest definition
+ out[key] = value
+ return out
diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py
index 2b26b2302..8cd24b0fa 100644
--- a/src/lerobot/utils/import_utils.py
+++ b/src/lerobot/utils/import_utils.py
@@ -69,13 +69,64 @@ def is_package_available(
return package_exists
+def get_safe_default_codec():
+ logger = logging.getLogger(__name__)
+ if importlib.util.find_spec("torchcodec"):
+ return "torchcodec"
+ else:
+ logger.warning(
+ "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
+ )
+ return "pyav"
+
+
+_require_package_cache: dict[str, bool] = {}
+
+
+def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None:
+ """Raise an informative ImportError if a package required by an optional feature is missing."""
+ cache_key = import_name or pkg_name
+ if cache_key not in _require_package_cache:
+ _require_package_cache[cache_key] = is_package_available(pkg_name, import_name)
+ if not _require_package_cache[cache_key]:
+ raise ImportError(
+ f"'{pkg_name}' is required but not installed. Install it with: "
+ f"pip install 'lerobot[{extra}]' (or uv pip install 'lerobot[{extra}]')"
+ )
+
+
+# ── Centralised availability flags ────────────────────────────────────────
+# Every optional-dependency check lives here so that the rest of the codebase
+# can simply ``from lerobot.utils.import_utils import _foo_available``.
+# Do NOT define ad-hoc ``is_package_available(...)`` calls in other modules.
+
+# ML / training
_transformers_available = is_package_available("transformers")
_peft_available = is_package_available("peft")
_scipy_available = is_package_available("scipy")
+_diffusers_available = is_package_available("diffusers")
+_torchdiffeq_available = is_package_available("torchdiffeq")
+
+# Hardware SDKs
+_serial_available = is_package_available("pyserial", import_name="serial")
+_deepdiff_available = is_package_available("deepdiff")
+_dynamixel_sdk_available = is_package_available("dynamixel-sdk", import_name="dynamixel_sdk")
+_feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="scservo_sdk")
_reachy2_sdk_available = is_package_available("reachy2_sdk")
_can_available = is_package_available("python-can", "can")
_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py")
+
+# Data / serialization
+_pandas_available = is_package_available("pandas")
+_faker_available = is_package_available("faker")
+
+# Misc
+_pynput_available = is_package_available("pynput")
_pygame_available = is_package_available("pygame")
+_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
+_wallx_deps_available = (
+ _transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
+)
def make_device_from_device_class(config: ChoiceRegistry) -> Any:
diff --git a/src/lerobot/utils/io_utils.py b/src/lerobot/utils/io_utils.py
index d70ea8b6a..e037b412c 100644
--- a/src/lerobot/utils/io_utils.py
+++ b/src/lerobot/utils/io_utils.py
@@ -14,21 +14,80 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-import warnings
+import logging
from pathlib import Path
+from typing import Any
-import imageio
+logger = logging.getLogger(__name__)
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
-def write_video(video_path, stacked_frames, fps):
- # Filter out DeprecationWarnings raised from pkg_resources
- with warnings.catch_warnings():
- warnings.filterwarnings(
- "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
- )
- imageio.mimsave(video_path, stacked_frames, fps=fps)
+def load_json(fpath: Path) -> Any:
+ """Load data from a JSON file.
+
+ Args:
+ fpath (Path): Path to the JSON file.
+
+ Returns:
+ Any: The data loaded from the JSON file.
+ """
+ with open(fpath) as f:
+ return json.load(f)
+
+
+def write_json(data: dict, fpath: Path) -> None:
+ """Write data to a JSON file.
+
+ Creates parent directories if they don't exist.
+
+ Args:
+ data (dict): The dictionary to write.
+ fpath (Path): The path to the output JSON file.
+ """
+ fpath.parent.mkdir(exist_ok=True, parents=True)
+ with open(fpath, "w") as f:
+ json.dump(data, f, indent=4, ensure_ascii=False)
+
+
+def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
+ """Write a sequence of RGB frames to an MP4 video file using libx264.
+
+ Args:
+ video_path: Output file path.
+ stacked_frames: List of HWC uint8 numpy arrays (RGB).
+ fps: Frames per second for the output video.
+ """
+ from .import_utils import require_package
+
+ require_package("av", extra="av-dep")
+ import av
+
+ with av.open(str(video_path), mode="w") as container:
+ orig_height, orig_width = stacked_frames[0].shape[:2]
+ # yuv420p requires even dimensions; crop by one pixel if needed
+ height = orig_height if orig_height % 2 == 0 else orig_height - 1
+ width = orig_width if orig_width % 2 == 0 else orig_width - 1
+ if height != orig_height or width != orig_width:
+ logger.warning(
+ "Frame dimensions %dx%d are not even; cropping to %dx%d for yuv420p compatibility.",
+ orig_width,
+ orig_height,
+ width,
+ height,
+ )
+ stream = container.add_stream("libx264", rate=fps)
+ stream.width = width
+ stream.height = height
+ stream.pix_fmt = "yuv420p"
+ for frame_array in stacked_frames:
+ if height != orig_height or width != orig_width:
+ frame_array = frame_array[:height, :width]
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
+ for packet in stream.encode(frame):
+ container.mux(packet)
+ for packet in stream.encode():
+ container.mux(packet)
def deserialize_json_into_object[T: JsonLike](fpath: Path, obj: T) -> T:
diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py
index 1497c0585..0ce596f55 100644
--- a/src/lerobot/utils/logging_utils.py
+++ b/src/lerobot/utils/logging_utils.py
@@ -16,7 +16,7 @@
from collections.abc import Callable
from typing import Any
-from lerobot.utils.utils import format_big_number
+from .utils import format_big_number
class AverageMeter:
diff --git a/src/lerobot/utils/random_utils.py b/src/lerobot/utils/random_utils.py
index b34d357aa..e280fc342 100644
--- a/src/lerobot/utils/random_utils.py
+++ b/src/lerobot/utils/random_utils.py
@@ -23,8 +23,8 @@ import numpy as np
import torch
from safetensors.torch import load_file, save_file
-from lerobot.datasets.utils import flatten_dict, unflatten_dict
-from lerobot.utils.constants import RNG_STATE
+from .constants import RNG_STATE
+from .utils import flatten_dict, unflatten_dict
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
diff --git a/src/lerobot/utils/transition.py b/src/lerobot/utils/transition.py
index fe3620861..a79b95151 100644
--- a/src/lerobot/utils/transition.py
+++ b/src/lerobot/utils/transition.py
@@ -18,7 +18,7 @@ from typing import TypedDict
import torch
-from lerobot.utils.constants import ACTION
+from .constants import ACTION
class Transition(TypedDict):
diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py
index f6aa93bea..2574f1fa3 100644
--- a/src/lerobot/utils/utils.py
+++ b/src/lerobot/utils/utils.py
@@ -22,11 +22,12 @@ import select
import subprocess
import sys
import time
+from collections.abc import Iterator
from copy import copy, deepcopy
from datetime import datetime
from pathlib import Path
from statistics import mean
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
import numpy as np
@@ -199,6 +200,80 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
return days, hours, minutes, seconds
+def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
+ """Flatten a nested dictionary by joining keys with a separator.
+
+ Example:
+ >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
+ >>> print(flatten_dict(dct))
+ {'a/b': 1, 'a/c/d': 2, 'e': 3}
+
+ Args:
+ d (dict): The dictionary to flatten.
+ parent_key (str): The base key to prepend to the keys in this level.
+ sep (str): The separator to use between keys.
+
+ Returns:
+ dict: A flattened dictionary.
+ """
+ items = []
+ for k, v in d.items():
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
+ if isinstance(v, dict):
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def unflatten_dict(d: dict, sep: str = "/") -> dict:
+ """Unflatten a dictionary with delimited keys into a nested dictionary.
+
+ Example:
+ >>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3}
+ >>> print(unflatten_dict(flat_dct))
+ {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
+
+ Args:
+ d (dict): A dictionary with flattened keys.
+ sep (str): The separator used in the keys.
+
+ Returns:
+ dict: A nested dictionary.
+ """
+ outdict = {}
+ for key, value in d.items():
+ parts = key.split(sep)
+ d_inner = outdict
+ for part in parts[:-1]:
+ if part not in d_inner:
+ d_inner[part] = {}
+ d_inner = d_inner[part]
+ d_inner[parts[-1]] = value
+ return outdict
+
+
+def cycle(iterable: Any) -> Iterator[Any]:
+ """Create a dataloader-safe cyclical iterator.
+
+ This is an equivalent of `itertools.cycle` but is safe for use with
+ PyTorch DataLoaders with multiple workers.
+ See https://github.com/pytorch/pytorch/issues/23900 for details.
+
+ Args:
+ iterable: The iterable to cycle over.
+
+ Yields:
+ Items from the iterable, restarting from the beginning when exhausted.
+ """
+ iterator = iter(iterable)
+ while True:
+ try:
+ yield next(iterator)
+ except StopIteration:
+ iterator = iter(iterable)
+
+
class SuppressProgressBars:
"""
Context manager to suppress progress bars.
@@ -212,14 +287,22 @@ class SuppressProgressBars:
"""
def __enter__(self):
- from datasets.utils.logging import disable_progress_bar
+ try:
+ from datasets.utils.logging import disable_progress_bar
- disable_progress_bar()
+ disable_progress_bar()
+ except ImportError:
+ logging.getLogger(__name__).debug(
+ "SuppressProgressBars is a no-op because 'datasets' is not installed."
+ )
def __exit__(self, exc_type, exc_val, exc_tb):
- from datasets.utils.logging import enable_progress_bar
+ try:
+ from datasets.utils.logging import enable_progress_bar
- enable_progress_bar()
+ enable_progress_bar()
+ except ImportError:
+ pass
class TimerManager:
diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py
index 782358c9e..d9d5bf6b5 100644
--- a/src/lerobot/utils/visualization_utils.py
+++ b/src/lerobot/utils/visualization_utils.py
@@ -16,11 +16,11 @@ import numbers
import os
import numpy as np
-import rerun as rr
from lerobot.types import RobotAction, RobotObservation
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
+from .import_utils import require_package
def init_rerun(
@@ -34,6 +34,10 @@ def init_rerun(
ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server.
"""
+
+ require_package("rerun-sdk", extra="viz", import_name="rerun")
+ import rerun as rr
+
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name)
@@ -44,6 +48,15 @@ def init_rerun(
rr.spawn(memory_limit=memory_limit)
+def shutdown_rerun() -> None:
+ """Shuts down the Rerun SDK gracefully."""
+
+ require_package("rerun-sdk", extra="viz", import_name="rerun")
+ import rerun as rr
+
+ rr.rerun_shutdown()
+
+
def _is_scalar(x):
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
@@ -73,6 +86,10 @@ def log_rerun_data(
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
"""
+
+ require_package("rerun-sdk", extra="viz", import_name="rerun")
+ import rerun as rr
+
if observation:
for k, v in observation.items():
if v is None:
diff --git a/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py b/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py
index ce15d16fd..182058563 100644
--- a/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py
+++ b/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py
@@ -19,7 +19,7 @@ import torch
from safetensors.torch import save_file
from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.datasets.transforms import (
+from lerobot.transforms import (
ImageTransformConfig,
ImageTransforms,
ImageTransformsConfig,
diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py
index 7359f6169..ffb3efd03 100644
--- a/tests/artifacts/policies/save_policy_to_safetensors.py
+++ b/tests/artifacts/policies/save_policy_to_safetensors.py
@@ -21,7 +21,7 @@ from safetensors.torch import save_file
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
-from lerobot.datasets.factory import make_dataset
+from lerobot.datasets import make_dataset
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
from lerobot.utils.constants import OBS_STR
diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py
index 54ca29b48..8c5861a91 100644
--- a/tests/async_inference/test_e2e.py
+++ b/tests/async_inference/test_e2e.py
@@ -35,8 +35,10 @@ from concurrent import futures
import pytest
import torch
-# Skip entire module if grpc is not available
+# Skip entire module if required deps are not available
pytest.importorskip("grpc")
+pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])")
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
# -----------------------------------------------------------------------------
# End-to-end test
diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py
index a9e53200d..17fca2a44 100644
--- a/tests/async_inference/test_helpers.py
+++ b/tests/async_inference/test_helpers.py
@@ -16,10 +16,14 @@ import math
import pickle
import time
-import numpy as np
-import torch
+import pytest
-from lerobot.async_inference.helpers import (
+pytest.importorskip("grpc")
+
+import numpy as np # noqa: E402
+import torch # noqa: E402
+
+from lerobot.async_inference.helpers import ( # noqa: E402
FPSTracker,
TimedAction,
TimedObservation,
diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py
index c3ee37c8f..5cec2051c 100644
--- a/tests/async_inference/test_policy_server.py
+++ b/tests/async_inference/test_policy_server.py
@@ -24,7 +24,7 @@ import torch
from lerobot.configs.types import PolicyFeature
from lerobot.utils.constants import OBS_STATE
-from tests.utils import require_package
+from tests.utils import skip_if_package_missing
# -----------------------------------------------------------------------------
# Test fixtures
@@ -62,7 +62,7 @@ class MockPolicy:
@pytest.fixture
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def policy_server():
"""Fresh `PolicyServer` instance with a stubbed-out policy model."""
# Import only when the test actually runs (after decorator check)
diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py
index d7ef5b350..e2d840358 100644
--- a/tests/async_inference/test_robot_client.py
+++ b/tests/async_inference/test_robot_client.py
@@ -25,8 +25,10 @@ from queue import Queue
import pytest
import torch
-# Skip entire module if grpc is not available
+# Skip entire module if required deps are not available
pytest.importorskip("grpc")
+pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])")
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
# -----------------------------------------------------------------------------
# Test fixtures
diff --git a/tests/conftest.py b/tests/conftest.py
index 2fcf878ab..cadeaf0d3 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -17,24 +17,39 @@
import traceback
import pytest
-from serial import SerialException
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
+from lerobot.utils.import_utils import is_package_available
from tests.utils import DEVICE
-# Import fixture modules as plugins
+# Import fixture modules as plugins.
+# Fixtures that depend on optional packages are only registered when those packages are available,
+# so that tests can be collected and run even with a minimal install.
pytest_plugins = [
- "tests.fixtures.dataset_factories",
- "tests.fixtures.files",
- "tests.fixtures.hub",
"tests.fixtures.optimizers",
]
+if is_package_available("datasets"):
+ pytest_plugins += [
+ "tests.fixtures.dataset_factories",
+ "tests.fixtures.files",
+ "tests.fixtures.hub",
+ ]
+
def pytest_collection_finish():
print(f"\nTesting with {DEVICE=}")
+def _is_serial_exception(exc: Exception) -> bool:
+ """Check if an exception is a SerialException without requiring pyserial."""
+ if not is_package_available("pyserial", import_name="serial"):
+ return False
+ from serial import SerialException
+
+ return isinstance(exc, SerialException)
+
+
def _check_component_availability(component_type, available_components, make_component):
"""Generic helper to check if a hardware component is available"""
if component_type not in available_components:
@@ -53,7 +68,7 @@ def _check_component_availability(component_type, available_components, make_com
if isinstance(e, ModuleNotFoundError):
print(f"\nInstall module '{e.name}'")
- elif isinstance(e, SerialException):
+ elif _is_serial_exception(e):
print("\nNo physical device detected.")
elif isinstance(e, ValueError) and "camera_index" in str(e):
print("\nNo physical camera detected.")
diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py
index 4ac7e001a..b74299311 100644
--- a/tests/datasets/test_aggregate.py
+++ b/tests/datasets/test_aggregate.py
@@ -16,7 +16,11 @@
from unittest.mock import patch
-import datasets
+import pytest
+
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
+import datasets # noqa: E402
import torch
from lerobot.datasets.aggregate import aggregate_datasets
diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py
index 973c80bd8..70ba42378 100644
--- a/tests/datasets/test_compute_stats.py
+++ b/tests/datasets/test_compute_stats.py
@@ -18,6 +18,8 @@ from unittest.mock import patch
import numpy as np
import pytest
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.compute_stats import (
RunningQuantileStats,
_assert_type_and_shape,
diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py
index 3f3971e15..6db41d05c 100644
--- a/tests/datasets/test_dataset_metadata.py
+++ b/tests/datasets/test_dataset_metadata.py
@@ -20,6 +20,8 @@ import json
import numpy as np
import pytest
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.utils import INFO_PATH
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_ROBOT_TYPE
diff --git a/tests/datasets/test_dataset_reader.py b/tests/datasets/test_dataset_reader.py
index 4c8a8b23f..bbe858b5d 100644
--- a/tests/datasets/test_dataset_reader.py
+++ b/tests/datasets/test_dataset_reader.py
@@ -15,8 +15,12 @@
# limitations under the License.
"""Contract tests for DatasetReader."""
+import pytest
+
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.dataset_reader import DatasetReader
-from lerobot.datasets.video_utils import get_safe_default_codec
+from lerobot.utils.import_utils import get_safe_default_codec
# ── Loading ──────────────────────────────────────────────────────────
diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py
index 5ed7aa1a3..0b0862f00 100644
--- a/tests/datasets/test_dataset_tools.py
+++ b/tests/datasets/test_dataset_tools.py
@@ -21,6 +21,8 @@ import numpy as np
import pytest
import torch
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.dataset_tools import (
add_features,
delete_episodes,
diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py
index 874099e2b..bf705ba81 100644
--- a/tests/datasets/test_dataset_utils.py
+++ b/tests/datasets/test_dataset_utils.py
@@ -16,13 +16,16 @@
import pytest
import torch
-from datasets import Dataset
+
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
+from datasets import Dataset # noqa: E402
from huggingface_hub import DatasetCard
-from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.utils import create_lerobot_dataset_card
from lerobot.utils.constants import ACTION, OBS_IMAGES
+from lerobot.utils.feature_utils import combine_feature_dicts
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
diff --git a/tests/datasets/test_dataset_writer.py b/tests/datasets/test_dataset_writer.py
index 8c6ee68bd..8d2bc0373 100644
--- a/tests/datasets/test_dataset_writer.py
+++ b/tests/datasets/test_dataset_writer.py
@@ -23,6 +23,8 @@ import pytest
import torch
from PIL import Image
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.dataset_writer import _encode_video_worker
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py
index d4e9e88b8..6d4c41aaa 100644
--- a/tests/datasets/test_datasets.py
+++ b/tests/datasets/test_datasets.py
@@ -21,21 +21,22 @@ from pathlib import Path
import numpy as np
import pytest
import torch
+
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
-import lerobot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
-from lerobot.datasets.factory import make_dataset
-from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
+from lerobot.datasets import make_dataset
+from lerobot.datasets.feature_utils import get_hf_features_from_features
from lerobot.datasets.image_writer import image_array_to_pil_image
from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
-from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -46,7 +47,9 @@ from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy_config
from lerobot.robots import make_robot_from_config
+from lerobot.transforms import ImageTransforms, ImageTransformsConfig
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
+from lerobot.utils.feature_utils import hw_to_dataset_features
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel
@@ -493,13 +496,28 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
# - [ ] remove old tests
+ENV_DATASET_POLICY_TRIPLETS = [
+ ("aloha", dataset, "act")
+ for dataset in [
+ "lerobot/aloha_sim_insertion_human",
+ "lerobot/aloha_sim_insertion_scripted",
+ "lerobot/aloha_sim_transfer_cube_human",
+ "lerobot/aloha_sim_transfer_cube_scripted",
+ "lerobot/aloha_sim_insertion_human_image",
+ "lerobot/aloha_sim_insertion_scripted_image",
+ "lerobot/aloha_sim_transfer_cube_human_image",
+ "lerobot/aloha_sim_transfer_cube_scripted_image",
+ ]
+] + [
+ ("pusht", dataset, policy)
+ for dataset in ["lerobot/pusht", "lerobot/pusht_image"]
+ for policy in ["diffusion", "vqbet"]
+]
+
+
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
- # Single dataset
- lerobot.env_dataset_policy_triplets,
- # Multi-dataset
- # TODO after fix multidataset
- # + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
+ ENV_DATASET_POLICY_TRIPLETS,
)
def test_factory(env_name, repo_id, policy_name):
"""
diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py
index 8d9529f68..e4e5cf4f3 100644
--- a/tests/datasets/test_delta_timestamps.py
+++ b/tests/datasets/test_delta_timestamps.py
@@ -13,6 +13,8 @@
# limitations under the License.
import pytest
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.feature_utils import (
check_delta_timestamps,
get_delta_indices,
diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py
index ef7e8c395..4310274e4 100644
--- a/tests/datasets/test_image_transforms.py
+++ b/tests/datasets/test_image_transforms.py
@@ -21,7 +21,13 @@ from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
-from lerobot.datasets.transforms import (
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
+from lerobot.scripts.lerobot_imgtransform_viz import (
+ save_all_transforms,
+ save_each_transform,
+)
+from lerobot.transforms import (
ImageTransformConfig,
ImageTransforms,
ImageTransformsConfig,
@@ -29,10 +35,6 @@ from lerobot.datasets.transforms import (
SharpnessJitter,
make_transform_from_config,
)
-from lerobot.scripts.lerobot_imgtransform_viz import (
- save_all_transforms,
- save_each_transform,
-)
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.utils import require_x86_64_kernel
diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py
index 55419473f..916b8f017 100644
--- a/tests/datasets/test_image_writer.py
+++ b/tests/datasets/test_image_writer.py
@@ -20,6 +20,8 @@ import numpy as np
import pytest
from PIL import Image
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.image_writer import (
AsyncImageWriter,
image_array_to_pil_image,
diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py
index 5c3c24f99..49efa84d9 100644
--- a/tests/datasets/test_lerobot_dataset.py
+++ b/tests/datasets/test_lerobot_dataset.py
@@ -25,6 +25,8 @@ from unittest.mock import Mock
import pytest
import torch
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
import lerobot.datasets.dataset_metadata as dataset_metadata_module
import lerobot.datasets.lerobot_dataset as lerobot_dataset_module
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
diff --git a/tests/datasets/test_quantiles_dataset_integration.py b/tests/datasets/test_quantiles_dataset_integration.py
index 4df7fab06..b0e8a0e3c 100644
--- a/tests/datasets/test_quantiles_dataset_integration.py
+++ b/tests/datasets/test_quantiles_dataset_integration.py
@@ -19,6 +19,8 @@
import numpy as np
import pytest
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.lerobot_dataset import LeRobotDataset
diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py
index 18fb1c8ac..8bb3be8e9 100644
--- a/tests/datasets/test_sampler.py
+++ b/tests/datasets/test_sampler.py
@@ -17,7 +17,10 @@ import logging
import pytest
import torch
-from datasets import Dataset
+
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
+from datasets import Dataset # noqa: E402
from lerobot.datasets.io_utils import (
hf_transform_to_torch,
diff --git a/tests/datasets/test_streaming.py b/tests/datasets/test_streaming.py
index 1bd4c1787..db167f657 100644
--- a/tests/datasets/test_streaming.py
+++ b/tests/datasets/test_streaming.py
@@ -17,6 +17,8 @@ import numpy as np
import pytest
import torch
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.utils import safe_shard
from lerobot.utils.constants import ACTION
diff --git a/tests/datasets/test_streaming_video_encoder.py b/tests/datasets/test_streaming_video_encoder.py
index f7e63b06f..8b7a1540f 100644
--- a/tests/datasets/test_streaming_video_encoder.py
+++ b/tests/datasets/test_streaming_video_encoder.py
@@ -20,10 +20,13 @@ import queue
import threading
from unittest.mock import patch
-import av
import numpy as np
import pytest
+pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
+
+import av # noqa: E402
+
from lerobot.datasets.video_utils import (
VALID_VIDEO_CODECS,
StreamingVideoEncoder,
diff --git a/tests/datasets/test_subtask_dataset.py b/tests/datasets/test_subtask_dataset.py
index f80a6c72d..bb77b77d1 100644
--- a/tests/datasets/test_subtask_dataset.py
+++ b/tests/datasets/test_subtask_dataset.py
@@ -23,8 +23,11 @@ These tests verify that:
- Subtask handling gracefully handles missing data
"""
-import pandas as pd
import pytest
+
+pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
+
+import pandas as pd # noqa: E402
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
diff --git a/tests/datasets/test_visualize_dataset.py b/tests/datasets/test_visualize_dataset.py
index 8e92ec82e..3bf94e6cb 100644
--- a/tests/datasets/test_visualize_dataset.py
+++ b/tests/datasets/test_visualize_dataset.py
@@ -15,6 +15,8 @@
# limitations under the License.
import pytest
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.scripts.lerobot_dataset_viz import visualize_dataset
diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py
index 910c275eb..c6a0b077d 100644
--- a/tests/envs/test_envs.py
+++ b/tests/envs/test_envs.py
@@ -23,7 +23,6 @@ import torch
from gymnasium.envs.registration import register, registry as gym_registry
from gymnasium.utils.env_checker import check_env
-import lerobot
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config
@@ -36,9 +35,16 @@ from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
+ENV_TASK_PAIRS = [
+ ("aloha", "AlohaInsertion-v0"),
+ ("aloha", "AlohaTransferCube-v0"),
+ ("pusht", "PushT-v0"),
+]
+AVAILABLE_ENVS = ["aloha", "pusht"]
+
@pytest.mark.parametrize("obs_type", OBS_TYPES)
-@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
+@pytest.mark.parametrize("env_name, env_task", ENV_TASK_PAIRS)
@require_env
def test_env(env_name, env_task, obs_type):
if env_name == "aloha" and obs_type == "state":
@@ -51,7 +57,7 @@ def test_env(env_name, env_task, obs_type):
env.close()
-@pytest.mark.parametrize("env_name", lerobot.available_envs)
+@pytest.mark.parametrize("env_name", AVAILABLE_ENVS)
@require_env
def test_factory(env_name):
cfg = make_env_config(env_name)
diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py
index 5ecb52145..e068484b0 100644
--- a/tests/fixtures/dataset_factories.py
+++ b/tests/fixtures/dataset_factories.py
@@ -34,12 +34,12 @@ from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
- DEFAULT_FEATURES,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
- flatten_dict,
)
from lerobot.datasets.video_utils import encode_video_frames
+from lerobot.utils.constants import DEFAULT_FEATURES
+from lerobot.utils.utils import flatten_dict
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py
index 84026fc34..4a424a97c 100644
--- a/tests/mocks/mock_dynamixel.py
+++ b/tests/mocks/mock_dynamixel.py
@@ -21,10 +21,25 @@ import dynamixel_sdk as dxl
import serial
from mock_serial.mock_serial import MockSerial
-from lerobot.motors.dynamixel.dynamixel import _split_into_byte_chunks
-
from .mock_serial_patch import WaitableStub
+
+def _split_into_byte_chunks(value: int, length: int) -> list[int]:
+ """Split an integer into a list of byte-sized integers (little-endian)."""
+ if length == 1:
+ data = [value]
+ elif length == 2:
+ data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
+ elif length == 4:
+ data = [
+ dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
+ dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
+ dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
+ dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
+ ]
+ return data
+
+
# https://emanual.robotis.com/docs/en/dxl/crc/
DXL_CRC_TABLE = [
0x0000, 0x8005, 0x800F, 0x000A, 0x801B, 0x001E, 0x0014, 0x8011,
diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py
index 33cbc41d6..6e303b56b 100644
--- a/tests/mocks/mock_feetech.py
+++ b/tests/mocks/mock_feetech.py
@@ -21,11 +21,27 @@ import scservo_sdk as scs
import serial
from mock_serial import MockSerial
-from lerobot.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout
+from lerobot.motors.feetech.feetech import patch_setPacketTimeout
from .mock_serial_patch import WaitableStub
+def _split_into_byte_chunks(value: int, length: int) -> list[int]:
+ """Split an integer into a list of byte-sized integers (little-endian)."""
+ if length == 1:
+ data = [value]
+ elif length == 2:
+ data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
+ elif length == 4:
+ data = [
+ scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
+ scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
+ scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
+ scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
+ ]
+ return data
+
+
class MockFeetechPacket(abc.ABC):
@classmethod
def build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> bytes:
diff --git a/tests/mocks/mock_motors_bus.py b/tests/mocks/mock_motors_bus.py
index a499dbfee..9cb27224f 100644
--- a/tests/mocks/mock_motors_bus.py
+++ b/tests/mocks/mock_motors_bus.py
@@ -17,6 +17,7 @@
from lerobot.motors.motors_bus import (
Motor,
MotorsBus,
+ MotorsBusBase,
)
DUMMY_CTRL_TABLE_1 = {
@@ -122,6 +123,12 @@ class MockPortHandler:
class MockMotorsBus(MotorsBus):
+ """Mock motor bus that bypasses hardware dependency checks.
+
+ Inherits from MotorsBus (alias for SerialMotorsBus) for type compatibility,
+ but calls MotorsBusBase.__init__ directly to skip the pyserial/deepdiff guards.
+ """
+
available_baudrates = [500_000, 1_000_000]
default_timeout = 1000
model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE
@@ -132,8 +139,13 @@ class MockMotorsBus(MotorsBus):
normalized_data = ["Present_Position", "Goal_Position"]
def __init__(self, port: str, motors: dict[str, Motor]):
- super().__init__(port, motors)
+ # Skip SerialMotorsBus.__init__ (which guards pyserial/deepdiff)
+ # and call the base class directly — this mock never touches real serial.
+ MotorsBusBase.__init__(self, port, motors)
self.port_handler = MockPortHandler(port)
+ self._id_to_model_dict = {m.id: m.model for m in self.motors.values()}
+ self._id_to_name_dict = {m.id: name for name, m in self.motors.items()}
+ self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()}
def _assert_protocol_is_compatible(self, instruction_name): ...
def _handshake(self): ...
diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py
index 27650ef1b..60ecaeabb 100644
--- a/tests/motors/test_motors_bus.py
+++ b/tests/motors/test_motors_bus.py
@@ -19,6 +19,8 @@ from unittest.mock import patch
import pytest
+pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])")
+
from lerobot.motors.motors_bus import (
Motor,
MotorNormMode,
diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py
index 224613416..5d6687102 100644
--- a/tests/optim/test_schedulers.py
+++ b/tests/optim/test_schedulers.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import pytest
import torch
from packaging.version import Version
from torch.optim.lr_scheduler import LambdaLR
@@ -23,8 +24,10 @@ from lerobot.optim.schedulers import (
save_scheduler_state,
)
from lerobot.utils.constants import SCHEDULER_STATE
+from lerobot.utils.import_utils import is_package_available
+@pytest.mark.skipif(not is_package_available("diffusers"), reason="diffusers not installed")
def test_diffuser_scheduler(optimizer):
config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5)
scheduler = config.build(optimizer, num_training_steps=100)
diff --git a/tests/policies/groot/test_groot_lerobot.py b/tests/policies/groot/test_groot_lerobot.py
index e299a34e2..788935d4f 100644
--- a/tests/policies/groot/test_groot_lerobot.py
+++ b/tests/policies/groot/test_groot_lerobot.py
@@ -31,7 +31,7 @@ from lerobot.policies.groot.processor_groot import make_groot_pre_post_processor
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
from lerobot.utils.device_utils import auto_select_torch_device
-from tests.utils import require_cuda # noqa: E402
+from tests.utils import require_cuda
pytest.importorskip("transformers")
diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py
index a62ef3ebb..6d262c01b 100644
--- a/tests/policies/hilserl/test_modeling_classifier.py
+++ b/tests/policies/hilserl/test_modeling_classifier.py
@@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE, REWARD
-from tests.utils import require_package
+from tests.utils import skip_if_package_missing
def test_classifier_output():
@@ -37,7 +37,7 @@ def test_classifier_output():
)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
@@ -81,7 +81,7 @@ def test_binary_classifier_with_default_params():
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
@@ -123,7 +123,7 @@ def test_multiclass_classifier():
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
@@ -138,7 +138,7 @@ def test_default_device():
assert p.device == torch.device("cpu")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
diff --git a/tests/policies/smolvla/test_smolvla_rtc.py b/tests/policies/smolvla/test_smolvla_rtc.py
index 53e74d940..8c64c8a6c 100644
--- a/tests/policies/smolvla/test_smolvla_rtc.py
+++ b/tests/policies/smolvla/test_smolvla_rtc.py
@@ -19,15 +19,15 @@
import pytest
import torch
-from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
-from lerobot.policies.factory import make_pre_post_processors # noqa: E402
-from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
+from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule
+from lerobot.policies.factory import make_pre_post_processors
+from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
-from lerobot.utils.random_utils import set_seed # noqa: E402
-from tests.utils import require_cuda, require_package # noqa: E402
+from lerobot.utils.random_utils import set_seed
+from tests.utils import require_cuda, skip_if_package_missing
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@require_cuda
def test_smolvla_rtc_initialization():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
@@ -65,7 +65,7 @@ def test_smolvla_rtc_initialization():
print("✓ SmolVLA RTC initialization: Test passed")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@require_cuda
def test_smolvla_rtc_initialization_without_rtc_config():
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
@@ -87,7 +87,7 @@ def test_smolvla_rtc_initialization_without_rtc_config():
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_with_prev_chunk():
@@ -170,7 +170,7 @@ def test_smolvla_rtc_inference_with_prev_chunk():
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_inference_without_prev_chunk():
@@ -244,7 +244,7 @@ def test_smolvla_rtc_inference_without_prev_chunk():
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@require_cuda
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
def test_smolvla_rtc_validation_rules():
diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py
index 4a8d3ab72..2d50446fe 100644
--- a/tests/policies/test_policies.py
+++ b/tests/policies/test_policies.py
@@ -20,16 +20,16 @@ from pathlib import Path
import einops
import pytest
import torch
+
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from packaging import version
from safetensors.torch import load_file
-from lerobot import available_policies
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
-from lerobot.datasets.factory import make_dataset
-from lerobot.datasets.feature_utils import dataset_to_policy_features
-from lerobot.datasets.utils import cycle
+from lerobot.datasets import make_dataset
from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import close_envs, preprocess_observation
from lerobot.optim.factory import make_optimizer_and_scheduler
@@ -45,10 +45,23 @@ from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
+from lerobot.utils.feature_utils import dataset_to_policy_features
+from lerobot.utils.import_utils import is_package_available
from lerobot.utils.random_utils import seeded_context
+from lerobot.utils.utils import cycle
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
+# Policies that require optional heavy dependencies to instantiate
+_POLICY_REQUIRED_PACKAGES: dict[str, tuple[str, ...]] = {
+ "diffusion": ("diffusers",),
+}
+
+_ALL_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"]
+AVAILABLE_POLICIES = [
+ p for p in _ALL_POLICIES if all(is_package_available(pkg) for pkg in _POLICY_REQUIRED_PACKAGES.get(p, ()))
+]
+
@pytest.fixture
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
@@ -84,7 +97,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
return ds_meta
-@pytest.mark.parametrize("policy_name", available_policies)
+@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned."""
policy_cls = get_policy_class(policy_name)
@@ -255,7 +268,7 @@ def test_act_backbone_lr():
assert len(optimizer.param_groups[1]["params"]) == 20
-@pytest.mark.parametrize("policy_name", available_policies)
+@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
"""Check that the policy can be instantiated with defaults."""
policy_cls = get_policy_class(policy_name)
@@ -268,7 +281,7 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
policy_cls(policy_cfg)
-@pytest.mark.parametrize("policy_name", available_policies)
+@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
@@ -343,7 +356,7 @@ def test_multikey_construction(multikey: bool):
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
# Thus, we deactivate this test for now.
- (
+ pytest.param(
"lerobot/pusht",
"diffusion",
{
@@ -352,6 +365,7 @@ def test_multikey_construction(multikey: bool):
"down_dims": [128, 256, 512],
},
"",
+ marks=pytest.mark.skipif(not is_package_available("diffusers"), reason="diffusers not installed"),
),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
(
diff --git a/tests/policies/test_relative_actions.py b/tests/policies/test_relative_actions.py
index 64c2ee9c4..15ef0a31b 100644
--- a/tests/policies/test_relative_actions.py
+++ b/tests/policies/test_relative_actions.py
@@ -10,6 +10,8 @@ import numpy as np
import pytest
import torch
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.compute_stats import get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py
index a335c2b4b..2c41de22c 100644
--- a/tests/processor/test_pipeline.py
+++ b/tests/processor/test_pipeline.py
@@ -25,6 +25,8 @@ import pytest
import torch
import torch.nn as nn
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features
from lerobot.processor import (
diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py
index 227b1dc35..2aa7d4bdf 100644
--- a/tests/processor/test_smolvla_processor.py
+++ b/tests/processor/test_smolvla_processor.py
@@ -22,14 +22,12 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
-from lerobot.policies.smolvla.processor_smolvla import (
- SmolVLANewLineProcessor,
- make_smolvla_pre_post_processors,
-)
+from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
+ NewLineTaskProcessorStep,
NormalizerProcessorStep,
ProcessorStep,
RenameObservationsProcessorStep,
@@ -108,7 +106,7 @@ def test_make_smolvla_processor_basic():
assert len(preprocessor.steps) == 6
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
- assert isinstance(preprocessor.steps[2], SmolVLANewLineProcessor)
+ assert isinstance(preprocessor.steps[2], NewLineTaskProcessorStep)
# Step 3 would be TokenizerProcessorStep but it's mocked
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
@@ -120,8 +118,8 @@ def test_make_smolvla_processor_basic():
def test_smolvla_newline_processor_single_task():
- """Test SmolVLANewLineProcessor with single task string."""
- processor = SmolVLANewLineProcessor()
+ """Test NewLineTaskProcessorStep with single task string."""
+ processor = NewLineTaskProcessorStep()
# Test with task that doesn't have newline
transition = create_transition(complementary_data={"task": "test task"})
@@ -135,8 +133,8 @@ def test_smolvla_newline_processor_single_task():
def test_smolvla_newline_processor_list_of_tasks():
- """Test SmolVLANewLineProcessor with list of task strings."""
- processor = SmolVLANewLineProcessor()
+ """Test NewLineTaskProcessorStep with list of task strings."""
+ processor = NewLineTaskProcessorStep()
# Test with list of tasks
tasks = ["task1", "task2\n", "task3"]
@@ -147,8 +145,8 @@ def test_smolvla_newline_processor_list_of_tasks():
def test_smolvla_newline_processor_empty_transition():
- """Test SmolVLANewLineProcessor with empty transition."""
- processor = SmolVLANewLineProcessor()
+ """Test NewLineTaskProcessorStep with empty transition."""
+ processor = NewLineTaskProcessorStep()
# Test with no complementary_data
transition = create_transition()
@@ -361,8 +359,8 @@ def test_smolvla_processor_without_stats():
def test_smolvla_newline_processor_state_dict():
- """Test SmolVLANewLineProcessor state dict methods."""
- processor = SmolVLANewLineProcessor()
+ """Test NewLineTaskProcessorStep state dict methods."""
+ processor = NewLineTaskProcessorStep()
# Test state_dict (should be empty)
state = processor.state_dict()
@@ -380,8 +378,8 @@ def test_smolvla_newline_processor_state_dict():
def test_smolvla_newline_processor_transform_features():
- """Test SmolVLANewLineProcessor transform_features method."""
- processor = SmolVLANewLineProcessor()
+ """Test NewLineTaskProcessorStep transform_features method."""
+ processor = NewLineTaskProcessorStep()
# Test transform_features
features = {
diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py
index 76dce2537..5708e6e81 100644
--- a/tests/processor/test_tokenizer_processor.py
+++ b/tests/processor/test_tokenizer_processor.py
@@ -36,7 +36,7 @@ from lerobot.utils.constants import (
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_STATE,
)
-from tests.utils import require_package
+from tests.utils import skip_if_package_missing
class MockTokenizer:
@@ -94,7 +94,7 @@ def mock_tokenizer():
return MockTokenizer(vocab_size=100)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_basic_tokenization(mock_auto_tokenizer):
"""Test basic string tokenization functionality."""
@@ -129,7 +129,7 @@ def test_basic_tokenization(mock_auto_tokenizer):
assert attention_mask.shape == (10,)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_basic_tokenization_with_tokenizer_object():
"""Test basic string tokenization functionality using tokenizer object directly."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -161,7 +161,7 @@ def test_basic_tokenization_with_tokenizer_object():
assert attention_mask.shape == (10,)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_list_of_strings_tokenization(mock_auto_tokenizer):
"""Test tokenization of a list of strings."""
@@ -189,7 +189,7 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer):
assert attention_mask.shape == (2, 8)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_tuple_of_strings_tokenization(mock_auto_tokenizer):
"""Test tokenization of a tuple of strings (returned by VectorEnv.call())."""
@@ -213,7 +213,7 @@ def test_tuple_of_strings_tokenization(mock_auto_tokenizer):
assert attention_mask.shape == (2, 8)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_custom_keys(mock_auto_tokenizer):
"""Test using custom task_key."""
@@ -239,7 +239,7 @@ def test_custom_keys(mock_auto_tokenizer):
assert tokens.shape == (5,)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_none_complementary_data(mock_auto_tokenizer):
"""Test handling of None complementary_data."""
@@ -255,7 +255,7 @@ def test_none_complementary_data(mock_auto_tokenizer):
processor(transition)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_missing_task_key(mock_auto_tokenizer):
"""Test handling when task key is missing."""
@@ -270,7 +270,7 @@ def test_missing_task_key(mock_auto_tokenizer):
processor(transition)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_none_task_value(mock_auto_tokenizer):
"""Test handling when task value is None."""
@@ -285,7 +285,7 @@ def test_none_task_value(mock_auto_tokenizer):
processor(transition)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_unsupported_task_type(mock_auto_tokenizer):
"""Test handling of unsupported task types."""
@@ -307,14 +307,14 @@ def test_unsupported_task_type(mock_auto_tokenizer):
processor(transition)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_no_tokenizer_error():
"""Test that ValueError is raised when neither tokenizer nor tokenizer_name is provided."""
with pytest.raises(ValueError, match="Either 'tokenizer' or 'tokenizer_name' must be provided"):
TokenizerProcessorStep()
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_invalid_tokenizer_name_error():
"""Test that error is raised when invalid tokenizer_name is provided."""
with patch("lerobot.processor.tokenizer_processor.AutoTokenizer") as mock_auto_tokenizer:
@@ -325,7 +325,7 @@ def test_invalid_tokenizer_name_error():
TokenizerProcessorStep(tokenizer_name="invalid-tokenizer")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_get_config_with_tokenizer_name(mock_auto_tokenizer):
"""Test configuration serialization when using tokenizer_name."""
@@ -354,7 +354,7 @@ def test_get_config_with_tokenizer_name(mock_auto_tokenizer):
assert config == expected
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_config_with_tokenizer_object():
"""Test configuration serialization when using tokenizer object."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -382,7 +382,7 @@ def test_get_config_with_tokenizer_object():
assert "tokenizer_name" not in config
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_state_dict_methods(mock_auto_tokenizer):
"""Test state_dict and load_state_dict methods."""
@@ -399,7 +399,7 @@ def test_state_dict_methods(mock_auto_tokenizer):
processor.load_state_dict({})
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_reset_method(mock_auto_tokenizer):
"""Test reset method."""
@@ -412,7 +412,7 @@ def test_reset_method(mock_auto_tokenizer):
processor.reset()
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_integration_with_robot_processor(mock_auto_tokenizer):
"""Test integration with RobotProcessor."""
@@ -449,7 +449,7 @@ def test_integration_with_robot_processor(mock_auto_tokenizer):
assert torch.equal(result[TransitionKey.ACTION], transition[TransitionKey.ACTION])
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
"""Test saving and loading processor with tokenizer_name."""
@@ -489,7 +489,7 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION]
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_save_and_load_pretrained_with_tokenizer_object():
"""Test saving and loading processor with tokenizer object using overrides."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -528,7 +528,7 @@ def test_save_and_load_pretrained_with_tokenizer_object():
assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION]
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_registry_functionality():
"""Test that the processor is properly registered."""
from lerobot.processor import ProcessorStepRegistry
@@ -541,7 +541,7 @@ def test_registry_functionality():
assert retrieved_class is TokenizerProcessorStep
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_features_basic():
"""Test basic feature contract functionality."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -574,7 +574,7 @@ def test_features_basic():
assert attention_mask_feature.shape == (128,)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_features_with_custom_max_length():
"""Test feature contract with custom max_length."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -596,7 +596,7 @@ def test_features_with_custom_max_length():
assert attention_mask_feature.shape == (64,)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_features_existing_features():
"""Test feature contract when tokenized features already exist."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -618,7 +618,7 @@ def test_features_existing_features():
assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"].shape == (100,)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_tokenization_parameters(mock_auto_tokenizer):
"""Test that tokenization parameters are correctly passed to tokenizer."""
@@ -666,7 +666,7 @@ def test_tokenization_parameters(mock_auto_tokenizer):
assert tracking_tokenizer.last_call_kwargs["return_tensors"] == "pt"
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_preserves_other_complementary_data(mock_auto_tokenizer):
"""Test that other complementary data fields are preserved."""
@@ -701,7 +701,7 @@ def test_preserves_other_complementary_data(mock_auto_tokenizer):
assert f"{OBS_LANGUAGE}.attention_mask" in observation
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_deterministic_tokenization(mock_auto_tokenizer):
"""Test that tokenization is deterministic for the same input."""
@@ -729,7 +729,7 @@ def test_deterministic_tokenization(mock_auto_tokenizer):
assert torch.equal(attention_mask1, attention_mask2)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_empty_string_task(mock_auto_tokenizer):
"""Test handling of empty string task."""
@@ -753,7 +753,7 @@ def test_empty_string_task(mock_auto_tokenizer):
assert tokens.shape == (8,)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_very_long_task(mock_auto_tokenizer):
"""Test handling of very long task strings."""
@@ -779,7 +779,7 @@ def test_very_long_task(mock_auto_tokenizer):
assert attention_mask.shape == (5,)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_custom_padding_side(mock_auto_tokenizer):
"""Test using custom padding_side parameter."""
@@ -833,7 +833,7 @@ def test_custom_padding_side(mock_auto_tokenizer):
assert tracking_tokenizer.padding_side_calls[-1] == "right"
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_device_detection_cpu():
"""Test that tokenized tensors stay on CPU when other tensors are on CPU."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -857,7 +857,7 @@ def test_device_detection_cpu():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_device_detection_cuda():
"""Test that tokenized tensors are moved to CUDA when other tensors are on CUDA."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -882,7 +882,7 @@ def test_device_detection_cuda():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_device_detection_multi_gpu():
"""Test that tokenized tensors match device in multi-GPU setup."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -906,7 +906,7 @@ def test_device_detection_multi_gpu():
assert attention_mask.device == device
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_device_detection_no_tensors():
"""Test that tokenized tensors stay on CPU when no other tensors exist."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -928,7 +928,7 @@ def test_device_detection_no_tensors():
assert attention_mask.device.type == "cpu"
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_device_detection_mixed_devices():
"""Test device detection when tensors are on different devices (uses first found)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -956,7 +956,7 @@ def test_device_detection_mixed_devices():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_device_detection_from_action():
"""Test that device is detected from action tensor when no observation tensors exist."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -979,7 +979,7 @@ def test_device_detection_from_action():
assert attention_mask.device.type == "cuda"
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_device_detection_preserves_dtype():
"""Test that device detection doesn't affect dtype of tokenized tensors."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1000,7 +1000,7 @@ def test_device_detection_preserves_dtype():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_integration_with_device_processor(mock_auto_tokenizer):
"""Test that TokenizerProcessorStep works correctly with DeviceProcessorStep in pipeline."""
@@ -1039,7 +1039,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_simulated_accelerate_scenario():
"""Test scenario simulating Accelerate with data already on GPU."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1077,7 +1077,7 @@ def test_simulated_accelerate_scenario():
# =============================================================================
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_missing_key():
"""Test get_subtask returns None when subtask key is missing from complementary_data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1093,7 +1093,7 @@ def test_get_subtask_missing_key():
assert result is None
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_none_value():
"""Test get_subtask returns None when subtask value is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1109,7 +1109,7 @@ def test_get_subtask_none_value():
assert result is None
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_none_complementary_data():
"""Test get_subtask returns None when complementary_data is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1125,7 +1125,7 @@ def test_get_subtask_none_complementary_data():
assert result is None
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_string():
"""Test get_subtask returns list with single string when subtask is a string."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1143,7 +1143,7 @@ def test_get_subtask_string():
assert len(result) == 1
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_list_of_strings():
"""Test get_subtask returns the list when subtask is already a list of strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1162,7 +1162,7 @@ def test_get_subtask_list_of_strings():
assert len(result) == 3
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_unsupported_type_integer():
"""Test get_subtask returns None when subtask is an unsupported type (integer)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1178,7 +1178,7 @@ def test_get_subtask_unsupported_type_integer():
assert result is None
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_unsupported_type_mixed_list():
"""Test get_subtask returns None when subtask is a list with mixed types."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1194,7 +1194,7 @@ def test_get_subtask_unsupported_type_mixed_list():
assert result is None
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_unsupported_type_dict():
"""Test get_subtask returns None when subtask is a dictionary."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1210,7 +1210,7 @@ def test_get_subtask_unsupported_type_dict():
assert result is None
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_empty_string():
"""Test get_subtask with empty string returns list with empty string."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1226,7 +1226,7 @@ def test_get_subtask_empty_string():
assert result == [""]
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_get_subtask_empty_list():
"""Test get_subtask with empty list returns empty list."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1247,7 +1247,7 @@ def test_get_subtask_empty_list():
# =============================================================================
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_tokenization_when_present():
"""Test that subtask is tokenized and added to observation when present."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1276,7 +1276,7 @@ def test_subtask_tokenization_when_present():
assert subtask_attention_mask.dtype == torch.bool
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_tokenization_not_added_when_none():
"""Test that subtask tokens are NOT added to observation when subtask is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1300,7 +1300,7 @@ def test_subtask_tokenization_not_added_when_none():
assert f"{OBS_LANGUAGE}.attention_mask" in observation
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_tokenization_not_added_when_subtask_value_is_none():
"""Test that subtask tokens are NOT added when subtask value is explicitly None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1320,7 +1320,7 @@ def test_subtask_tokenization_not_added_when_subtask_value_is_none():
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_tokenization_list_of_strings():
"""Test subtask tokenization with list of strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1346,7 +1346,7 @@ def test_subtask_tokenization_list_of_strings():
assert subtask_attention_mask.shape == (2, 8)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_tokenization_device_cpu():
"""Test that subtask tokens are on CPU when other tensors are on CPU."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1372,7 +1372,7 @@ def test_subtask_tokenization_device_cpu():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_tokenization_device_cuda():
"""Test that subtask tokens are moved to CUDA when other tensors are on CUDA."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1397,7 +1397,7 @@ def test_subtask_tokenization_device_cuda():
assert subtask_attention_mask.device.type == "cuda"
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_tokenization_preserves_other_observation_data():
"""Test that subtask tokenization preserves other observation data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1423,7 +1423,7 @@ def test_subtask_tokenization_preserves_other_observation_data():
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_attention_mask_dtype():
"""Test that subtask attention mask has correct dtype (bool)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1442,7 +1442,7 @@ def test_subtask_attention_mask_dtype():
assert subtask_attention_mask.dtype == torch.bool
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_tokenization_deterministic():
"""Test that subtask tokenization is deterministic for the same input."""
mock_tokenizer = MockTokenizer(vocab_size=100)
@@ -1467,7 +1467,7 @@ def test_subtask_tokenization_deterministic():
assert torch.equal(subtask_mask1, subtask_mask2)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer):
"""Test subtask tokenization works correctly with DataProcessorPipeline."""
@@ -1504,7 +1504,7 @@ def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer):
assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,)
-@require_package("transformers")
+@skip_if_package_missing("transformers")
def test_subtask_not_added_for_unsupported_types():
"""Test that subtask tokens are not added when subtask has unsupported type."""
mock_tokenizer = MockTokenizer(vocab_size=100)
diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py
index 54e4d2870..08746ec91 100644
--- a/tests/rl/test_actor.py
+++ b/tests/rl/test_actor.py
@@ -19,11 +19,14 @@ from unittest.mock import patch
import pytest
import torch
+
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from torch.multiprocessing import Event, Queue
from lerobot.utils.constants import OBS_STR
from lerobot.utils.transition import Transition
-from tests.utils import require_package
+from tests.utils import skip_if_package_missing
def create_learner_service_stub():
@@ -64,7 +67,7 @@ def close_service_stub(channel, server):
server.stop(None)
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_establish_learner_connection_success():
from lerobot.rl.actor import establish_learner_connection
@@ -81,7 +84,7 @@ def test_establish_learner_connection_success():
close_service_stub(channel, server)
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_establish_learner_connection_failure():
from lerobot.rl.actor import establish_learner_connection
@@ -100,7 +103,7 @@ def test_establish_learner_connection_failure():
close_service_stub(channel, server)
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_push_transitions_to_transport_queue():
from lerobot.rl.actor import push_transitions_to_transport_queue
from lerobot.transport.utils import bytes_to_transitions
@@ -135,7 +138,7 @@ def test_push_transitions_to_transport_queue():
assert_transitions_equal(deserialized_transition, transitions[i])
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_transitions_stream():
from lerobot.rl.actor import transitions_stream
@@ -167,7 +170,7 @@ def test_transitions_stream():
assert streamed_data[2].data == b"transition_data_3"
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_interactions_stream():
from lerobot.rl.actor import interactions_stream
diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py
index e13862d82..3978dfffd 100644
--- a/tests/rl/test_actor_learner.py
+++ b/tests/rl/test_actor_learner.py
@@ -20,13 +20,16 @@ import time
import pytest
import torch
+
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.utils.constants import OBS_STR
from lerobot.utils.transition import Transition
-from tests.utils import require_package
+from tests.utils import skip_if_package_missing
def create_test_transitions(count: int = 3) -> list[Transition]:
@@ -88,7 +91,7 @@ def cfg():
return cfg
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_end_to_end_transitions_flow(cfg):
from lerobot.rl.actor import (
@@ -150,7 +153,7 @@ def test_end_to_end_transitions_flow(cfg):
assert_transitions_equal(transition, input_transitions[i])
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(10)
def test_end_to_end_interactions_flow(cfg):
from lerobot.rl.actor import (
@@ -223,7 +226,7 @@ def test_end_to_end_interactions_flow(cfg):
assert received == expected
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.parametrize("data_size", ["small", "large"])
@pytest.mark.timeout(10)
def test_end_to_end_parameters_flow(cfg, data_size):
diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py
index d967388f0..f1023f0f3 100644
--- a/tests/rl/test_learner_service.py
+++ b/tests/rl/test_learner_service.py
@@ -20,7 +20,7 @@ from multiprocessing import Event, Queue
import pytest
-from tests.utils import require_package # our gRPC servicer class
+from tests.utils import skip_if_package_missing # our gRPC servicer class
@pytest.fixture(scope="function")
@@ -39,7 +39,7 @@ def learner_service_stub():
close_learner_service_stub(channel, server)
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def create_learner_service_stub(
shutdown_event: Event,
parameters_queue: Queue,
@@ -75,7 +75,7 @@ def create_learner_service_stub(
return services_pb2_grpc.LearnerServiceStub(channel), channel, server
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def close_learner_service_stub(channel, server):
channel.close()
server.stop(None)
@@ -91,7 +91,7 @@ def test_ready_method(learner_service_stub):
assert response == services_pb2.Empty()
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_interactions():
from lerobot.transport import services_pb2
@@ -135,7 +135,7 @@ def test_send_interactions():
assert interactions == [b"123", b"4", b"5", b"678"]
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions():
from lerobot.transport import services_pb2
@@ -181,7 +181,7 @@ def test_send_transitions():
assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"]
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_send_transitions_empty_stream():
from lerobot.transport import services_pb2
@@ -209,7 +209,7 @@ def test_send_transitions_empty_stream():
assert transitions_queue.empty()
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(10) # force cross-platform watchdog
def test_stream_parameters():
import time
@@ -267,7 +267,7 @@ def test_stream_parameters():
assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1)
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_with_shutdown():
from lerobot.transport import services_pb2
@@ -319,7 +319,7 @@ def test_stream_parameters_with_shutdown():
assert received_params == [b"param_batch_1", b"stop"]
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
@pytest.mark.timeout(3) # force cross-platform watchdog
def test_stream_parameters_waits_and_retries_on_empty_queue():
import threading
diff --git a/tests/rl/test_queue.py b/tests/rl/test_queue.py
index b6716fbd6..cf3d6cdca 100644
--- a/tests/rl/test_queue.py
+++ b/tests/rl/test_queue.py
@@ -18,9 +18,13 @@ import threading
import time
from queue import Queue
-from torch.multiprocessing import Queue as TorchMPQueue
+import pytest
-from lerobot.rl.queue import get_last_item_from_queue
+pytest.importorskip("grpc")
+
+from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402
+
+from lerobot.rl.queue import get_last_item_from_queue # noqa: E402
def test_get_last_item_single_item():
diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py
index 4d758ae35..83ed5a78b 100644
--- a/tests/scripts/test_edit_dataset_parsing.py
+++ b/tests/scripts/test_edit_dataset_parsing.py
@@ -17,6 +17,8 @@
import draccus
import pytest
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.scripts.lerobot_edit_dataset import (
ConvertImageToVideoConfig,
DeleteEpisodesConfig,
diff --git a/tests/test_available.py b/tests/test_available.py
index 19e39b2b6..7dd1cdacb 100644
--- a/tests/test_available.py
+++ b/tests/test_available.py
@@ -13,48 +13,50 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import importlib
-import gymnasium as gym
+from unittest.mock import patch
+
import pytest
import lerobot
-from lerobot.policies.act.modeling_act import ACTPolicy
-from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
-from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
-from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
-from tests.utils import require_env
+from lerobot.utils.import_utils import _require_package_cache, require_package
-@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
-@require_env
-def test_available_env_task(env_name: str, task_name: list):
- """
- This test verifies that all environments listed in `lerobot/__init__.py` can
- be successfully imported — if they're installed — and that their
- `available_tasks_per_env` are valid.
- """
- package_name = f"gym_{env_name}"
- importlib.import_module(package_name)
- gym_handle = f"{package_name}/{task_name}"
- assert gym_handle in gym.envs.registry, gym_handle
+def test_version():
+ """Verify the package exposes a version string."""
+ assert isinstance(lerobot.__version__, str)
+ assert len(lerobot.__version__) > 0
-def test_available_policies():
- """
- This test verifies that the class attribute `name` for all policies is
- consistent with those listed in `lerobot/__init__.py`.
- """
- policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy]
- policies = [pol_cls.name for pol_cls in policy_classes]
- assert set(policies) == set(lerobot.available_policies), policies
+def test_require_package_raises_when_missing():
+ """require_package raises ImportError with install instructions when a package is missing."""
+ with patch("lerobot.utils.import_utils.is_package_available", return_value=False):
+ # Clear the cache so the mock takes effect
+ _require_package_cache.clear()
+ try:
+ with pytest.raises(ImportError, match=r"pip install 'lerobot\[dataset\]'"):
+ require_package("datasets", extra="dataset")
+ finally:
+ _require_package_cache.clear()
-def test_print():
- print(lerobot.available_envs)
- print(lerobot.available_tasks_per_env)
- print(lerobot.available_datasets)
- print(lerobot.available_datasets_per_env)
- print(lerobot.available_real_world_datasets)
- print(lerobot.available_policies)
- print(lerobot.available_policies_per_env)
+def test_require_package_passes_when_available():
+ """require_package does not raise when the package is installed."""
+ with patch("lerobot.utils.import_utils.is_package_available", return_value=True):
+ _require_package_cache.clear()
+ try:
+ # Should not raise
+ require_package("datasets", extra="dataset")
+ finally:
+ _require_package_cache.clear()
+
+
+def test_require_package_error_message_includes_uv():
+ """Error message includes both pip and uv install commands."""
+ with patch("lerobot.utils.import_utils.is_package_available", return_value=False):
+ _require_package_cache.clear()
+ try:
+ with pytest.raises(ImportError, match=r"uv pip install"):
+ require_package("grpcio", extra="async", import_name="grpc")
+ finally:
+ _require_package_cache.clear()
diff --git a/tests/test_cli_peft.py b/tests/test_cli_peft.py
index 42fef4741..5d653ee6b 100644
--- a/tests/test_cli_peft.py
+++ b/tests/test_cli_peft.py
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
import pytest
from safetensors.torch import load_file
-from .utils import require_package
+from .utils import skip_if_package_missing
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
@@ -37,7 +37,7 @@ def resolve_model_id_for_peft_training(policy_type):
@pytest.mark.parametrize("policy_type", ["smolvla"])
-@require_package("peft")
+@skip_if_package_missing("peft")
def test_peft_training_push_to_hub_works(policy_type, tmp_path):
"""Ensure that push to hub stores PEFT only the adapter, not the full model weights."""
output_dir = tmp_path / f"output_{policy_type}"
@@ -76,7 +76,7 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path):
@pytest.mark.parametrize("policy_type", ["smolvla"])
-@require_package("peft")
+@skip_if_package_missing("peft")
def test_peft_training_works(policy_type, tmp_path):
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
output_dir = tmp_path / f"output_{policy_type}"
@@ -125,7 +125,7 @@ def test_peft_training_works(policy_type, tmp_path):
@pytest.mark.parametrize("policy_type", ["smolvla"])
-@require_package("peft")
+@skip_if_package_missing("peft")
def test_peft_training_params_are_fewer(policy_type, tmp_path):
"""Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
output_dir = tmp_path / f"output_{policy_type}"
@@ -176,7 +176,7 @@ def dummy_make_robot_from_config(*args, **kwargs):
@pytest.mark.parametrize("policy_type", ["smolvla"])
-@require_package("peft")
+@skip_if_package_missing("peft")
def test_peft_record_loads_policy(policy_type, tmp_path):
"""Train a policy with PEFT and attempt to load it with `lerobot-record`."""
from peft import PeftModel
diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py
index 772588467..28e91a149 100644
--- a/tests/test_control_robot.py
+++ b/tests/test_control_robot.py
@@ -16,6 +16,11 @@
from unittest.mock import patch
+import pytest
+
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])")
+
from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate
from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record
from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay
diff --git a/tests/training/test_multi_gpu.py b/tests/training/test_multi_gpu.py
index bb234e2e7..638dc3131 100644
--- a/tests/training/test_multi_gpu.py
+++ b/tests/training/test_multi_gpu.py
@@ -33,6 +33,8 @@ from pathlib import Path
import pytest
import torch
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.datasets.lerobot_dataset import LeRobotDataset
diff --git a/tests/training/test_visual_validation.py b/tests/training/test_visual_validation.py
index 89351e3c2..1df8006b2 100644
--- a/tests/training/test_visual_validation.py
+++ b/tests/training/test_visual_validation.py
@@ -31,6 +31,8 @@ from pathlib import Path
import numpy as np
import pytest
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py
index 63632a8f4..d0df3d941 100644
--- a/tests/transport/test_transport_utils.py
+++ b/tests/transport/test_transport_utils.py
@@ -23,10 +23,10 @@ import torch
from lerobot.utils.constants import ACTION
from lerobot.utils.transition import Transition
-from tests.utils import require_cuda, require_package
+from tests.utils import require_cuda, skip_if_package_missing
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_bytes_buffer_size_empty_buffer():
from lerobot.transport.utils import bytes_buffer_size
@@ -37,7 +37,7 @@ def test_bytes_buffer_size_empty_buffer():
assert buffer.tell() == 0
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_bytes_buffer_size_small_buffer():
from lerobot.transport.utils import bytes_buffer_size
@@ -47,7 +47,7 @@ def test_bytes_buffer_size_small_buffer():
assert buffer.tell() == 0
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_bytes_buffer_size_large_buffer():
from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size
@@ -58,7 +58,7 @@ def test_bytes_buffer_size_large_buffer():
assert buffer.tell() == 0
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_send_bytes_in_chunks_empty_data():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -68,7 +68,7 @@ def test_send_bytes_in_chunks_empty_data():
assert len(chunks) == 0
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_single_chunk_small_data():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -82,7 +82,7 @@ def test_single_chunk_small_data():
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_not_silent_mode():
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
@@ -94,7 +94,7 @@ def test_not_silent_mode():
assert chunks[0].data == b"Some data"
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_send_bytes_in_chunks_large_data():
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
@@ -111,7 +111,7 @@ def test_send_bytes_in_chunks_large_data():
assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
@@ -124,7 +124,7 @@ def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_receive_bytes_in_chunks_empty_data():
from lerobot.transport.utils import receive_bytes_in_chunks
@@ -138,7 +138,7 @@ def test_receive_bytes_in_chunks_empty_data():
assert queue.empty()
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_receive_bytes_in_chunks_single_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -157,7 +157,7 @@ def test_receive_bytes_in_chunks_single_chunk():
assert queue.empty()
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_receive_bytes_in_chunks_single_not_end_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -175,7 +175,7 @@ def test_receive_bytes_in_chunks_single_not_end_chunk():
assert queue.empty()
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_receive_bytes_in_chunks_multiple_chunks():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -199,7 +199,7 @@ def test_receive_bytes_in_chunks_multiple_chunks():
assert queue.empty()
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_receive_bytes_in_chunks_multiple_messages():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -235,7 +235,7 @@ def test_receive_bytes_in_chunks_multiple_messages():
assert queue.empty()
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_receive_bytes_in_chunks_shutdown_during_receive():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -259,7 +259,7 @@ def test_receive_bytes_in_chunks_shutdown_during_receive():
assert queue.empty()
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_receive_bytes_in_chunks_only_begin_chunk():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -279,7 +279,7 @@ def test_receive_bytes_in_chunks_only_begin_chunk():
assert queue.empty()
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_receive_bytes_in_chunks_missing_begin():
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
@@ -303,7 +303,7 @@ def test_receive_bytes_in_chunks_missing_begin():
# Tests for state_to_bytes and bytes_to_state_dict
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_state_to_bytes_empty_dict():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -314,7 +314,7 @@ def test_state_to_bytes_empty_dict():
assert reconstructed == state_dict
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_bytes_to_state_dict_empty_data():
from lerobot.transport.utils import bytes_to_state_dict
@@ -323,7 +323,7 @@ def test_bytes_to_state_dict_empty_data():
bytes_to_state_dict(b"")
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_state_to_bytes_simple_dict():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -347,7 +347,7 @@ def test_state_to_bytes_simple_dict():
assert torch.allclose(state_dict[key], reconstructed[key])
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_state_to_bytes_various_dtypes():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -372,7 +372,7 @@ def test_state_to_bytes_various_dtypes():
assert torch.allclose(state_dict[key], reconstructed[key])
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_bytes_to_state_dict_invalid_data():
from lerobot.transport.utils import bytes_to_state_dict
@@ -382,7 +382,7 @@ def test_bytes_to_state_dict_invalid_data():
@require_cuda
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_state_to_bytes_various_dtypes_cuda():
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
@@ -407,7 +407,7 @@ def test_state_to_bytes_various_dtypes_cuda():
assert torch.allclose(state_dict[key], reconstructed[key])
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_python_object_to_bytes_none():
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -439,7 +439,7 @@ def test_python_object_to_bytes_none():
(1, 2, 3),
],
)
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_python_object_to_bytes_simple_types(obj):
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -450,7 +450,7 @@ def test_python_object_to_bytes_simple_types(obj):
assert type(reconstructed) is type(obj)
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_python_object_to_bytes_with_tensors():
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
@@ -475,7 +475,7 @@ def test_python_object_to_bytes_with_tensors():
assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"])
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_transitions_to_bytes_empty_list():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -487,7 +487,7 @@ def test_transitions_to_bytes_empty_list():
assert isinstance(reconstructed, list)
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_transitions_to_bytes_single_transition():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -509,7 +509,7 @@ def test_transitions_to_bytes_single_transition():
assert_transitions_equal(transitions[0], reconstructed[0])
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def assert_transitions_equal(t1: Transition, t2: Transition):
"""Helper to assert two transitions are equal."""
assert_observation_equal(t1["state"], t2["state"])
@@ -519,7 +519,7 @@ def assert_transitions_equal(t1: Transition, t2: Transition):
assert_observation_equal(t1["next_state"], t2["next_state"])
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def assert_observation_equal(o1: dict, o2: dict):
"""Helper to assert two observations are equal."""
assert set(o1.keys()) == set(o2.keys())
@@ -527,7 +527,7 @@ def assert_observation_equal(o1: dict, o2: dict):
assert torch.allclose(o1[key], o2[key])
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_transitions_to_bytes_multiple_transitions():
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
@@ -551,7 +551,7 @@ def test_transitions_to_bytes_multiple_transitions():
assert_transitions_equal(original, reconstructed_item)
-@require_package("grpcio", "grpc")
+@skip_if_package_missing("grpcio", "grpc")
def test_receive_bytes_in_chunks_unknown_state():
from lerobot.transport.utils import receive_bytes_in_chunks
diff --git a/tests/utils.py b/tests/utils.py
index 33c554804..f8f4b135b 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -20,23 +20,11 @@ from functools import wraps
import pytest
import torch
-from lerobot import available_cameras, available_motors, available_robots
from lerobot.utils.device_utils import auto_select_torch_device
from lerobot.utils.import_utils import is_package_available
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device()))
-TEST_ROBOT_TYPES = []
-for robot_type in available_robots:
- TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)]
-
-TEST_CAMERA_TYPES = []
-for camera_type in available_cameras:
- TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
-
-TEST_MOTOR_TYPES = []
-for motor_type in available_motors:
- TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
# Camera indices used for connecting physical cameras
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
@@ -152,7 +140,7 @@ def require_env(func):
return wrapper
-def require_package_arg(func):
+def skip_if_package_arg_missing(func):
"""
Decorator that skips the test if the required package is not installed.
This is similar to `require_env` but more general in that it can check any package (not just environments).
@@ -184,7 +172,7 @@ def require_package_arg(func):
return wrapper
-def require_package(package_name, import_name=None):
+def skip_if_package_missing(package_name, import_name=None):
"""
Decorator that skips the test if the specified package is not installed.
"""
diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py
index e2b00cae9..ce56db173 100644
--- a/tests/utils/test_process.py
+++ b/tests/utils/test_process.py
@@ -22,7 +22,9 @@ from unittest.mock import patch
import pytest
-from lerobot.rl.process import ProcessSignalHandler
+pytest.importorskip("grpc")
+
+from lerobot.rl.process import ProcessSignalHandler # noqa: E402
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py
index b9d3a1ac0..1b2af39f1 100644
--- a/tests/utils/test_replay_buffer.py
+++ b/tests/utils/test_replay_buffer.py
@@ -18,12 +18,16 @@ import sys
from collections.abc import Callable
import pytest
-import torch
-from lerobot.datasets.lerobot_dataset import LeRobotDataset
-from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
-from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD
-from tests.fixtures.constants import DUMMY_REPO_ID
+pytest.importorskip("grpc")
+pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
+
+import torch # noqa: E402
+
+from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: E402
+from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized # noqa: E402
+from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD # noqa: E402
+from tests.fixtures.constants import DUMMY_REPO_ID # noqa: E402
def state_dims() -> list[str]:
diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py
index 4791caf58..8e5b3f167 100644
--- a/tests/utils/test_train_utils.py
+++ b/tests/utils/test_train_utils.py
@@ -17,6 +17,16 @@
from pathlib import Path
from unittest.mock import Mock, patch
+from lerobot.common.train_utils import (
+ get_step_checkpoint_dir,
+ get_step_identifier,
+ load_training_state,
+ load_training_step,
+ save_checkpoint,
+ save_training_state,
+ save_training_step,
+ update_last_checkpoint,
+)
from lerobot.utils.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
@@ -27,16 +37,6 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR,
TRAINING_STEP,
)
-from lerobot.utils.train_utils import (
- get_step_checkpoint_dir,
- get_step_identifier,
- load_training_state,
- load_training_step,
- save_checkpoint,
- save_training_state,
- save_training_step,
- update_last_checkpoint,
-)
def test_get_step_identifier():
@@ -72,7 +72,7 @@ def test_update_last_checkpoint(tmp_path):
assert last_checkpoint.resolve() == checkpoint
-@patch("lerobot.utils.train_utils.save_training_state")
+@patch("lerobot.common.train_utils.save_training_state")
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
policy = Mock()
cfg = Mock()
@@ -82,7 +82,7 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
mock_save_training_state.assert_called_once()
-@patch("lerobot.utils.train_utils.save_training_state")
+@patch("lerobot.common.train_utils.save_training_state")
def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer):
policy = Mock()
policy.config = Mock()
diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py
index c8e5a92a8..63ff76c77 100644
--- a/tests/utils/test_visualization_utils.py
+++ b/tests/utils/test_visualization_utils.py
@@ -21,6 +21,8 @@ from types import SimpleNamespace
import numpy as np
import pytest
+pytest.importorskip("rerun", reason="rerun-sdk is required (install lerobot[viz])")
+
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_STATE
@@ -48,6 +50,9 @@ def mock_rerun(monkeypatch):
calls.append((key, obj, kwargs))
dummy_rr = SimpleNamespace(
+ __name__="rerun",
+ __package__="rerun",
+ __spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar,
Image=DummyImage,
log=dummy_log,
diff --git a/uv.lock b/uv.lock
index d549938aa..a66f044ff 100644
--- a/uv.lock
+++ b/uv.lock
@@ -2,24 +2,33 @@ version = 1
revision = 2
requires-python = ">=3.12"
resolution-markers = [
- "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux'",
+ "python_full_version >= '3.14' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'",
"python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux'",
- "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux'",
+ "python_full_version == '3.13.*' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'",
"python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux'",
- "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'",
+ "python_full_version < '3.13' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'",
"python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'",
- "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
- "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
- "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
- "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
- "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
- "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
+ "(python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'armv7l' and sys_platform == 'linux')",
+ "(python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and platform_machine == 'armv7l' and sys_platform == 'linux')",
+ "(python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'armv7l' and sys_platform == 'linux')",
+ "(python_full_version >= '3.14' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')",
+ "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "(python_full_version == '3.13.*' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')",
+ "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "(python_full_version < '3.13' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')",
+ "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'",
"python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'emscripten'",
"python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'emscripten'",
"python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'",
"python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'",
"python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'",
"python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'",
+ "(python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32')",
+ "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32'",
+ "(python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32')",
+ "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'",
+ "(python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32')",
+ "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'",
]
[[package]]
@@ -820,7 +829,7 @@ name = "cuda-bindings"
version = "12.9.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "cuda-pathfinder", marker = "sys_platform == 'linux'" },
+ { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" },
@@ -907,7 +916,7 @@ name = "decord"
version = "0.6.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "numpy" },
+ { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l') or sys_platform != 'linux'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" },
@@ -1010,12 +1019,15 @@ name = "dm-tree"
version = "0.1.9"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
- "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux'",
+ "python_full_version >= '3.14' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'",
"python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux'",
- "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
- "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
+ "(python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'armv7l' and sys_platform == 'linux')",
+ "(python_full_version >= '3.14' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')",
+ "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'",
"python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'emscripten'",
"python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'emscripten'",
+ "(python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32')",
+ "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32'",
]
dependencies = [
{ name = "absl-py", marker = "python_full_version >= '3.14'" },
@@ -1043,18 +1055,24 @@ name = "dm-tree"
version = "0.1.10"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
- "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux'",
+ "python_full_version == '3.13.*' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'",
"python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux'",
- "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'",
+ "python_full_version < '3.13' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'",
"python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'",
- "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
- "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
- "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
- "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'",
+ "(python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and platform_machine == 'armv7l' and sys_platform == 'linux')",
+ "(python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'armv7l' and sys_platform == 'linux')",
+ "(python_full_version == '3.13.*' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')",
+ "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'",
+ "(python_full_version < '3.13' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')",
+ "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'",
"python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'",
"python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'",
"python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'",
"python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'",
+ "(python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32')",
+ "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'",
+ "(python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32')",
+ "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'",
]
dependencies = [
{ name = "absl-py", marker = "python_full_version < '3.14'" },
@@ -2187,37 +2205,33 @@ name = "lerobot"
version = "0.5.2"
source = { editable = "." }
dependencies = [
- { name = "accelerate" },
- { name = "av" },
{ name = "cmake" },
- { name = "datasets" },
- { name = "deepdiff" },
- { name = "diffusers" },
{ name = "draccus" },
{ name = "einops" },
{ name = "gymnasium" },
{ name = "huggingface-hub" },
- { name = "imageio", extra = ["ffmpeg"] },
- { name = "jsonlines" },
{ name = "numpy" },
{ name = "opencv-python-headless" },
{ name = "packaging" },
- { name = "pynput" },
- { name = "pyserial" },
- { name = "rerun-sdk" },
+ { name = "pillow" },
+ { name = "requests" },
+ { name = "safetensors" },
{ name = "setuptools" },
{ name = "termcolor" },
{ name = "torch" },
- { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
{ name = "torchvision" },
- { name = "wandb" },
+ { name = "tqdm" },
]
[package.optional-dependencies]
all = [
{ name = "accelerate" },
+ { name = "av" },
{ name = "contourpy" },
+ { name = "datasets" },
{ name = "debugpy" },
+ { name = "deepdiff" },
+ { name = "diffusers" },
{ name = "dynamixel-sdk" },
{ name = "faker" },
{ name = "fastapi" },
@@ -2230,6 +2244,7 @@ all = [
{ name = "hebi-py" },
{ name = "hf-libero", marker = "sys_platform == 'linux'" },
{ name = "hidapi" },
+ { name = "jsonlines" },
{ name = "matplotlib" },
{ name = "metaworld" },
{ name = "mock-serial", marker = "sys_platform != 'win32'" },
@@ -2240,26 +2255,40 @@ all = [
{ name = "placo" },
{ name = "pre-commit" },
{ name = "protobuf" },
+ { name = "pyarrow" },
+ { name = "pydantic" },
{ name = "pygame" },
{ name = "pymunk" },
+ { name = "pynput" },
{ name = "pyrealsense2", marker = "sys_platform != 'darwin'" },
{ name = "pyrealsense2-macosx", marker = "sys_platform == 'darwin'" },
+ { name = "pyserial" },
{ name = "pytest" },
{ name = "pytest-cov" },
{ name = "pytest-timeout" },
+ { name = "python-can" },
{ name = "pyzmq" },
{ name = "qwen-vl-utils" },
{ name = "reachy2-sdk" },
- { name = "safetensors" },
+ { name = "rerun-sdk" },
+ { name = "ruff" },
{ name = "scikit-image" },
{ name = "scipy" },
{ name = "teleop" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
{ name = "torchdiffeq" },
{ name = "transformers" },
+ { name = "wandb" },
]
aloha = [
+ { name = "av" },
+ { name = "datasets" },
{ name = "gym-aloha" },
+ { name = "jsonlines" },
+ { name = "pandas" },
+ { name = "pyarrow" },
{ name = "scipy" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
]
async = [
{ name = "contourpy" },
@@ -2267,12 +2296,44 @@ async = [
{ name = "matplotlib" },
{ name = "protobuf" },
]
+av-dep = [
+ { name = "av" },
+]
can-dep = [
{ name = "python-can" },
]
+core-scripts = [
+ { name = "av" },
+ { name = "datasets" },
+ { name = "deepdiff" },
+ { name = "jsonlines" },
+ { name = "pandas" },
+ { name = "pyarrow" },
+ { name = "pynput" },
+ { name = "pyserial" },
+ { name = "rerun-sdk" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
+]
damiao = [
{ name = "python-can" },
]
+dataset = [
+ { name = "av" },
+ { name = "datasets" },
+ { name = "jsonlines" },
+ { name = "pandas" },
+ { name = "pyarrow" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
+]
+dataset-viz = [
+ { name = "av" },
+ { name = "datasets" },
+ { name = "jsonlines" },
+ { name = "pandas" },
+ { name = "pyarrow" },
+ { name = "rerun-sdk" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
+]
dev = [
{ name = "debugpy" },
{ name = "grpcio" },
@@ -2280,10 +2341,20 @@ dev = [
{ name = "mypy" },
{ name = "pre-commit" },
{ name = "protobuf" },
+ { name = "ruff" },
+]
+diffusers-dep = [
+ { name = "diffusers" },
+]
+diffusion = [
+ { name = "diffusers" },
]
dynamixel = [
{ name = "dynamixel-sdk" },
]
+evaluation = [
+ { name = "av" },
+]
feetech = [
{ name = "feetech-servo-sdk" },
]
@@ -2293,13 +2364,12 @@ gamepad = [
]
groot = [
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
+ { name = "diffusers" },
{ name = "dm-tree", version = "0.1.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" },
{ name = "dm-tree", version = "0.1.10", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" },
{ name = "flash-attn", marker = "sys_platform != 'darwin'" },
{ name = "ninja" },
{ name = "peft" },
- { name = "pillow" },
- { name = "safetensors" },
{ name = "timm" },
{ name = "transformers" },
]
@@ -2307,6 +2377,11 @@ grpcio-dep = [
{ name = "grpcio" },
{ name = "protobuf" },
]
+hardware = [
+ { name = "deepdiff" },
+ { name = "pynput" },
+ { name = "pyserial" },
+]
hilserl = [
{ name = "grpcio" },
{ name = "gym-hil" },
@@ -2330,8 +2405,14 @@ lekiwi = [
{ name = "pyzmq" },
]
libero = [
+ { name = "av" },
+ { name = "datasets" },
{ name = "hf-libero", marker = "sys_platform == 'linux'" },
+ { name = "jsonlines" },
+ { name = "pandas" },
+ { name = "pyarrow" },
{ name = "scipy" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
{ name = "transformers" },
]
matplotlib-dep = [
@@ -2339,10 +2420,17 @@ matplotlib-dep = [
{ name = "matplotlib" },
]
metaworld = [
+ { name = "av" },
+ { name = "datasets" },
+ { name = "jsonlines" },
{ name = "metaworld" },
+ { name = "pandas" },
+ { name = "pyarrow" },
{ name = "scipy" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
]
multi-task-dit = [
+ { name = "diffusers" },
{ name = "transformers" },
]
openarms = [
@@ -2369,8 +2457,14 @@ placo-dep = [
{ name = "placo" },
]
pusht = [
+ { name = "av" },
+ { name = "datasets" },
{ name = "gym-pusht" },
+ { name = "jsonlines" },
+ { name = "pandas" },
+ { name = "pyarrow" },
{ name = "pymunk" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
]
pygame-dep = [
{ name = "pygame" },
@@ -2388,6 +2482,7 @@ sarm = [
{ name = "contourpy" },
{ name = "faker" },
{ name = "matplotlib" },
+ { name = "pydantic" },
{ name = "qwen-vl-utils" },
{ name = "transformers" },
]
@@ -2397,7 +2492,6 @@ scipy-dep = [
smolvla = [
{ name = "accelerate" },
{ name = "num2words" },
- { name = "safetensors" },
{ name = "transformers" },
]
test = [
@@ -2406,6 +2500,16 @@ test = [
{ name = "pytest-cov" },
{ name = "pytest-timeout" },
]
+training = [
+ { name = "accelerate" },
+ { name = "av" },
+ { name = "datasets" },
+ { name = "jsonlines" },
+ { name = "pandas" },
+ { name = "pyarrow" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" },
+ { name = "wandb" },
+]
transformers-dep = [
{ name = "transformers" },
]
@@ -2422,6 +2526,9 @@ video-benchmark = [
{ name = "pandas" },
{ name = "scikit-image" },
]
+viz = [
+ { name = "rerun-sdk" },
+]
wallx = [
{ name = "peft" },
{ name = "qwen-vl-utils" },
@@ -2435,16 +2542,16 @@ xvla = [
[package.metadata]
requires-dist = [
- { name = "accelerate", specifier = ">=1.10.0,<2.0.0" },
{ name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" },
- { name = "av", specifier = ">=15.0.0,<16.0.0" },
+ { name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" },
+ { name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" },
{ name = "cmake", specifier = ">=3.29.0.1,<4.2.0" },
{ name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" },
- { name = "datasets", specifier = ">=4.0.0,<5.0.0" },
+ { name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
{ name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" },
{ name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" },
- { name = "deepdiff", specifier = ">=7.0.1,<9.0.0" },
- { name = "diffusers", specifier = ">=0.27.2,<0.36.0" },
+ { name = "deepdiff", marker = "extra == 'hardware'", specifier = ">=7.0.1,<9.0.0" },
+ { name = "diffusers", marker = "extra == 'diffusers-dep'", specifier = ">=0.27.2,<0.36.0" },
{ name = "dm-tree", marker = "extra == 'groot'", specifier = ">=0.1.8,<1.0.0" },
{ name = "draccus", specifier = "==0.10.0" },
{ name = "dynamixel-sdk", marker = "extra == 'dynamixel'", specifier = ">=3.7.31,<3.9.0" },
@@ -2463,21 +2570,38 @@ requires-dist = [
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
- { name = "imageio", extras = ["ffmpeg"], specifier = ">=2.34.0,<3.0.0" },
- { name = "jsonlines", specifier = ">=4.0.0,<5.0.0" },
+ { name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" },
{ name = "lerobot", extras = ["aloha"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["async"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["av-dep"], marker = "extra == 'dataset'" },
+ { name = "lerobot", extras = ["av-dep"], marker = "extra == 'evaluation'" },
{ name = "lerobot", extras = ["can-dep"], marker = "extra == 'damiao'" },
{ name = "lerobot", extras = ["can-dep"], marker = "extra == 'robstride'" },
+ { name = "lerobot", extras = ["damiao"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["damiao"], marker = "extra == 'openarms'" },
+ { name = "lerobot", extras = ["dataset"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["dataset"], marker = "extra == 'aloha'" },
+ { name = "lerobot", extras = ["dataset"], marker = "extra == 'core-scripts'" },
+ { name = "lerobot", extras = ["dataset"], marker = "extra == 'dataset-viz'" },
+ { name = "lerobot", extras = ["dataset"], marker = "extra == 'libero'" },
+ { name = "lerobot", extras = ["dataset"], marker = "extra == 'metaworld'" },
+ { name = "lerobot", extras = ["dataset"], marker = "extra == 'pusht'" },
+ { name = "lerobot", extras = ["dataset"], marker = "extra == 'training'" },
{ name = "lerobot", extras = ["dev"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" },
+ { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" },
+ { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" },
+ { name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
{ name = "lerobot", extras = ["gamepad"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'async'" },
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'dev'" },
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'hilserl'" },
+ { name = "lerobot", extras = ["hardware"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["hardware"], marker = "extra == 'core-scripts'" },
{ name = "lerobot", extras = ["hilserl"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["hopejr"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["intelrealsense"], marker = "extra == 'all'" },
@@ -2488,10 +2612,12 @@ requires-dist = [
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" },
{ name = "lerobot", extras = ["metaworld"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["multi-task-dit"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["openarms"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["peft"], marker = "extra == 'all'" },
- { name = "lerobot", extras = ["peft"], marker = "extra == 'groot'" },
- { name = "lerobot", extras = ["peft"], marker = "extra == 'wallx'" },
+ { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["peft-dep"], marker = "extra == 'peft'" },
+ { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["phone"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["pi"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["placo-dep"], marker = "extra == 'hilserl'" },
@@ -2503,6 +2629,7 @@ requires-dist = [
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" },
{ name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["robstride"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["sarm"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'aloha'" },
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'libero'" },
@@ -2512,6 +2639,7 @@ requires-dist = [
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" },
@@ -2523,6 +2651,9 @@ requires-dist = [
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" },
{ name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["viz"], marker = "extra == 'all'" },
+ { name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" },
+ { name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" },
{ name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" },
{ name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" },
@@ -2537,18 +2668,21 @@ requires-dist = [
{ name = "onnxruntime", marker = "extra == 'unitree-g1'", specifier = ">=1.16.0,<2.0.0" },
{ name = "opencv-python-headless", specifier = ">=4.9.0,<4.14.0" },
{ name = "packaging", specifier = ">=24.2,<26.0" },
+ { name = "pandas", marker = "extra == 'dataset'", specifier = ">=2.0.0,<3.0.0" },
{ name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" },
{ name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" },
- { name = "pillow", marker = "extra == 'groot'", specifier = ">=10.0.0,<13.0.0" },
+ { name = "pillow", specifier = ">=10.0.0,<13.0.0" },
{ name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.17" },
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" },
{ name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" },
+ { name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" },
+ { name = "pydantic", marker = "extra == 'sarm'", specifier = ">=2.0.0,<3.0.0" },
{ name = "pygame", marker = "extra == 'pygame-dep'", specifier = ">=2.5.1,<2.7.0" },
{ name = "pymunk", marker = "extra == 'pusht'", specifier = ">=6.6.0,<7.0.0" },
- { name = "pynput", specifier = ">=1.7.8,<1.9.0" },
+ { name = "pynput", marker = "extra == 'hardware'", specifier = ">=1.7.8,<1.9.0" },
{ name = "pyrealsense2", marker = "sys_platform != 'darwin' and extra == 'intelrealsense'", specifier = ">=2.55.1.6486,<2.57.0" },
{ name = "pyrealsense2-macosx", marker = "sys_platform == 'darwin' and extra == 'intelrealsense'", specifier = ">=2.54,<2.57.0" },
- { name = "pyserial", specifier = ">=3.5,<4.0" },
+ { name = "pyserial", marker = "extra == 'hardware'", specifier = ">=3.5,<4.0" },
{ name = "pytest", marker = "extra == 'test'", specifier = ">=8.1.0,<9.0.0" },
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=5.0.0,<8.0.0" },
{ name = "pytest-timeout", marker = "extra == 'test'", specifier = ">=2.4.0,<3.0.0" },
@@ -2557,9 +2691,10 @@ requires-dist = [
{ name = "pyzmq", marker = "extra == 'unitree-g1'", specifier = ">=26.2.1,<28.0.0" },
{ name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" },
{ name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" },
- { name = "rerun-sdk", specifier = ">=0.24.0,<0.27.0" },
- { name = "safetensors", marker = "extra == 'groot'", specifier = ">=0.4.3,<1.0.0" },
- { name = "safetensors", marker = "extra == 'smolvla'", specifier = ">=0.4.3,<1.0.0" },
+ { name = "requests", specifier = ">=2.32.0,<3.0.0" },
+ { name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" },
+ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.1" },
+ { name = "safetensors", specifier = ">=0.4.3,<1.0.0" },
{ name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" },
{ name = "scipy", marker = "extra == 'all'", specifier = ">=1.14.0,<2.0.0" },
{ name = "scipy", marker = "extra == 'scipy-dep'", specifier = ">=1.14.0,<2.0.0" },
@@ -2568,13 +2703,14 @@ requires-dist = [
{ name = "termcolor", specifier = ">=2.4.0,<4.0.0" },
{ name = "timm", marker = "extra == 'groot'", specifier = ">=1.0.0,<1.1.0" },
{ name = "torch", specifier = ">=2.7,<2.11.0" },
- { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", specifier = ">=0.3.0,<0.11.0" },
+ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine != 'x86_64' and sys_platform == 'darwin' and extra == 'dataset') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'dataset')", specifier = ">=0.3.0,<0.11.0" },
{ name = "torchdiffeq", marker = "extra == 'wallx'", specifier = ">=0.2.4,<0.3.0" },
{ name = "torchvision", specifier = ">=0.22.0,<0.26.0" },
+ { name = "tqdm", specifier = ">=4.66.0,<5.0.0" },
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = "==5.3.0" },
- { name = "wandb", specifier = ">=0.24.0,<0.25.0" },
+ { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
]
-provides-extras = ["pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "qwen-vl-utils-dep", "matplotlib-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "hilserl", "async", "peft", "dev", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
+provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "hilserl", "async", "peft", "dev", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"]
[[package]]
name = "librt"
@@ -3359,7 +3495,7 @@ name = "nvidia-cudnn-cu12"
version = "9.10.2.21"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
+ { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" },
@@ -3370,7 +3506,7 @@ name = "nvidia-cufft-cu12"
version = "11.3.3.83"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" },
@@ -3397,9 +3533,9 @@ name = "nvidia-cusolver-cu12"
version = "11.7.3.90"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" },
- { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" },
- { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
+ { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" },
+ { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" },
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" },
@@ -3410,7 +3546,7 @@ name = "nvidia-cusparse-cu12"
version = "12.5.8.93"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" },
+ { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" },
@@ -3677,7 +3813,7 @@ name = "pexpect"
version = "4.9.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "ptyprocess", marker = "sys_platform != 'emscripten'" },
+ { name = "ptyprocess", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" }
wheels = [
@@ -4231,10 +4367,10 @@ name = "pyobjc-framework-applicationservices"
version = "12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "pyobjc-core", marker = "sys_platform != 'linux'" },
- { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'linux'" },
- { name = "pyobjc-framework-coretext", marker = "sys_platform != 'linux'" },
- { name = "pyobjc-framework-quartz", marker = "sys_platform != 'linux'" },
+ { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
+ { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
+ { name = "pyobjc-framework-coretext", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
+ { name = "pyobjc-framework-quartz", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/be/6a/d4e613c8e926a5744fc47a9e9fea08384a510dc4f27d844f7ad7a2d793bd/pyobjc_framework_applicationservices-12.1.tar.gz", hash = "sha256:c06abb74f119bc27aeb41bf1aef8102c0ae1288aec1ac8665ea186a067a8945b", size = 103247, upload-time = "2025-11-14T10:08:52.18Z" }
wheels = [
@@ -4250,7 +4386,7 @@ name = "pyobjc-framework-cocoa"
version = "12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "pyobjc-core", marker = "sys_platform != 'linux'" },
+ { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/02/a3/16ca9a15e77c061a9250afbae2eae26f2e1579eb8ca9462ae2d2c71e1169/pyobjc_framework_cocoa-12.1.tar.gz", hash = "sha256:5556c87db95711b985d5efdaaf01c917ddd41d148b1e52a0c66b1a2e2c5c1640", size = 2772191, upload-time = "2025-11-14T10:13:02.069Z" }
wheels = [
@@ -4266,9 +4402,9 @@ name = "pyobjc-framework-coretext"
version = "12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "pyobjc-core", marker = "sys_platform != 'linux'" },
- { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'linux'" },
- { name = "pyobjc-framework-quartz", marker = "sys_platform != 'linux'" },
+ { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
+ { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
+ { name = "pyobjc-framework-quartz", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/29/da/682c9c92a39f713bd3c56e7375fa8f1b10ad558ecb075258ab6f1cdd4a6d/pyobjc_framework_coretext-12.1.tar.gz", hash = "sha256:e0adb717738fae395dc645c9e8a10bb5f6a4277e73cba8fa2a57f3b518e71da5", size = 90124, upload-time = "2025-11-14T10:14:38.596Z" }
wheels = [
@@ -4284,8 +4420,8 @@ name = "pyobjc-framework-quartz"
version = "12.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "pyobjc-core", marker = "sys_platform != 'linux'" },
- { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'linux'" },
+ { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
+ { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/94/18/cc59f3d4355c9456fc945eae7fe8797003c4da99212dd531ad1b0de8a0c6/pyobjc_framework_quartz-12.1.tar.gz", hash = "sha256:27f782f3513ac88ec9b6c82d9767eef95a5cf4175ce88a1e5a65875fee799608", size = 3159099, upload-time = "2025-11-14T10:21:24.31Z" }
wheels = [
@@ -4888,6 +5024,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/85/70/92482ccffb96f5441aab93e26c4d66489eb599efdcf96fad90c14bbfb976/rpds_py-0.30.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:dbd936cde57abfee19ab3213cf9c26be06d60750e60a8e4dd85d1ab12c8b1f40", size = 556030, upload-time = "2025-11-30T20:24:10.956Z" },
]
+[[package]]
+name = "ruff"
+version = "0.15.10"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e7/d9/aa3f7d59a10ef6b14fe3431706f854dbf03c5976be614a9796d36326810c/ruff-0.15.10.tar.gz", hash = "sha256:d1f86e67ebfdef88e00faefa1552b5e510e1d35f3be7d423dc7e84e63788c94e", size = 4631728, upload-time = "2026-04-09T14:06:09.884Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/eb/00/a1c2fdc9939b2c03691edbda290afcd297f1f389196172826b03d6b6a595/ruff-0.15.10-py3-none-linux_armv6l.whl", hash = "sha256:0744e31482f8f7d0d10a11fcbf897af272fefdfcb10f5af907b18c2813ff4d5f", size = 10563362, upload-time = "2026-04-09T14:06:21.189Z" },
+ { url = "https://files.pythonhosted.org/packages/5c/15/006990029aea0bebe9d33c73c3e28c80c391ebdba408d1b08496f00d422d/ruff-0.15.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b1e7c16ea0ff5a53b7c2df52d947e685973049be1cdfe2b59a9c43601897b22e", size = 10951122, upload-time = "2026-04-09T14:06:02.236Z" },
+ { url = "https://files.pythonhosted.org/packages/f2/c0/4ac978fe874d0618c7da647862afe697b281c2806f13ce904ad652fa87e4/ruff-0.15.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:93cc06a19e5155b4441dd72808fdf84290d84ad8a39ca3b0f994363ade4cebb1", size = 10314005, upload-time = "2026-04-09T14:06:00.026Z" },
+ { url = "https://files.pythonhosted.org/packages/da/73/c209138a5c98c0d321266372fc4e33ad43d506d7e5dd817dd89b60a8548f/ruff-0.15.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83e1dd04312997c99ea6965df66a14fb4f03ba978564574ffc68b0d61fd3989e", size = 10643450, upload-time = "2026-04-09T14:05:42.137Z" },
+ { url = "https://files.pythonhosted.org/packages/ec/76/0deec355d8ec10709653635b1f90856735302cb8e149acfdf6f82a5feb70/ruff-0.15.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8154d43684e4333360fedd11aaa40b1b08a4e37d8ffa9d95fee6fa5b37b6fab1", size = 10379597, upload-time = "2026-04-09T14:05:49.984Z" },
+ { url = "https://files.pythonhosted.org/packages/dc/be/86bba8fc8798c081e28a4b3bb6d143ccad3fd5f6f024f02002b8f08a9fa3/ruff-0.15.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ab88715f3a6deb6bde6c227f3a123410bec7b855c3ae331b4c006189e895cef", size = 11146645, upload-time = "2026-04-09T14:06:12.246Z" },
+ { url = "https://files.pythonhosted.org/packages/a8/89/140025e65911b281c57be1d385ba1d932c2366ca88ae6663685aed8d4881/ruff-0.15.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a768ff5969b4f44c349d48edf4ab4f91eddb27fd9d77799598e130fb628aa158", size = 12030289, upload-time = "2026-04-09T14:06:04.776Z" },
+ { url = "https://files.pythonhosted.org/packages/88/de/ddacca9545a5e01332567db01d44bd8cf725f2db3b3d61a80550b48308ea/ruff-0.15.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ee3ef42dab7078bda5ff6a1bcba8539e9857deb447132ad5566a038674540d0", size = 11496266, upload-time = "2026-04-09T14:05:55.485Z" },
+ { url = "https://files.pythonhosted.org/packages/bc/bb/7ddb00a83760ff4a83c4e2fc231fd63937cc7317c10c82f583302e0f6586/ruff-0.15.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51cb8cc943e891ba99989dd92d61e29b1d231e14811db9be6440ecf25d5c1609", size = 11256418, upload-time = "2026-04-09T14:05:57.69Z" },
+ { url = "https://files.pythonhosted.org/packages/dc/8d/55de0d35aacf6cd50b6ee91ee0f291672080021896543776f4170fc5c454/ruff-0.15.10-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:e59c9bdc056a320fb9ea1700a8d591718b8faf78af065484e801258d3a76bc3f", size = 11288416, upload-time = "2026-04-09T14:05:44.695Z" },
+ { url = "https://files.pythonhosted.org/packages/68/cf/9438b1a27426ec46a80e0a718093c7f958ef72f43eb3111862949ead3cc1/ruff-0.15.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:136c00ca2f47b0018b073f28cb5c1506642a830ea941a60354b0e8bc8076b151", size = 10621053, upload-time = "2026-04-09T14:05:52.782Z" },
+ { url = "https://files.pythonhosted.org/packages/4c/50/e29be6e2c135e9cd4cb15fbade49d6a2717e009dff3766dd080fcb82e251/ruff-0.15.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8b80a2f3c9c8a950d6237f2ca12b206bccff626139be9fa005f14feb881a1ae8", size = 10378302, upload-time = "2026-04-09T14:06:14.361Z" },
+ { url = "https://files.pythonhosted.org/packages/18/2f/e0b36a6f99c51bb89f3a30239bc7bf97e87a37ae80aa2d6542d6e5150364/ruff-0.15.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:e3e53c588164dc025b671c9df2462429d60357ea91af7e92e9d56c565a9f1b07", size = 10850074, upload-time = "2026-04-09T14:06:16.581Z" },
+ { url = "https://files.pythonhosted.org/packages/11/08/874da392558ce087a0f9b709dc6ec0d60cbc694c1c772dab8d5f31efe8cb/ruff-0.15.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b0c52744cf9f143a393e284125d2576140b68264a93c6716464e129a3e9adb48", size = 11358051, upload-time = "2026-04-09T14:06:18.948Z" },
+ { url = "https://files.pythonhosted.org/packages/e4/46/602938f030adfa043e67112b73821024dc79f3ab4df5474c25fa4c1d2d14/ruff-0.15.10-py3-none-win32.whl", hash = "sha256:d4272e87e801e9a27a2e8df7b21011c909d9ddd82f4f3281d269b6ba19789ca5", size = 10588964, upload-time = "2026-04-09T14:06:07.14Z" },
+ { url = "https://files.pythonhosted.org/packages/25/b6/261225b875d7a13b33a6d02508c39c28450b2041bb01d0f7f1a83d569512/ruff-0.15.10-py3-none-win_amd64.whl", hash = "sha256:28cb32d53203242d403d819fd6983152489b12e4a3ae44993543d6fe62ab42ed", size = 11745044, upload-time = "2026-04-09T14:05:39.473Z" },
+ { url = "https://files.pythonhosted.org/packages/58/ed/dea90a65b7d9e69888890fb14c90d7f51bf0c1e82ad800aeb0160e4bacfd/ruff-0.15.10-py3-none-win_arm64.whl", hash = "sha256:601d1610a9e1f1c2165a4f561eeaa2e2ea1e97f3287c5aa258d3dab8b57c6188", size = 11035607, upload-time = "2026-04-09T14:05:47.593Z" },
+]
+
[[package]]
name = "safetensors"
version = "0.7.0"