From df0763a2bc8153ae69ab360af39673a328004b33 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 12 Apr 2026 20:03:04 +0200 Subject: [PATCH] feat(dependencies): minimal default tag install (#3362) --- .github/workflows/fast_tests.yml | 38 ++- docs/source/adding_benchmarks.mdx | 2 +- docs/source/async.mdx | 2 +- docs/source/backwardcomp.mdx | 2 +- docs/source/bring_your_own_policies.mdx | 8 +- docs/source/cameras.mdx | 10 +- docs/source/dataset_subtask.mdx | 13 +- docs/source/earthrover_mini_plus.mdx | 4 +- docs/source/env_processor.mdx | 6 +- docs/source/envhub.mdx | 6 +- docs/source/envhub_isaaclab_arena.mdx | 6 +- docs/source/envhub_leisaac.mdx | 6 +- docs/source/il_robots.mdx | 41 ++- docs/source/installation.mdx | 60 +++- docs/source/introduction_processors.mdx | 8 +- docs/source/lerobot-dataset-v3.mdx | 10 +- docs/source/multi_gpu_training.mdx | 4 +- docs/source/phone_teleop.mdx | 3 +- docs/source/pi0.mdx | 3 +- docs/source/pi05.mdx | 3 +- docs/source/policy_pi0_README.md | 5 +- docs/source/rtc.mdx | 5 +- docs/source/xvla.mdx | 2 +- examples/backward_compatibility/replay.py | 2 +- examples/dataset/load_lerobot_dataset.py | 10 +- examples/dataset/slurm_compute_rabc.py | 2 +- .../dataset/use_dataset_image_transforms.py | 4 +- examples/dataset/use_dataset_tools.py | 4 +- examples/hil/hil_data_collection.py | 39 +-- examples/hil/hil_utils.py | 4 +- examples/lekiwi/evaluate.py | 10 +- examples/lekiwi/record.py | 9 +- examples/lekiwi/replay.py | 5 +- examples/phone_to_so100/evaluate.py | 17 +- examples/phone_to_so100/record.py | 17 +- examples/phone_to_so100/replay.py | 6 +- examples/phone_to_so100/teleoperate.py | 8 +- examples/port_datasets/port_droid.py | 3 +- .../port_datasets/slurm_aggregate_shards.py | 2 +- examples/port_datasets/slurm_upload.py | 5 +- examples/rtc/eval_dataset.py | 13 +- examples/rtc/eval_with_real_robot.py | 18 +- examples/so100_to_so100_EE/evaluate.py | 17 +- examples/so100_to_so100_EE/record.py | 13 +- examples/so100_to_so100_EE/replay.py | 6 +- examples/so100_to_so100_EE/teleoperate.py | 4 +- examples/training/train_policy.py | 12 +- examples/training/train_with_streaming.py | 12 +- examples/tutorial/act/act_training_example.py | 12 +- examples/tutorial/act/act_using_example.py | 8 +- examples/tutorial/async-inf/robot_client.py | 2 +- .../diffusion/diffusion_training_example.py | 12 +- .../diffusion/diffusion_using_example.py | 8 +- examples/tutorial/pi0/using_pi0_example.py | 8 +- examples/tutorial/rl/hilserl_example.py | 8 +- .../tutorial/rl/reward_classifier_example.py | 5 +- .../tutorial/smolvla/using_smolvla_example.py | 8 +- pyproject.toml | 121 ++++--- src/lerobot/__init__.py | 201 ++---------- src/lerobot/async_inference/__init__.py | 30 ++ src/lerobot/async_inference/helpers.py | 4 +- src/lerobot/async_inference/policy_server.py | 2 +- src/lerobot/async_inference/robot_client.py | 4 +- src/lerobot/cameras/__init__.py | 6 + .../cameras/reachy2_camera/__init__.py | 2 + src/lerobot/cameras/realsense/__init__.py | 2 + src/lerobot/cameras/zmq/image_server.py | 4 +- src/lerobot/common/__init__.py | 30 ++ .../{utils => common}/control_utils.py | 22 +- src/lerobot/{utils => common}/train_utils.py | 12 +- src/lerobot/{rl => common}/wandb_utils.py | 0 src/lerobot/configs/__init__.py | 47 +++ src/lerobot/configs/default.py | 4 +- src/lerobot/configs/eval.py | 5 +- src/lerobot/configs/policies.py | 6 +- src/lerobot/configs/train.py | 8 +- src/lerobot/data_processing/__init__.py | 10 + .../sarm_annotations/__init__.py | 10 + .../sarm_annotations/subtask_annotation.py | 7 +- src/lerobot/datasets/__init__.py | 65 +++- src/lerobot/datasets/aggregate.py | 12 +- src/lerobot/datasets/compute_stats.py | 6 +- src/lerobot/datasets/dataset_metadata.py | 17 +- src/lerobot/datasets/dataset_reader.py | 8 +- src/lerobot/datasets/dataset_tools.py | 26 +- src/lerobot/datasets/dataset_writer.py | 14 +- src/lerobot/datasets/factory.py | 18 +- src/lerobot/datasets/feature_utils.py | 203 +----------- src/lerobot/datasets/io_utils.py | 36 +-- src/lerobot/datasets/lerobot_dataset.py | 13 +- src/lerobot/datasets/multi_dataset.py | 9 +- src/lerobot/datasets/pipeline_features.py | 4 +- src/lerobot/datasets/streaming_dataset.py | 13 +- src/lerobot/datasets/utils.py | 86 +---- src/lerobot/datasets/video_utils.py | 19 +- src/lerobot/envs/__init__.py | 25 +- src/lerobot/envs/configs.py | 15 +- src/lerobot/envs/factory.py | 4 +- src/lerobot/envs/libero.py | 3 +- src/lerobot/envs/metaworld.py | 3 +- src/lerobot/envs/utils.py | 5 +- src/lerobot/model/__init__.py | 19 ++ src/lerobot/motors/__init__.py | 2 + src/lerobot/motors/damiao/__init__.py | 4 +- src/lerobot/motors/dynamixel/__init__.py | 4 +- src/lerobot/motors/dynamixel/dynamixel.py | 41 +-- src/lerobot/motors/feetech/__init__.py | 4 +- src/lerobot/motors/feetech/feetech.py | 45 ++- src/lerobot/motors/motors_bus.py | 18 +- src/lerobot/motors/robstride/__init__.py | 4 +- src/lerobot/optim/__init__.py | 43 ++- src/lerobot/optim/factory.py | 2 +- src/lerobot/optim/optimizers.py | 5 +- src/lerobot/optim/schedulers.py | 6 +- src/lerobot/policies/__init__.py | 35 +- src/lerobot/policies/act/__init__.py | 19 ++ src/lerobot/policies/act/configuration_act.py | 5 +- src/lerobot/policies/act/modeling_act.py | 5 +- src/lerobot/policies/act/processor_act.py | 6 +- src/lerobot/policies/diffusion/__init__.py | 19 ++ .../diffusion/configuration_diffusion.py | 6 +- .../policies/diffusion/modeling_diffusion.py | 19 +- .../policies/diffusion/processor_diffusion.py | 6 +- src/lerobot/policies/factory.py | 116 ++++--- .../policies/groot/action_head/__init__.py | 2 + .../groot/action_head/cross_attention_dit.py | 33 +- .../action_head/flow_matching_action_head.py | 3 +- .../policies/groot/configuration_groot.py | 6 +- src/lerobot/policies/groot/groot_n1.py | 7 +- src/lerobot/policies/groot/modeling_groot.py | 9 +- src/lerobot/policies/groot/processor_groot.py | 7 +- .../configuration_multi_task_dit.py | 6 +- .../multi_task_dit/modeling_multi_task_dit.py | 16 +- .../processor_multi_task_dit.py | 6 +- src/lerobot/policies/pi0/configuration_pi0.py | 9 +- src/lerobot/policies/pi0/modeling_pi0.py | 11 +- src/lerobot/policies/pi0/processor_pi0.py | 8 +- .../policies/pi05/configuration_pi05.py | 9 +- src/lerobot/policies/pi05/modeling_pi05.py | 11 +- src/lerobot/policies/pi05/processor_pi05.py | 8 +- .../pi0_fast/configuration_pi0_fast.py | 9 +- .../policies/pi0_fast/modeling_pi0_fast.py | 11 +- .../policies/pi0_fast/processor_pi0_fast.py | 8 +- src/lerobot/policies/pretrained.py | 5 +- src/lerobot/policies/rtc/__init__.py | 10 +- src/lerobot/policies/rtc/action_queue.py | 2 +- src/lerobot/policies/rtc/configuration_rtc.py | 2 +- src/lerobot/policies/rtc/modeling_rtc.py | 7 +- src/lerobot/policies/sac/__init__.py | 19 ++ src/lerobot/policies/sac/configuration_sac.py | 5 +- src/lerobot/policies/sac/modeling_sac.py | 7 +- src/lerobot/policies/sac/processor_sac.py | 6 +- .../policies/sac/reward_model/__init__.py | 19 ++ .../reward_model/configuration_classifier.py | 6 +- .../sac/reward_model/modeling_classifier.py | 5 +- .../sac/reward_model/processor_classifier.py | 6 +- src/lerobot/policies/sarm/__init__.py | 18 ++ .../policies/sarm/compute_rabc_weights.py | 9 +- .../policies/sarm/configuration_sarm.py | 6 +- src/lerobot/policies/sarm/modeling_sarm.py | 9 +- src/lerobot/policies/sarm/processor_sarm.py | 70 ++-- src/lerobot/policies/smolvla/__init__.py | 19 ++ .../policies/smolvla/configuration_smolvla.py | 11 +- .../policies/smolvla/modeling_smolvla.py | 15 +- .../policies/smolvla/processor_smolvla.py | 50 +-- .../policies/smolvla/smolvlm_with_expert.py | 26 +- src/lerobot/policies/tdmpc/__init__.py | 19 ++ .../policies/tdmpc/configuration_tdmpc.py | 5 +- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 7 +- src/lerobot/policies/tdmpc/processor_tdmpc.py | 6 +- src/lerobot/policies/utils.py | 5 +- src/lerobot/policies/vqbet/__init__.py | 19 ++ .../policies/vqbet/configuration_vqbet.py | 6 +- src/lerobot/policies/vqbet/modeling_vqbet.py | 9 +- src/lerobot/policies/vqbet/processor_vqbet.py | 6 +- src/lerobot/policies/vqbet/vqbet_utils.py | 2 +- src/lerobot/policies/wall_x/__init__.py | 2 + .../policies/wall_x/configuration_wall_x.py | 6 +- .../policies/wall_x/modeling_wall_x.py | 84 +++-- .../policies/wall_x/processor_wall_x.py | 8 +- src/lerobot/policies/wall_x/utils.py | 15 +- src/lerobot/policies/xvla/__init__.py | 13 +- .../policies/xvla/configuration_xvla.py | 6 +- src/lerobot/policies/xvla/modeling_xvla.py | 19 +- src/lerobot/policies/xvla/processor_xvla.py | 12 +- src/lerobot/processor/__init__.py | 26 +- src/lerobot/processor/batch_processor.py | 2 +- .../processor/delta_action_processor.py | 2 +- src/lerobot/processor/device_processor.py | 2 +- src/lerobot/processor/env_processor.py | 2 +- src/lerobot/processor/gym_action_processor.py | 6 +- src/lerobot/processor/hil_processor.py | 2 +- .../processor/migrate_policy_normalization.py | 4 +- .../processor/newline_task_processor.py | 59 ++++ src/lerobot/processor/normalize_processor.py | 9 +- .../processor/observation_processor.py | 2 +- src/lerobot/processor/pipeline.py | 5 +- src/lerobot/processor/policy_robot_bridge.py | 6 +- .../processor/relative_action_processor.py | 2 +- src/lerobot/processor/rename_processor.py | 2 +- src/lerobot/processor/tokenizer_processor.py | 2 +- src/lerobot/rl/__init__.py | 34 ++ src/lerobot/rl/actor.py | 6 +- src/lerobot/rl/buffer.py | 2 +- src/lerobot/rl/crop_dataset_roi.py | 2 +- src/lerobot/rl/eval_policy.py | 4 +- src/lerobot/rl/gym_manipulator.py | 8 +- .../rl/joint_observations_processor.py | 4 +- src/lerobot/rl/learner.py | 25 +- src/lerobot/rl/learner_service.py | 3 +- src/lerobot/robots/__init__.py | 2 + .../bi_openarm_follower.py | 2 +- .../config_bi_openarm_follower.py | 2 +- src/lerobot/robots/bi_so_follower/__init__.py | 2 + .../robots/bi_so_follower/bi_so_follower.py | 2 +- .../bi_so_follower/config_bi_so_follower.py | 3 +- src/lerobot/robots/hope_jr/__init__.py | 2 + src/lerobot/robots/hope_jr/hope_jr_arm.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_hand.py | 2 +- src/lerobot/robots/koch_follower/__init__.py | 2 + .../robots/koch_follower/koch_follower.py | 2 +- src/lerobot/robots/lekiwi/__init__.py | 2 + src/lerobot/robots/lekiwi/config_lekiwi.py | 4 +- src/lerobot/robots/lekiwi/lekiwi.py | 2 +- src/lerobot/robots/omx_follower/__init__.py | 2 + .../robots/omx_follower/omx_follower.py | 2 +- .../openarm_follower/openarm_follower.py | 2 +- src/lerobot/robots/reachy2/__init__.py | 10 + .../robots/reachy2/configuration_reachy2.py | 3 +- src/lerobot/robots/reachy2/robot_reachy2.py | 2 +- src/lerobot/robots/so_follower/__init__.py | 10 + .../so_follower/robot_kinematic_processor.py | 4 +- src/lerobot/robots/so_follower/so_follower.py | 2 +- .../robots/unitree_g1/gr00t_locomotion.py | 2 +- .../robots/unitree_g1/holosoma_locomotion.py | 2 +- src/lerobot/robots/unitree_g1/unitree_g1.py | 20 +- .../robots/unitree_g1/unitree_sdk2_socket.py | 2 +- .../scripts/augment_dataset_quantile_stats.py | 12 +- .../scripts/convert_dataset_v21_to_v30.py | 11 +- src/lerobot/scripts/lerobot_calibrate.py | 6 +- src/lerobot/scripts/lerobot_dataset_viz.py | 10 +- src/lerobot/scripts/lerobot_edit_dataset.py | 6 +- src/lerobot/scripts/lerobot_eval.py | 11 +- src/lerobot/scripts/lerobot_find_cameras.py | 8 +- .../scripts/lerobot_find_joint_limits.py | 2 +- src/lerobot/scripts/lerobot_find_port.py | 5 +- .../scripts/lerobot_imgtransform_viz.py | 6 +- src/lerobot/scripts/lerobot_record.py | 57 ++-- src/lerobot/scripts/lerobot_replay.py | 4 +- src/lerobot/scripts/lerobot_teleoperate.py | 14 +- src/lerobot/scripts/lerobot_train.py | 50 +-- .../scripts/lerobot_train_tokenizer.py | 5 +- src/lerobot/teleoperators/__init__.py | 2 + .../bi_openarm_leader/bi_openarm_leader.py | 3 +- .../config_bi_openarm_leader.py | 3 +- .../teleoperators/bi_so_leader/__init__.py | 2 + .../bi_so_leader/bi_so_leader.py | 3 +- .../bi_so_leader/config_bi_so_leader.py | 3 +- src/lerobot/teleoperators/gamepad/__init__.py | 2 + .../teleoperators/homunculus/__init__.py | 8 + .../homunculus/homunculus_arm.py | 10 +- .../homunculus/homunculus_glove.py | 12 +- .../teleoperators/keyboard/teleop_keyboard.py | 27 +- .../teleoperators/koch_leader/__init__.py | 2 + .../teleoperators/omx_leader/__init__.py | 2 + src/lerobot/teleoperators/phone/__init__.py | 2 + .../teleoperators/phone/phone_processor.py | 5 +- .../teleoperators/phone/teleop_phone.py | 5 +- .../reachy2_teleoperator/__init__.py | 10 + .../teleoperators/so_leader/__init__.py | 10 + .../teleoperators/unitree_g1/exo_calib.py | 13 +- .../teleoperators/unitree_g1/exo_serial.py | 11 +- src/lerobot/transforms/__init__.py | 31 ++ .../{datasets => transforms}/transforms.py | 0 src/lerobot/transport/__init__.py | 29 ++ src/lerobot/utils/__init__.py | 65 ++++ src/lerobot/utils/constants.py | 15 + src/lerobot/utils/decorators.py | 2 +- src/lerobot/utils/feature_utils.py | 223 +++++++++++++ src/lerobot/utils/import_utils.py | 51 +++ src/lerobot/utils/io_utils.py | 77 ++++- src/lerobot/utils/logging_utils.py | 2 +- src/lerobot/utils/random_utils.py | 4 +- src/lerobot/utils/transition.py | 2 +- src/lerobot/utils/utils.py | 93 +++++- src/lerobot/utils/visualization_utils.py | 19 +- .../save_image_transforms_to_safetensors.py | 2 +- .../policies/save_policy_to_safetensors.py | 2 +- tests/async_inference/test_e2e.py | 4 +- tests/async_inference/test_helpers.py | 10 +- tests/async_inference/test_policy_server.py | 4 +- tests/async_inference/test_robot_client.py | 4 +- tests/conftest.py | 27 +- tests/datasets/test_aggregate.py | 6 +- tests/datasets/test_compute_stats.py | 2 + tests/datasets/test_dataset_metadata.py | 2 + tests/datasets/test_dataset_reader.py | 6 +- tests/datasets/test_dataset_tools.py | 2 + tests/datasets/test_dataset_utils.py | 7 +- tests/datasets/test_dataset_writer.py | 2 + tests/datasets/test_datasets.py | 36 ++- tests/datasets/test_delta_timestamps.py | 2 + tests/datasets/test_image_transforms.py | 12 +- tests/datasets/test_image_writer.py | 2 + tests/datasets/test_lerobot_dataset.py | 2 + .../test_quantiles_dataset_integration.py | 2 + tests/datasets/test_sampler.py | 5 +- tests/datasets/test_streaming.py | 2 + .../datasets/test_streaming_video_encoder.py | 5 +- tests/datasets/test_subtask_dataset.py | 5 +- tests/datasets/test_visualize_dataset.py | 2 + tests/envs/test_envs.py | 12 +- tests/fixtures/dataset_factories.py | 4 +- tests/mocks/mock_dynamixel.py | 19 +- tests/mocks/mock_feetech.py | 18 +- tests/mocks/mock_motors_bus.py | 14 +- tests/motors/test_motors_bus.py | 2 + tests/optim/test_schedulers.py | 3 + tests/policies/groot/test_groot_lerobot.py | 2 +- .../hilserl/test_modeling_classifier.py | 10 +- tests/policies/smolvla/test_smolvla_rtc.py | 20 +- tests/policies/test_policies.py | 30 +- tests/policies/test_relative_actions.py | 2 + tests/processor/test_pipeline.py | 2 + tests/processor/test_smolvla_processor.py | 28 +- tests/processor/test_tokenizer_processor.py | 118 +++---- tests/rl/test_actor.py | 15 +- tests/rl/test_actor_learner.py | 11 +- tests/rl/test_learner_service.py | 18 +- tests/rl/test_queue.py | 8 +- tests/scripts/test_edit_dataset_parsing.py | 2 + tests/test_available.py | 72 +++-- tests/test_cli_peft.py | 10 +- tests/test_control_robot.py | 5 + tests/training/test_multi_gpu.py | 2 + tests/training/test_visual_validation.py | 2 + tests/transport/test_transport_utils.py | 64 ++-- tests/utils.py | 16 +- tests/utils/test_process.py | 4 +- tests/utils/test_replay_buffer.py | 14 +- tests/utils/test_train_utils.py | 24 +- tests/utils/test_visualization_utils.py | 5 + uv.lock | 303 ++++++++++++++---- 343 files changed, 3248 insertions(+), 1930 deletions(-) create mode 100644 src/lerobot/async_inference/__init__.py create mode 100644 src/lerobot/common/__init__.py rename src/lerobot/{utils => common}/control_utils.py (95%) rename src/lerobot/{utils => common}/train_utils.py (95%) rename src/lerobot/{rl => common}/wandb_utils.py (100%) create mode 100644 src/lerobot/configs/__init__.py create mode 100644 src/lerobot/model/__init__.py create mode 100644 src/lerobot/policies/act/__init__.py create mode 100644 src/lerobot/policies/diffusion/__init__.py create mode 100644 src/lerobot/policies/sac/__init__.py create mode 100644 src/lerobot/policies/sac/reward_model/__init__.py create mode 100644 src/lerobot/policies/sarm/__init__.py create mode 100644 src/lerobot/policies/smolvla/__init__.py create mode 100644 src/lerobot/policies/tdmpc/__init__.py create mode 100644 src/lerobot/policies/vqbet/__init__.py create mode 100644 src/lerobot/processor/newline_task_processor.py create mode 100644 src/lerobot/rl/__init__.py create mode 100644 src/lerobot/transforms/__init__.py rename src/lerobot/{datasets => transforms}/transforms.py (100%) create mode 100644 src/lerobot/transport/__init__.py create mode 100644 src/lerobot/utils/__init__.py create mode 100644 src/lerobot/utils/feature_utils.py diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml index d78bdd21b..b6680db73 100644 --- a/.github/workflows/fast_tests.yml +++ b/.github/workflows/fast_tests.yml @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This workflow handles fast testing. +# This workflow validates each optional-dependency tier in isolation. +# Each tier installs a different extra and runs the full test suite. +# Tests that require an extra not installed in the current tier are +# skipped automatically via pytest.importorskip guards. name: Fast Tests on: @@ -54,8 +57,9 @@ concurrency: cancel-in-progress: true jobs: - # This job runs pytests with the default dependencies. - # It runs everytime we commit to a PR or push to main + # This job runs pytests in isolated dependency tiers. + # Each tier installs a different extra and runs the full suite; + # tests gated behind other extras skip automatically. fast-pytest-tests: name: Fast Pytest Tests runs-on: ubuntu-latest @@ -89,8 +93,9 @@ jobs: version: ${{ env.UV_VERSION }} python-version: ${{ env.PYTHON_VERSION }} - - name: Install lerobot with test extras - run: uv sync --locked --extra "test" + # ── Tier 1: Base ────────────────────────────────────── + - name: "Tier 1 — Install: base" + run: uv sync --locked --extra test - name: Login to Hugging Face if: env.HF_USER_TOKEN != '' @@ -98,5 +103,26 @@ jobs: uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential uv run hf auth whoami - - name: Run pytest + - name: "Tier 1 — Test: base" + run: uv run pytest tests -vv --maxfail=10 + + # ── Tier 2: Dataset ────────────────────────────────── + - name: "Tier 2 — Install: dataset" + run: uv sync --locked --extra test --extra dataset + + - name: "Tier 2 — Test: dataset" + run: uv run pytest tests -vv --maxfail=10 + + # ── Tier 3: Hardware ───────────────────────────────── + - name: "Tier 3 — Install: hardware" + run: uv sync --locked --extra test --extra hardware + + - name: "Tier 3 — Test: hardware" + run: uv run pytest tests -vv --maxfail=10 + + # ── Tier 4: Viz ────────────────────────────────────── + - name: "Tier 4 — Install: viz" + run: uv sync --locked --extra test --extra viz + + - name: "Tier 4 — Test: viz" run: uv run pytest tests -vv --maxfail=10 diff --git a/docs/source/adding_benchmarks.mdx b/docs/source/adding_benchmarks.mdx index 3a024f026..6e9d23bdf 100644 --- a/docs/source/adding_benchmarks.mdx +++ b/docs/source/adding_benchmarks.mdx @@ -216,7 +216,7 @@ class MyBenchmarkEnvConfig(EnvConfig): def get_env_processors(self): """Override if your benchmark needs observation/action transforms.""" - from lerobot.processor.pipeline import PolicyProcessorPipeline + from lerobot.processor import PolicyProcessorPipeline from lerobot.processor.env_processor import MyBenchmarkProcessorStep return ( PolicyProcessorPipeline(steps=[MyBenchmarkProcessorStep()]), diff --git a/docs/source/async.mdx b/docs/source/async.mdx index a46408a0d..7b1efae97 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -170,7 +170,7 @@ python -m lerobot.async_inference.robot_client \ ```python import threading from lerobot.robots.so_follower import SO100FollowerConfig -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.cameras.opencv import OpenCVCameraConfig from lerobot.async_inference.configs import RobotClientConfig from lerobot.async_inference.robot_client import RobotClient from lerobot.async_inference.helpers import visualize_action_queue_size diff --git a/docs/source/backwardcomp.mdx b/docs/source/backwardcomp.mdx index 3366c8ab9..a83ee2e2e 100644 --- a/docs/source/backwardcomp.mdx +++ b/docs/source/backwardcomp.mdx @@ -41,7 +41,7 @@ The script: ```python # New usage pattern (after migration) -from lerobot.policies.factory import make_policy, make_pre_post_processors +from lerobot.policies import make_policy, make_pre_post_processors # Load model and processors separately policy = make_policy(config, ds_meta=dataset.meta) diff --git a/docs/source/bring_your_own_policies.mdx b/docs/source/bring_your_own_policies.mdx index 38c32aa71..57ecc2fb2 100644 --- a/docs/source/bring_your_own_policies.mdx +++ b/docs/source/bring_your_own_policies.mdx @@ -47,9 +47,9 @@ Here is a template to get you started, customize the parameters and methods as n ```python # configuration_my_custom_policy.py from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.configs import PreTrainedConfig +from lerobot.optim import AdamWConfig +from lerobot.optim import CosineDecayWithWarmupSchedulerConfig @PreTrainedConfig.register_subclass("my_custom_policy") @dataclass @@ -120,7 +120,7 @@ import torch import torch.nn as nn from typing import Any -from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies import PreTrainedPolicy from lerobot.utils.constants import ACTION from .configuration_my_custom_policy import MyCustomPolicyConfig diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx index 8af0f5ae5..2dc2859dd 100644 --- a/docs/source/cameras.mdx +++ b/docs/source/cameras.mdx @@ -79,9 +79,8 @@ The following examples show how to use the camera API to configure and capture f ```python -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.cameras.opencv.camera_opencv import OpenCVCamera -from lerobot.cameras.configs import ColorMode, Cv2Rotation +from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig +from lerobot.cameras import ColorMode, Cv2Rotation # Construct an `OpenCVCameraConfig` with your desired FPS, resolution, color mode, and rotation. config = OpenCVCameraConfig( @@ -126,9 +125,8 @@ with OpenCVCamera(config) as camera: ```python -from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig -from lerobot.cameras.realsense.camera_realsense import RealSenseCamera -from lerobot.cameras.configs import ColorMode, Cv2Rotation +from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig +from lerobot.cameras import ColorMode, Cv2Rotation # Create a `RealSenseCameraConfig` specifying your camera’s serial number and enabling depth. config = RealSenseCameraConfig( diff --git a/docs/source/dataset_subtask.mdx b/docs/source/dataset_subtask.mdx index beb5d80bd..6264aca22 100644 --- a/docs/source/dataset_subtask.mdx +++ b/docs/source/dataset_subtask.mdx @@ -95,7 +95,7 @@ After completing your annotation: When you load a dataset with subtask annotations, the subtask information is automatically available: ```python -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset # Load a dataset with subtask annotations dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") @@ -133,11 +133,10 @@ if has_subtasks: The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models: ```python -from lerobot.processor.tokenizer_processor import TokenizerProcessor -from lerobot.processor.pipeline import ProcessorPipeline +from lerobot.processor import TokenizerProcessorStep -# Create a tokenizer processor -tokenizer_processor = TokenizerProcessor( +# Create a tokenizer processor step +tokenizer_processor = TokenizerProcessorStep( tokenizer_name_or_path="google/paligemma-3b-pt-224", padding="max_length", max_length=64, @@ -158,7 +157,7 @@ When subtasks are available in the batch, the tokenizer processor adds: ```python import torch -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") @@ -182,7 +181,7 @@ for batch in dataloader: Try loading a dataset with subtask annotations: ```python -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset # Example dataset with subtask annotations dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index 884e84d8c..a87bd325b 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -66,10 +66,10 @@ The SDK gives you: Follow our [Installation Guide](./installation) to install LeRobot. -In addition to the base installation, install the EarthRover Mini dependencies: +In addition to the base installation, install the EarthRover Mini with hardware dependencies: ```bash -pip install -e . +pip install -e ".[hardware]" ``` ## How It Works diff --git a/docs/source/env_processor.mdx b/docs/source/env_processor.mdx index 290af3b34..8bfafdfb9 100644 --- a/docs/source/env_processor.mdx +++ b/docs/source/env_processor.mdx @@ -173,8 +173,8 @@ observation = { The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies: ```python -from lerobot.envs.factory import make_env_pre_post_processors -from lerobot.envs.configs import LiberoEnv, PushtEnv +from lerobot.envs import make_env_pre_post_processors, PushtEnv +from lerobot.envs.configs import LiberoEnv # For LIBERO: Returns LiberoProcessorStep in preprocessor libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"]) @@ -257,7 +257,7 @@ def eval_main(cfg: EvalPipelineConfig): The `LiberoProcessorStep` demonstrates a real-world environment processor: ```python -from lerobot.processor.pipeline import ObservationProcessorStep +from lerobot.processor import ObservationProcessorStep @dataclass @ProcessorStepRegistry.register(name="libero_processor") diff --git a/docs/source/envhub.mdx b/docs/source/envhub.mdx index 36c08a8b3..47f5567a8 100644 --- a/docs/source/envhub.mdx +++ b/docs/source/envhub.mdx @@ -34,7 +34,7 @@ Finally, your environment must implement the standard `gym.vector.VectorEnv` int Loading an environment from the Hub is as simple as: ```python -from lerobot.envs.factory import make_env +from lerobot.envs import make_env # Load a hub environment (requires explicit consent to run remote code) env = make_env("lerobot/cartpole-env", trust_remote_code=True) @@ -191,7 +191,7 @@ api.upload_folder( ### Basic Usage ```python -from lerobot.envs.factory import make_env +from lerobot.envs import make_env # Load from the hub envs_dict = make_env( @@ -314,7 +314,7 @@ env = make_env("trusted-org/verified-env@a1b2c3d4", trust_remote_code=True) Here's a complete example using the reference CartPole environment: ```python -from lerobot.envs.factory import make_env +from lerobot.envs import make_env import numpy as np # Load the environment diff --git a/docs/source/envhub_isaaclab_arena.mdx b/docs/source/envhub_isaaclab_arena.mdx index 828d51bad..b934240d6 100644 --- a/docs/source/envhub_isaaclab_arena.mdx +++ b/docs/source/envhub_isaaclab_arena.mdx @@ -58,10 +58,10 @@ pip install -e . cd .. -# 5. Install LeRobot +# 5. Install LeRobot (evaluation extra for env/policy evaluation) git clone https://github.com/huggingface/lerobot.git cd lerobot -pip install -e . +pip install -e ".[evaluation]" cd .. @@ -262,7 +262,7 @@ def main(cfg: EvalPipelineConfig): """Run random action rollout for IsaacLab Arena environment.""" logging.info(pformat(asdict(cfg))) - from lerobot.envs.factory import make_env + from lerobot.envs import make_env env_dict = make_env( cfg.env, diff --git a/docs/source/envhub_leisaac.mdx b/docs/source/envhub_leisaac.mdx index 2537700a5..91bb6a871 100644 --- a/docs/source/envhub_leisaac.mdx +++ b/docs/source/envhub_leisaac.mdx @@ -74,7 +74,7 @@ EnvHub exposes every LeIsaac-supported task in a uniform interface. The examples # envhub_random_action.py import torch -from lerobot.envs.factory import make_env +from lerobot.envs import make_env # Load from the hub envs_dict = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True) @@ -142,7 +142,7 @@ from lerobot.teleoperators import ( # noqa: F401 ) from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import init_logging -from lerobot.envs.factory import make_env +from lerobot.envs import make_env @dataclass @@ -282,7 +282,7 @@ Note: when working with `bi_so101_fold_cloth`, call `initialize()` immediately a ```python import torch -from lerobot.envs.factory import make_env +from lerobot.envs import make_env # Load from the hub envs_dict = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 8e50a2aec..d03e35d8d 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -58,8 +58,8 @@ lerobot-teleoperate \ ```python -from lerobot.teleoperators.so_leader import SO101LeaderConfig, SO101Leader -from lerobot.robots.so_follower import SO101FollowerConfig, SO101Follower +from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig +from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig robot_config = SO101FollowerConfig( port="/dev/tty.usbmodem58760431541", @@ -116,9 +116,9 @@ lerobot-teleoperate \ ```python -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader -from lerobot.robots.koch_follower import KochFollowerConfig, KochFollower +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig +from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig camera_config = { "front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30) @@ -195,13 +195,12 @@ lerobot-record \ ```python -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.datasets import LeRobotDataset +from lerobot.utils.feature_utils import hw_to_dataset_features from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig -from lerobot.teleoperators.so_leader.config_so100_leader import SO100LeaderConfig -from lerobot.teleoperators.so_leader.so100_leader import SO100Leader -from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig +from lerobot.common.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun from lerobot.scripts.lerobot_record import record_loop @@ -410,9 +409,8 @@ lerobot-replay \ ```python import time -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so_follower.so100_follower import SO100Follower +from lerobot.datasets import LeRobotDataset +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say @@ -532,15 +530,14 @@ lerobot-record \ ```python -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import hw_to_dataset_features -from lerobot.policies.act.modeling_act import ACTPolicy -from lerobot.policies.factory import make_pre_post_processors -from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so_follower.so100_follower import SO100Follower +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.datasets import LeRobotDataset +from lerobot.utils.feature_utils import hw_to_dataset_features +from lerobot.policies.act import ACTPolicy +from lerobot.policies import make_pre_post_processors +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.scripts.lerobot_record import record_loop -from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.common.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index a988523b5..1d772fc97 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -116,6 +116,8 @@ brew install ffmpeg ## Step 3: Install LeRobot 🤗 +The base `lerobot` install is intentionally **lightweight** — it includes only core ML dependencies (PyTorch, torchvision, numpy, opencv, einops, draccus, huggingface-hub, gymnasium, safetensors). Heavier dependencies are gated behind optional extras so you only install what you need. + ### From Source First, clone the repository and navigate into the directory: @@ -131,12 +133,16 @@ Then, install the library in editable mode. This is useful if you plan to contri ```bash -pip install -e . +pip install -e ".[core_scripts]" # For robot workflows (recording, replaying, calibrate) +pip install -e ".[training]" # For training policies +pip install -e ".[all]" # Everything (all policies, envs, hardware, dev tools) ``` ```bash -uv pip install -e . +uv pip install -e ".[core_scripts]" # For robot workflows (recording, replaying, calibrate) +uv pip install -e ".[training]" # For training policies +uv pip install -e ".[all]" # Everything (all policies, envs, hardware, dev tools) ``` @@ -162,26 +168,48 @@ uv pip install lerobot -_This installs only the default dependencies._ +_This installs only the core ML dependencies. You will need to add extras for most workflows._ -**Extra Features:** -To install additional functionality, use one of the following (If you are using `uv`, replace `pip install` with `uv pip install` in the commands below.): +**Feature Extras:** +LeRobot provides **feature-scoped extras** that map to common workflows. If you are using `uv`, replace `pip install` with `uv pip install` in the commands below. + +| Extra | What it adds | Typical use case | +| ---------- | ------------------------------------------- | ----------------------------------- | +| `dataset` | `datasets`, `av`, `torchcodec`, `jsonlines` | Loading & creating datasets | +| `training` | `dataset` + `accelerate`, `wandb` | Training policies | +| `hardware` | `pynput`, `pyserial`, `deepdiff` | Connecting to real robots | +| `viz` | `rerun-sdk` | Visualization during recording/eval | + +**Composite Extras** combine feature extras for common CLI scripts: + +| Extra | Includes | Typical use case | +| -------------- | ------------------------------ | ------------------------------------------------------- | +| `core_scripts` | `dataset` + `hardware` + `viz` | `lerobot-record`, `lerobot-replay`, `lerobot-calibrate` | +| `evaluation` | `av` | `lerobot-eval` (add policy + env extras as needed) | +| `dataset_viz` | `dataset` + `viz` | `lerobot-dataset-viz`, `lerobot-imgtransform-viz` | ```bash -pip install 'lerobot[all]' # All available features -pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht) -pip install 'lerobot[feetech]' # Feetech motor support +pip install 'lerobot[core_scripts]' # Record, replay, calibrate +pip install 'lerobot[training]' # Train policies +pip install 'lerobot[core_scripts,training]' # Record + train +pip install 'lerobot[all]' # Everything ``` -_Replace `[...]` with your desired features._ +**Policy, environment, and hardware extras** are still available for specific dependencies: -**Available Tags:** -For a full list of optional dependencies, see: -https://pypi.org/project/lerobot/ +```bash +pip install 'lerobot[pi]' # Pi0/Pi0.5/Pi0-FAST policy deps +pip install 'lerobot[smolvla]' # SmolVLA policy deps +pip install 'lerobot[diffusion]' # Diffusion policy deps (diffusers) +pip install 'lerobot[aloha,pusht]' # Simulation environments +pip install 'lerobot[feetech]' # Feetech motor support +``` + +_Multiple extras can be combined (e.g., `.[core_scripts,pi,pusht]`). For a full list of available extras, refer to `pyproject.toml`._ ### Troubleshooting -If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`. +If you encounter build errors, you may need to install additional system dependencies: `cmake`, `build-essential`, and `ffmpeg libs`. To install these for Linux run: ```bash @@ -196,8 +224,8 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c ### Simulations -Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)) -Example: +Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)). +These automatically include the `dataset` extra. ```bash pip install -e ".[aloha]" # or "[pusht]" for example @@ -213,7 +241,7 @@ pip install -e ".[feetech]" # or "[dynamixel]" for example ### Experiment Tracking -To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with +Weights and Biases is included in the `training` extra. To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with: ```bash wandb login diff --git a/docs/source/introduction_processors.mdx b/docs/source/introduction_processors.mdx index 6f3768615..4395e889b 100644 --- a/docs/source/introduction_processors.mdx +++ b/docs/source/introduction_processors.mdx @@ -19,10 +19,10 @@ This means that your favorite policy can be used like this: ```python import torch -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.factory import make_pre_post_processors +from lerobot.datasets import LeRobotDataset +from lerobot.policies import make_pre_post_processors from lerobot.policies.your_policy import YourPolicy -from lerobot.processor.pipeline import RobotProcessorPipeline, PolicyProcessorPipeline +from lerobot.processor import RobotProcessorPipeline, PolicyProcessorPipeline dataset = LeRobotDataset("hf_user/dataset", episodes=[0]) sample = dataset[10] @@ -260,7 +260,7 @@ Since processor pipelines can add new features (like velocity fields), change te These functions work together by starting with robot hardware specifications (`create_initial_features()`) then simulating the entire pipeline transformation (`aggregate_pipeline_dataset_features()`) to compute the final feature dictionary that gets passed to `LeRobotDataset.create()`, ensuring perfect alignment between what processors output and what datasets expect to store. ```python -from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features +from lerobot.datasets import aggregate_pipeline_dataset_features # Start with robot's raw features initial_features = create_initial_features( diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx index 235a355bd..8ab4a5d40 100644 --- a/docs/source/lerobot-dataset-v3.mdx +++ b/docs/source/lerobot-dataset-v3.mdx @@ -89,7 +89,7 @@ A core v3 principle is **decoupling storage from the user API**: data is stored ```python import torch -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset repo_id = "yaak-ai/L2D-v3" @@ -135,7 +135,7 @@ for batch in data_loader: Use `StreamingLeRobotDataset` to iterate directly from the Hub without local copies. This allows to stream large datasets without the need to downloading them onto disk or loading them onto memory, and is a key feature of the new dataset format. ```python -from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset +from lerobot.datasets import StreamingLeRobotDataset repo_id = "yaak-ai/L2D-v3" dataset = StreamingLeRobotDataset(repo_id) # streams directly from the Hub @@ -167,8 +167,8 @@ Currently, transforms are applied during **training time only**, not during reco Use the `image_transforms` parameter when loading a dataset for training: ```python -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig, ImageTransformConfig +from lerobot.datasets import LeRobotDataset +from lerobot.transforms import ImageTransforms, ImageTransformsConfig, ImageTransformConfig # Option 1: Use default transform configuration (disabled by default) transforms_config = ImageTransformsConfig( @@ -290,7 +290,7 @@ python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id= list[float]: diff --git a/examples/tutorial/act/act_using_example.py b/examples/tutorial/act/act_using_example.py index 15254d8eb..6a8f73287 100644 --- a/examples/tutorial/act/act_using_example.py +++ b/examples/tutorial/act/act_using_example.py @@ -1,9 +1,9 @@ import torch -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.policies.act.modeling_act import ACTPolicy -from lerobot.policies.factory import make_pre_post_processors +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.datasets import LeRobotDatasetMetadata +from lerobot.policies import make_pre_post_processors +from lerobot.policies.act import ACTPolicy from lerobot.policies.utils import build_inference_frame, make_robot_action from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig diff --git a/examples/tutorial/async-inf/robot_client.py b/examples/tutorial/async-inf/robot_client.py index db6ead3fe..ac2331f38 100644 --- a/examples/tutorial/async-inf/robot_client.py +++ b/examples/tutorial/async-inf/robot_client.py @@ -3,7 +3,7 @@ import threading from lerobot.async_inference.configs import RobotClientConfig from lerobot.async_inference.helpers import visualize_action_queue_size from lerobot.async_inference.robot_client import RobotClient -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.cameras.opencv import OpenCVCameraConfig from lerobot.robots.so_follower import SO100FollowerConfig diff --git a/examples/tutorial/diffusion/diffusion_training_example.py b/examples/tutorial/diffusion/diffusion_training_example.py index dc6ca68a3..5cca15923 100644 --- a/examples/tutorial/diffusion/diffusion_training_example.py +++ b/examples/tutorial/diffusion/diffusion_training_example.py @@ -4,13 +4,11 @@ from pathlib import Path import torch -from lerobot.configs.types import FeatureType -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.datasets.feature_utils import dataset_to_policy_features -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy -from lerobot.policies.factory import make_pre_post_processors +from lerobot.configs import FeatureType +from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.policies import make_pre_post_processors +from lerobot.policies.diffusion import DiffusionConfig, DiffusionPolicy +from lerobot.utils.feature_utils import dataset_to_policy_features def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]: diff --git a/examples/tutorial/diffusion/diffusion_using_example.py b/examples/tutorial/diffusion/diffusion_using_example.py index 9b31cf359..8f9150ad6 100644 --- a/examples/tutorial/diffusion/diffusion_using_example.py +++ b/examples/tutorial/diffusion/diffusion_using_example.py @@ -1,9 +1,9 @@ import torch -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy -from lerobot.policies.factory import make_pre_post_processors +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.datasets import LeRobotDatasetMetadata +from lerobot.policies import make_pre_post_processors +from lerobot.policies.diffusion import DiffusionPolicy from lerobot.policies.utils import build_inference_frame, make_robot_action from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig diff --git a/examples/tutorial/pi0/using_pi0_example.py b/examples/tutorial/pi0/using_pi0_example.py index d8cf9dbff..66f6309c2 100644 --- a/examples/tutorial/pi0/using_pi0_example.py +++ b/examples/tutorial/pi0/using_pi0_example.py @@ -1,11 +1,11 @@ import torch -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.feature_utils import hw_to_dataset_features -from lerobot.policies.factory import make_pre_post_processors -from lerobot.policies.pi0.modeling_pi0 import PI0Policy +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.policies import make_pre_post_processors +from lerobot.policies.pi0 import PI0Policy from lerobot.policies.utils import build_inference_frame, make_robot_action from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.utils.feature_utils import hw_to_dataset_features MAX_EPISODES = 5 MAX_STEPS_PER_EPISODE = 20 diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py index d367a01ce..8a08d6d56 100644 --- a/examples/tutorial/rl/hilserl_example.py +++ b/examples/tutorial/rl/hilserl_example.py @@ -6,17 +6,17 @@ from queue import Empty, Full import torch import torch.optim as optim -from lerobot.datasets.feature_utils import hw_to_dataset_features -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig -from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.policies import SACConfig from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.policies.sac.reward_model.modeling_classifier import Classifier from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.gym_manipulator import make_robot_env from lerobot.robots.so_follower import SO100FollowerConfig +from lerobot.teleoperators import TeleopEvents from lerobot.teleoperators.so_leader import SO100LeaderConfig -from lerobot.teleoperators.utils import TeleopEvents +from lerobot.utils.feature_utils import hw_to_dataset_features LOG_EVERY = 10 SEND_EVERY = 10 diff --git a/examples/tutorial/rl/reward_classifier_example.py b/examples/tutorial/rl/reward_classifier_example.py index 4af6b899c..b386bf4db 100644 --- a/examples/tutorial/rl/reward_classifier_example.py +++ b/examples/tutorial/rl/reward_classifier_example.py @@ -1,8 +1,7 @@ import torch -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.factory import make_policy, make_pre_post_processors -from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.datasets import LeRobotDataset +from lerobot.policies import RewardClassifierConfig, make_policy, make_pre_post_processors def main(): diff --git a/examples/tutorial/smolvla/using_smolvla_example.py b/examples/tutorial/smolvla/using_smolvla_example.py index b99126efa..f59603db7 100644 --- a/examples/tutorial/smolvla/using_smolvla_example.py +++ b/examples/tutorial/smolvla/using_smolvla_example.py @@ -1,11 +1,11 @@ import torch -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.datasets.feature_utils import hw_to_dataset_features -from lerobot.policies.factory import make_pre_post_processors -from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.policies import make_pre_post_processors +from lerobot.policies.smolvla import SmolVLAPolicy from lerobot.policies.utils import build_inference_frame, make_robot_action from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.utils.feature_utils import hw_to_dataset_features MAX_EPISODES = 5 MAX_STEPS_PER_EPISODE = 20 diff --git a/pyproject.toml b/pyproject.toml index 79409a200..2f12840e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,45 +58,74 @@ classifiers = [ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artificial intelligence"] dependencies = [ - - # Hugging Face dependencies - "datasets>=4.0.0,<5.0.0", - "diffusers>=0.27.2,<0.36.0", - "huggingface-hub>=1.0.0,<2.0.0", - "accelerate>=1.10.0,<2.0.0", - - # Core dependencies - "numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless. - "setuptools>=71.0.0,<81.0.0", - "cmake>=3.29.0.1,<4.2.0", - "packaging>=24.2,<26.0", - + # Core ML "torch>=2.7,<2.11.0", - "torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10). "torchvision>=0.22.0,<0.26.0", - - "einops>=0.8.0,<0.9.0", + "numpy>=2.0.0,<2.3.0", # NOTE: Explicitly listing numpy helps the resolver converge faster. Upper bound imposed by opencv-python-headless. "opencv-python-headless>=4.9.0,<4.14.0", - "av>=15.0.0,<16.0.0", - "jsonlines>=4.0.0,<5.0.0", - "pynput>=1.7.8,<1.9.0", - "pyserial>=3.5,<4.0", + "Pillow>=10.0.0,<13.0.0", + "einops>=0.8.0,<0.9.0", - "wandb>=0.24.0,<0.25.0", + # Config & Hub "draccus==0.10.0", # TODO: Relax version constraint - "gymnasium>=1.1.1,<2.0.0", - "rerun-sdk>=0.24.0,<0.27.0", + "huggingface-hub>=1.0.0,<2.0.0", + "requests>=2.32.0,<3.0.0", - # Support dependencies - "deepdiff>=7.0.1,<9.0.0", - "imageio[ffmpeg]>=2.34.0,<3.0.0", + # Environments + # NOTE: gymnasium is used in lerobot.envs (lerobot-train, lerobot-eval), policies/factory, + # and robots/unitree. Moving it to an optional extra would require import guards across many + # tightly-coupled modules. Candidate for a future refactor to decouple envs from the core. + "gymnasium>=1.1.1,<2.0.0", + + # Serialization & checkpointing + "safetensors>=0.4.3,<1.0.0", + + # Lightweight utilities + "packaging>=24.2,<26.0", "termcolor>=2.4.0,<4.0.0", + "tqdm>=4.66.0,<5.0.0", + + # Build tools (required by opencv-python-headless on some platforms) + "cmake>=3.29.0.1,<4.2.0", + "setuptools>=71.0.0,<81.0.0", ] # Optional dependencies [project.optional-dependencies] +# ── Feature-scoped extras ────────────────────────────────── +dataset = [ + "datasets>=4.0.0,<5.0.0", + "pandas>=2.0.0,<3.0.0", # NOTE: Transitive dependency of datasets + "pyarrow>=21.0.0,<30.0.0", # NOTE: Transitive dependency of datasets + "lerobot[av-dep]", + "torchcodec>=0.3.0,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # NOTE: Windows support starts at version 0.7 (needs torch==2.8), ffmpeg>=8 support starts at version 0.8.1 (needs torch==2.9), system-wide ffmpeg support starts at version 0.10 (needs torch==2.10). + "jsonlines>=4.0.0,<5.0.0", +] +training = [ + "lerobot[dataset]", + "accelerate>=1.10.0,<2.0.0", + "wandb>=0.24.0,<0.25.0", +] +hardware = [ + "pynput>=1.7.8,<1.9.0", + "pyserial>=3.5,<4.0", + "deepdiff>=7.0.1,<9.0.0", +] +viz = [ + "rerun-sdk>=0.24.0,<0.27.0", +] +# ── User-facing composite extras (map to CLI scripts) ───── +# lerobot-record, lerobot-replay, lerobot-calibrate, lerobot-teleoperate, etc. +core_scripts = ["lerobot[dataset]", "lerobot[hardware]", "lerobot[viz]"] +# lerobot-eval -- base evaluation framework. You also need the policy's extra (e.g., lerobot[pi]) +# and the environment's extra (e.g., lerobot[pusht]) if evaluating in simulation. +evaluation = ["lerobot[av-dep]"] +# lerobot-dataset-viz, lerobot-imgtransform-viz +dataset_viz = ["lerobot[dataset]", "lerobot[viz]"] + # Common +av-dep = ["av>=15.0.0,<16.0.0"] pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.9.17"] transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249 @@ -104,6 +133,7 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] scipy-dep = ["scipy>=1.14.0,<2.0.0"] +diffusers-dep = ["diffusers>=0.27.2,<0.36.0"] qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"] matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0", "contourpy>=1.3.0,<2.0.0"] # NOTE: Explicitly listing contourpy helps the resolver converge faster. @@ -136,28 +166,28 @@ intelrealsense = [ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0", "lerobot[scipy-dep]"] # Policies +diffusion = ["lerobot[diffusers-dep]"] wallx = [ "lerobot[transformers-dep]", - "lerobot[peft]", + "lerobot[peft-dep]", "lerobot[scipy-dep]", "torchdiffeq>=0.2.4,<0.3.0", "lerobot[qwen-vl-utils-dep]", ] pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] -smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] -multi_task_dit = ["lerobot[transformers-dep]"] +smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"] +multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"] groot = [ "lerobot[transformers-dep]", - "lerobot[peft]", + "lerobot[peft-dep]", + "lerobot[diffusers-dep]", "dm-tree>=0.1.8,<1.0.0", "timm>=1.0.0,<1.1.0", - "safetensors>=0.4.3,<1.0.0", - "Pillow>=10.0.0,<13.0.0", "decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "ninja>=1.11.1,<2.0.0", "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" ] -sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] +sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"] xvla = ["lerobot[transformers-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] @@ -166,31 +196,42 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"] # Development -dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"] +dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1"] test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"] video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation # NOTE: Explicitly listing scipy helps flatten the dependecy tree. -aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"] -pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] -metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"] +aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"] +pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead +libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] +metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"] # All all = [ + # Feature-scoped extras + "lerobot[dataset]", + "lerobot[training]", + "lerobot[hardware]", + "lerobot[viz]", # NOTE(resolver hint): scipy is pulled in transitively via lerobot[scipy-dep] through # multiple extras (aloha, metaworld, pi, wallx, phone). Listing it explicitly # helps pip's resolver converge by constraining scipy early, before it encounters # the loose scipy requirements from transitive deps like dm-control and metaworld. "scipy>=1.14.0,<2.0.0", "lerobot[dynamixel]", + "lerobot[feetech]", + "lerobot[damiao]", + "lerobot[robstride]", "lerobot[gamepad]", "lerobot[hopejr]", "lerobot[lekiwi]", + "lerobot[openarms]", "lerobot[reachy2]", "lerobot[kinematics]", "lerobot[intelrealsense]", + "lerobot[diffusion]", + "lerobot[multi_task_dit]", "lerobot[wallx]", "lerobot[pi]", "lerobot[smolvla]", @@ -267,7 +308,9 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] -"__init__.py" = ["F401", "F403"] +"__init__.py" = ["F401", "F403", "E402"] +# E402: conditional-import guards (TYPE_CHECKING / is_package_available) must precede the imports they protect +"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"] "src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original [tool.ruff.lint.isort] diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py index eec574296..df43e7172 100644 --- a/src/lerobot/__init__.py +++ b/src/lerobot/__init__.py @@ -13,188 +13,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """ -This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library. -We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables. +LeRobot -- PyTorch library for real-world robotics. -Example: - ```python - import lerobot - print(lerobot.available_envs) - print(lerobot.available_tasks_per_env) - print(lerobot.available_datasets) - print(lerobot.available_datasets_per_env) - print(lerobot.available_real_world_datasets) - print(lerobot.available_policies) - print(lerobot.available_policies_per_env) - print(lerobot.available_robots) - print(lerobot.available_cameras) - print(lerobot.available_motors) - ``` +Provides datasets, pretrained policies, and tools for training, evaluation, +data collection, and robot control. Integrates with Hugging Face Hub for +model and dataset sharing. -When implementing a new dataset loadable with LeRobotDataset follow these steps: -- Update `available_datasets_per_env` in `lerobot/__init__.py` +The base install is intentionally lightweight. Feature-specific dependencies +are gated behind optional extras:: -When implementing a new environment (e.g. `gym_aloha`), follow these steps: -- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py` - -When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps: -- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py` -- Set the required `name` class attribute. -- Update variables in `tests/test_available.py` by importing your new Policy class + pip install 'lerobot[dataset]' # dataset loading & creation + pip install 'lerobot[training]' # training loop + wandb + pip install 'lerobot[hardware]' # real robot control + pip install 'lerobot[core_scripts]' # dataset + hardware + viz (record, replay, calibrate, etc.) + pip install 'lerobot[all]' # everything """ -import itertools +from lerobot.__version__ import __version__ -from lerobot.__version__ import __version__ # noqa: F401 - -# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies` -# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to -# a yaml file AND a environment name. The difference should be more obvious. -available_tasks_per_env = { - "aloha": [ - "AlohaInsertion-v0", - "AlohaTransferCube-v0", +# Maps optional extras to the CLI entry-points they unlock. +available_extras: dict[str, list[str]] = { + "dataset": ["lerobot-dataset-viz", "lerobot-imgtransform-viz", "lerobot-edit-dataset"], + "training": ["lerobot-train"], + "hardware": [ + "lerobot-calibrate", + "lerobot-find-port", + "lerobot-find-cameras", + "lerobot-find-joint-limits", + "lerobot-setup-motors", ], - "pusht": ["PushT-v0"], -} -available_envs = list(available_tasks_per_env.keys()) - -available_datasets_per_env = { - "aloha": [ - "lerobot/aloha_sim_insertion_human", - "lerobot/aloha_sim_insertion_scripted", - "lerobot/aloha_sim_transfer_cube_human", - "lerobot/aloha_sim_transfer_cube_scripted", - "lerobot/aloha_sim_insertion_human_image", - "lerobot/aloha_sim_insertion_scripted_image", - "lerobot/aloha_sim_transfer_cube_human_image", - "lerobot/aloha_sim_transfer_cube_scripted_image", - ], - # TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly - # coupled with tests. - "pusht": ["lerobot/pusht", "lerobot/pusht_image"], + "core_scripts": ["lerobot-record", "lerobot-replay", "lerobot-teleoperate"], + "evaluation": ["lerobot-eval"], } -available_real_world_datasets = [ - "lerobot/aloha_mobile_cabinet", - "lerobot/aloha_mobile_chair", - "lerobot/aloha_mobile_elevator", - "lerobot/aloha_mobile_shrimp", - "lerobot/aloha_mobile_wash_pan", - "lerobot/aloha_mobile_wipe_wine", - "lerobot/aloha_static_battery", - "lerobot/aloha_static_candy", - "lerobot/aloha_static_coffee", - "lerobot/aloha_static_coffee_new", - "lerobot/aloha_static_cups_open", - "lerobot/aloha_static_fork_pick_up", - "lerobot/aloha_static_pingpong_test", - "lerobot/aloha_static_pro_pencil", - "lerobot/aloha_static_screw_driver", - "lerobot/aloha_static_tape", - "lerobot/aloha_static_thread_velcro", - "lerobot/aloha_static_towel", - "lerobot/aloha_static_vinh_cup", - "lerobot/aloha_static_vinh_cup_left", - "lerobot/aloha_static_ziploc_slide", - "lerobot/umi_cup_in_the_wild", - "lerobot/unitreeh1_fold_clothes", - "lerobot/unitreeh1_rearrange_objects", - "lerobot/unitreeh1_two_robot_greeting", - "lerobot/unitreeh1_warehouse", - "lerobot/nyu_rot_dataset", - "lerobot/utokyo_saytap", - "lerobot/imperialcollege_sawyer_wrist_cam", - "lerobot/utokyo_xarm_bimanual", - "lerobot/tokyo_u_lsmo", - "lerobot/utokyo_pr2_opening_fridge", - "lerobot/cmu_franka_exploration_dataset", - "lerobot/cmu_stretch", - "lerobot/asu_table_top", - "lerobot/utokyo_pr2_tabletop_manipulation", - "lerobot/utokyo_xarm_pick_and_place", - "lerobot/ucsd_kitchen_dataset", - "lerobot/austin_buds_dataset", - "lerobot/dlr_sara_grid_clamp", - "lerobot/conq_hose_manipulation", - "lerobot/columbia_cairlab_pusht_real", - "lerobot/dlr_sara_pour", - "lerobot/dlr_edan_shared_control", - "lerobot/ucsd_pick_and_place_dataset", - "lerobot/berkeley_cable_routing", - "lerobot/nyu_franka_play_dataset", - "lerobot/austin_sirius_dataset", - "lerobot/cmu_play_fusion", - "lerobot/berkeley_gnm_sac_son", - "lerobot/nyu_door_opening_surprising_effectiveness", - "lerobot/berkeley_fanuc_manipulation", - "lerobot/jaco_play", - "lerobot/viola", - "lerobot/kaist_nonprehensile", - "lerobot/berkeley_mvp", - "lerobot/uiuc_d3field", - "lerobot/berkeley_gnm_recon", - "lerobot/austin_sailor_dataset", - "lerobot/utaustin_mutex", - "lerobot/roboturk", - "lerobot/stanford_hydra_dataset", - "lerobot/berkeley_autolab_ur5", - "lerobot/stanford_robocook", - "lerobot/toto", - "lerobot/fmb", - "lerobot/droid_100", - "lerobot/berkeley_rpt", - "lerobot/stanford_kuka_multimodal_dataset", - "lerobot/iamlab_cmu_pickup_insert", - "lerobot/taco_play", - "lerobot/berkeley_gnm_cory_hall", - "lerobot/usc_cloth_sim", -] - -available_datasets = sorted( - set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)) -) - -# lists all available policies from `lerobot/policies` -available_policies = ["act", "diffusion", "tdmpc", "vqbet"] - -# lists all available robots from `lerobot/robots` -available_robots = [ - "koch", - "koch_bimanual", - "aloha", - "so100", - "so101", -] - -# lists all available cameras from `lerobot/cameras` -available_cameras = [ - "opencv", - "intelrealsense", -] - -# lists all available motors from `lerobot/motors` -available_motors = [ - "dynamixel", - "feetech", -] - -# keys and values refer to yaml files -available_policies_per_env = { - "aloha": ["act"], - "pusht": ["diffusion", "vqbet"], - "koch_real": ["act_koch_real"], - "aloha_real": ["act_aloha_real"], -} - -env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] -env_dataset_pairs = [ - (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets -] -env_dataset_policy_triplets = [ - (env, dataset, policy) - for env, datasets in available_datasets_per_env.items() - for dataset in datasets - for policy in available_policies_per_env[env] -] +__all__ = ["__version__", "available_extras"] diff --git a/src/lerobot/async_inference/__init__.py b/src/lerobot/async_inference/__init__.py new file mode 100644 index 000000000..8d7a22584 --- /dev/null +++ b/src/lerobot/async_inference/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Async inference server/client. + +Requires: ``pip install 'lerobot[async]'`` + +Available modules (import directly):: + + from lerobot.async_inference.policy_server import ... + from lerobot.async_inference.robot_client import ... +""" + +from lerobot.utils.import_utils import require_package + +require_package("grpcio", extra="async", import_name="grpc") + +__all__: list[str] = [] diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 9dd44eb44..4931c68c5 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -22,8 +22,7 @@ from typing import Any import torch -from lerobot.configs.types import PolicyFeature -from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features +from lerobot.configs import PolicyFeature # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config from lerobot.policies import ( # noqa: F401 @@ -36,6 +35,7 @@ from lerobot.policies import ( # noqa: F401 ) from lerobot.robots.robot import Robot from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features from lerobot.utils.utils import init_logging Action = torch.Tensor diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index 3f63929df..787d39abf 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -38,7 +38,7 @@ import draccus import grpc import torch -from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.policies import get_policy_class, make_pre_post_processors from lerobot.processor import PolicyProcessorPipeline from lerobot.transport import ( services_pb2, # type: ignore diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index 0ee70a0e6..a250a08fb 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -47,8 +47,8 @@ import draccus import grpc import torch -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 -from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401 from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, diff --git a/src/lerobot/cameras/__init__.py b/src/lerobot/cameras/__init__.py index cbf1f11bf..3598d58aa 100644 --- a/src/lerobot/cameras/__init__.py +++ b/src/lerobot/cameras/__init__.py @@ -15,3 +15,9 @@ from .camera import Camera from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation from .utils import make_cameras_from_configs + +# NOTE: Camera submodule configs and implementations (OpenCVCameraConfig, RealSenseCamera, etc.) +# are intentionally NOT re-exported here to avoid pulling backend-specific dependencies. +# Import from submodules: ``from lerobot.cameras.opencv import OpenCVCameraConfig`` + +__all__ = ["Camera", "CameraConfig", "ColorMode", "Cv2Backends", "Cv2Rotation", "make_cameras_from_configs"] diff --git a/src/lerobot/cameras/reachy2_camera/__init__.py b/src/lerobot/cameras/reachy2_camera/__init__.py index 72e45f32a..d979a7db5 100644 --- a/src/lerobot/cameras/reachy2_camera/__init__.py +++ b/src/lerobot/cameras/reachy2_camera/__init__.py @@ -14,3 +14,5 @@ from .configuration_reachy2_camera import Reachy2CameraConfig from .reachy2_camera import Reachy2Camera + +__all__ = ["Reachy2Camera", "Reachy2CameraConfig"] diff --git a/src/lerobot/cameras/realsense/__init__.py b/src/lerobot/cameras/realsense/__init__.py index 67f2f4000..eb20c9973 100644 --- a/src/lerobot/cameras/realsense/__init__.py +++ b/src/lerobot/cameras/realsense/__init__.py @@ -14,3 +14,5 @@ from .camera_realsense import RealSenseCamera from .configuration_realsense import RealSenseCameraConfig + +__all__ = ["RealSenseCamera", "RealSenseCameraConfig"] diff --git a/src/lerobot/cameras/zmq/image_server.py b/src/lerobot/cameras/zmq/image_server.py index 8222b9fee..b8b6f8e74 100644 --- a/src/lerobot/cameras/zmq/image_server.py +++ b/src/lerobot/cameras/zmq/image_server.py @@ -31,8 +31,8 @@ import cv2 import numpy as np import zmq -from lerobot.cameras.configs import ColorMode -from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig +from ..configs import ColorMode +from ..opencv import OpenCVCamera, OpenCVCameraConfig logger = logging.getLogger(__name__) diff --git a/src/lerobot/common/__init__.py b/src/lerobot/common/__init__.py new file mode 100644 index 000000000..782ef5b77 --- /dev/null +++ b/src/lerobot/common/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cross-cutting modules that bridge multiple lerobot packages. + +Unlike ``lerobot.utils`` (which must remain dependency-free), modules here +are allowed to import from ``lerobot.policies``, ``lerobot.processor``, +``lerobot.configs``, etc. They are deliberately NOT re-exported from the +top-level ``lerobot`` package. + +Available modules (import directly):: + + from lerobot.common.control_utils import predict_action, ... + from lerobot.common.train_utils import save_checkpoint, ... + from lerobot.common.wandb_utils import WandBLogger, ... +""" + +__all__: list[str] = [] diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/common/control_utils.py similarity index 95% rename from src/lerobot/utils/control_utils.py rename to src/lerobot/common/control_utils.py index 94cd82fa1..530955078 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/common/control_utils.py @@ -12,26 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + ######################################################################################## # Utilities ######################################################################################## - - import logging import traceback from contextlib import nullcontext from copy import copy from functools import cache -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import torch -from deepdiff import DeepDiff -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import DEFAULT_FEATURES -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import prepare_observation_for_inference +from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference + +if TYPE_CHECKING: + from lerobot.datasets import LeRobotDataset from lerobot.processor import PolicyProcessorPipeline from lerobot.robots import Robot from lerobot.types import PolicyAction @@ -218,6 +217,13 @@ def sanity_check_dataset_robot_compatibility( Raises: ValueError: If any of the checked metadata fields do not match. """ + from lerobot.utils.import_utils import require_package + + require_package("deepdiff", extra="hardware") + from deepdiff import DeepDiff + + from lerobot.utils.constants import DEFAULT_FEATURES + fields = [ ("robot_type", dataset.meta.robot_type, robot.robot_type), ("fps", dataset.fps, fps), diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/common/train_utils.py similarity index 95% rename from src/lerobot/utils/train_utils.py rename to src/lerobot/common/train_utils.py index 02f6aebb3..3e96e1330 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -19,10 +19,13 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.io_utils import load_json, write_json -from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state -from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state -from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.optim import ( + load_optimizer_state, + load_scheduler_state, + save_optimizer_state, + save_scheduler_state, +) +from lerobot.policies import PreTrainedPolicy from lerobot.processor import PolicyProcessorPipeline from lerobot.utils.constants import ( CHECKPOINTS_DIR, @@ -31,6 +34,7 @@ from lerobot.utils.constants import ( TRAINING_STATE_DIR, TRAINING_STEP, ) +from lerobot.utils.io_utils import load_json, write_json from lerobot.utils.random_utils import load_rng_state, save_rng_state diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/common/wandb_utils.py similarity index 100% rename from src/lerobot/rl/wandb_utils.py rename to src/lerobot/common/wandb_utils.py diff --git a/src/lerobot/configs/__init__.py b/src/lerobot/configs/__init__.py new file mode 100644 index 000000000..3ddaec1af --- /dev/null +++ b/src/lerobot/configs/__init__.py @@ -0,0 +1,47 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Public API for lerobot configuration types and base config classes. + +NOTE: TrainPipelineConfig, EvalPipelineConfig, and TrainRLServerPipelineConfig +are intentionally NOT re-exported here to avoid circular dependencies +(they import lerobot.envs and lerobot.policies at module level). +Import them directly: ``from lerobot.configs.train import TrainPipelineConfig`` +""" + +from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig +from .policies import PreTrainedConfig +from .types import ( + FeatureType, + NormalizationMode, + PipelineFeatureType, + PolicyFeature, + RTCAttentionSchedule, +) + +__all__ = [ + # Types + "FeatureType", + "NormalizationMode", + "PipelineFeatureType", + "PolicyFeature", + "RTCAttentionSchedule", + # Config classes + "DatasetConfig", + "EvalConfig", + "PeftConfig", + "PreTrainedConfig", + "WandBConfig", +] diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index d6ad665bf..b05e96fde 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -16,8 +16,8 @@ from dataclasses import dataclass, field -from lerobot.datasets.transforms import ImageTransformsConfig -from lerobot.datasets.video_utils import get_safe_default_codec +from lerobot.transforms import ImageTransformsConfig +from lerobot.utils.import_utils import get_safe_default_codec @dataclass diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py index da8bee6b2..d1cebd27f 100644 --- a/src/lerobot/configs/eval.py +++ b/src/lerobot/configs/eval.py @@ -19,8 +19,9 @@ from pathlib import Path from lerobot import envs, policies # noqa: F401 from lerobot.configs import parser -from lerobot.configs.default import EvalConfig -from lerobot.configs.policies import PreTrainedConfig + +from .default import EvalConfig +from .policies import PreTrainedConfig logger = getLogger(__name__) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index ce567b8f5..91701af6d 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -26,13 +26,13 @@ from huggingface_hub import hf_hub_download from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.errors import HfHubHTTPError -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.optim.optimizers import OptimizerConfig -from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.optim import LRSchedulerConfig, OptimizerConfig from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.device_utils import auto_select_torch_device, is_amp_available, is_torch_device_available from lerobot.utils.hub import HubMixin +from .types import FeatureType, PolicyFeature + T = TypeVar("T", bound="PreTrainedConfig") logger = getLogger(__name__) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 8b8aedb26..d754a0847 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -24,12 +24,12 @@ from huggingface_hub.errors import HfHubHTTPError from lerobot import envs from lerobot.configs import parser -from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig -from lerobot.configs.policies import PreTrainedConfig -from lerobot.optim import OptimizerConfig -from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.optim import LRSchedulerConfig, OptimizerConfig from lerobot.utils.hub import HubMixin +from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig +from .policies import PreTrainedConfig + TRAIN_CONFIG_NAME = "train_config.json" diff --git a/src/lerobot/data_processing/__init__.py b/src/lerobot/data_processing/__init__.py index 2f76d5676..cd55d46fc 100644 --- a/src/lerobot/data_processing/__init__.py +++ b/src/lerobot/data_processing/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +Data processing utilities (annotation tools, dataset transformations). + +Available sub-modules (import directly):: + + from lerobot.data_processing.sarm_annotations import ... +""" + +__all__: list[str] = [] diff --git a/src/lerobot/data_processing/sarm_annotations/__init__.py b/src/lerobot/data_processing/sarm_annotations/__init__.py index 2f76d5676..cd4c38f33 100644 --- a/src/lerobot/data_processing/sarm_annotations/__init__.py +++ b/src/lerobot/data_processing/sarm_annotations/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +SARM subtask annotation tools. + +Available modules (import directly):: + + from lerobot.data_processing.sarm_annotations.subtask_annotation import ... +""" + +__all__: list[str] = [] diff --git a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py index 8f3a65e39..b26257d44 100644 --- a/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py +++ b/src/lerobot/data_processing/sarm_annotations/subtask_annotation.py @@ -76,7 +76,7 @@ import torch from pydantic import BaseModel, Field from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset # Pydantic Models for SARM Subtask Annotation @@ -746,8 +746,7 @@ def save_annotations_to_dataset( dataset_path: Path, annotations: dict[int, SubtaskAnnotation], fps: int, prefix: str = "sparse" ): """Save annotations to LeRobot dataset parquet format.""" - from lerobot.datasets.io_utils import load_episodes - from lerobot.datasets.utils import DEFAULT_EPISODES_PATH + from lerobot.datasets import DEFAULT_EPISODES_PATH, load_episodes episodes_dataset = load_episodes(dataset_path) if not episodes_dataset or len(episodes_dataset) == 0: @@ -841,7 +840,7 @@ def generate_auto_sparse_annotations( def load_annotations_from_dataset(dataset_path: Path, prefix: str = "sparse") -> dict[int, SubtaskAnnotation]: """Load annotations from LeRobot dataset parquet files.""" - from lerobot.datasets.io_utils import load_episodes + from lerobot.datasets import load_episodes episodes_dataset = load_episodes(dataset_path) if not episodes_dataset or len(episodes_dataset) == 0: diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py index 42c4ab810..6c42959a5 100644 --- a/src/lerobot/datasets/__init__.py +++ b/src/lerobot/datasets/__init__.py @@ -15,19 +15,68 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.multi_dataset import MultiLeRobotDataset -from lerobot.datasets.sampler import EpisodeAwareSampler -from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset -from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig +from lerobot.utils.import_utils import require_package + +require_package("datasets", extra="dataset") +require_package("av", extra="dataset") + +from .aggregate import aggregate_datasets +from .compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats +from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from .dataset_tools import ( + add_features, + convert_image_to_video_dataset, + delete_episodes, + merge_datasets, + modify_features, + modify_tasks, + recompute_stats, + remove_feature, + split_dataset, +) +from .factory import make_dataset, resolve_delta_timestamps +from .image_writer import safe_stop_image_writer +from .io_utils import load_episodes, write_stats +from .lerobot_dataset import LeRobotDataset +from .multi_dataset import MultiLeRobotDataset +from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from .sampler import EpisodeAwareSampler +from .streaming_dataset import StreamingLeRobotDataset +from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card +from .video_utils import VideoEncodingManager + +# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.) +# and legacy migration constants are intentionally NOT re-exported here. +# Import directly: ``from lerobot.datasets.io_utils import ...`` __all__ = [ + "CODEBASE_VERSION", + "DEFAULT_EPISODES_PATH", + "DEFAULT_QUANTILES", "EpisodeAwareSampler", - "ImageTransforms", - "ImageTransformsConfig", "LeRobotDataset", "LeRobotDatasetMetadata", "MultiLeRobotDataset", "StreamingLeRobotDataset", + "VideoEncodingManager", + "add_features", + "aggregate_datasets", + "aggregate_pipeline_dataset_features", + "aggregate_stats", + "convert_image_to_video_dataset", + "create_initial_features", + "create_lerobot_dataset_card", + "delete_episodes", + "get_feature_stats", + "load_episodes", + "make_dataset", + "merge_datasets", + "modify_features", + "modify_tasks", + "recompute_stats", + "remove_feature", + "resolve_delta_timestamps", + "safe_stop_image_writer", + "split_dataset", + "write_stats", ] diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 66f055f04..0da1da964 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -23,10 +23,10 @@ import datasets import pandas as pd import tqdm -from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.datasets.feature_utils import get_hf_features_from_features -from lerobot.datasets.io_utils import ( +from .compute_stats import aggregate_stats +from .dataset_metadata import LeRobotDatasetMetadata +from .feature_utils import get_hf_features_from_features +from .io_utils import ( get_file_size_in_mb, get_parquet_file_size_in_mb, to_parquet_with_hf_images, @@ -34,7 +34,7 @@ from lerobot.datasets.io_utils import ( write_stats, write_tasks, ) -from lerobot.datasets.utils import ( +from .utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, @@ -43,7 +43,7 @@ from lerobot.datasets.utils import ( DEFAULT_VIDEO_PATH, update_chunk_file_indices, ) -from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s +from .video_utils import concatenate_video_files, get_video_duration_in_s def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index 03eefe40e..f489c84a7 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -19,9 +19,11 @@ import logging import numpy as np -from lerobot.datasets.io_utils import load_image_as_numpy +from lerobot.processor import RelativeActionsProcessorStep from lerobot.utils.constants import ACTION, OBS_STATE +from .io_utils import load_image_as_numpy + DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99] @@ -696,8 +698,6 @@ def compute_relative_action_stats( ValueError: If the dataset has fewer frames than ``chunk_size``. RuntimeError: If no valid (single-episode) chunks are found. """ - from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep - if exclude_joints is None: exclude_joints = [] diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index d79f4bfba..8bf67fa39 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -23,9 +23,13 @@ import pyarrow as pa import pyarrow.parquet as pq from huggingface_hub import snapshot_download -from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.feature_utils import _validate_feature_names, create_empty_dataset_info -from lerobot.datasets.io_utils import ( +from lerobot.utils.constants import DEFAULT_FEATURES, HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE +from lerobot.utils.feature_utils import _validate_feature_names +from lerobot.utils.utils import flatten_dict + +from .compute_stats import aggregate_stats +from .feature_utils import create_empty_dataset_info +from .io_utils import ( get_file_size_in_mb, load_episodes, load_info, @@ -37,19 +41,16 @@ from lerobot.datasets.io_utils import ( write_stats, write_tasks, ) -from lerobot.datasets.utils import ( +from .utils import ( DEFAULT_EPISODES_PATH, - DEFAULT_FEATURES, INFO_PATH, check_version_compatibility, - flatten_dict, get_safe_version, has_legacy_hub_download_metadata, is_valid_version, update_chunk_file_indices, ) -from lerobot.datasets.video_utils import get_video_info -from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE +from .video_utils import get_video_info CODEBASE_VERSION = "v3.0" diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 3720a5084..fc7ce36ce 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -21,17 +21,17 @@ from pathlib import Path import datasets import torch -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.datasets.feature_utils import ( +from .dataset_metadata import LeRobotDatasetMetadata +from .feature_utils import ( check_delta_timestamps, get_delta_indices, get_hf_features_from_features, ) -from lerobot.datasets.io_utils import ( +from .io_utils import ( hf_transform_to_torch, load_nested_dataset, ) -from lerobot.datasets.video_utils import decode_video_frames +from .video_utils import decode_video_frames class DatasetReader: diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 16bf24822..cbf4e5c49 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -36,22 +36,25 @@ import pyarrow.parquet as pq import torch from tqdm import tqdm -from lerobot.datasets.aggregate import aggregate_datasets -from lerobot.datasets.compute_stats import ( +from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE +from lerobot.utils.utils import flatten_dict + +from .aggregate import aggregate_datasets +from .compute_stats import ( aggregate_stats, compute_episode_stats, compute_relative_action_stats, ) -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.datasets.io_utils import ( +from .dataset_metadata import LeRobotDatasetMetadata +from .io_utils import ( get_parquet_file_size_in_mb, load_episodes, write_info, write_stats, write_tasks, ) -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import ( +from .lerobot_dataset import LeRobotDataset +from .utils import ( DATA_DIR, DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -59,8 +62,7 @@ from lerobot.datasets.utils import ( DEFAULT_EPISODES_PATH, update_chunk_file_indices, ) -from lerobot.datasets.video_utils import encode_video_frames, get_video_info -from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE +from .video_utils import encode_video_frames, get_video_info def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict: @@ -829,8 +831,6 @@ def _copy_and_reindex_episodes_metadata( data_metadata: Dict mapping new episode index to its data file metadata video_metadata: Optional dict mapping new episode index to its video metadata """ - from lerobot.datasets.utils import flatten_dict - if src_dataset.meta.episodes is None: src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) @@ -922,8 +922,8 @@ def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) - This ensures images are properly embedded and the file can be loaded correctly by HF datasets. """ - from lerobot.datasets.feature_utils import get_hf_features_from_features - from lerobot.datasets.io_utils import embed_images + from .feature_utils import get_hf_features_from_features + from .io_utils import embed_images hf_features = get_hf_features_from_features(meta.features) ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train") @@ -1367,7 +1367,7 @@ def _copy_data_without_images( episode_indices: Episodes to include img_keys: Image keys to remove """ - from lerobot.datasets.utils import DATA_DIR + from .utils import DATA_DIR data_dir = src_dataset.root / DATA_DIR parquet_files = sorted(data_dir.glob("*/*.parquet")) diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index 787ecd337..60ec9e348 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -31,26 +31,26 @@ import PIL.Image import pyarrow.parquet as pq import torch -from lerobot.datasets.compute_stats import compute_episode_stats -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.datasets.feature_utils import ( +from .compute_stats import compute_episode_stats +from .dataset_metadata import LeRobotDatasetMetadata +from .feature_utils import ( get_hf_features_from_features, validate_episode_buffer, validate_frame, ) -from lerobot.datasets.image_writer import AsyncImageWriter, write_image -from lerobot.datasets.io_utils import ( +from .image_writer import AsyncImageWriter, write_image +from .io_utils import ( embed_images, get_file_size_in_mb, load_episodes, write_info, ) -from lerobot.datasets.utils import ( +from .utils import ( DEFAULT_EPISODES_PATH, DEFAULT_IMAGE_PATH, update_chunk_file_indices, ) -from lerobot.datasets.video_utils import ( +from .video_utils import ( StreamingVideoEncoder, concatenate_video_files, encode_video_frames, diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 76ece8961..040cba5cb 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -18,19 +18,15 @@ from pprint import pformat import torch -from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.multi_dataset import MultiLeRobotDataset -from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset -from lerobot.datasets.transforms import ImageTransforms -from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD +from lerobot.transforms import ImageTransforms +from lerobot.utils.constants import ACTION, IMAGENET_STATS, OBS_PREFIX, REWARD -IMAGENET_STATS = { - "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) - "std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1) -} +from .dataset_metadata import LeRobotDatasetMetadata +from .lerobot_dataset import LeRobotDataset +from .multi_dataset import MultiLeRobotDataset +from .streaming_dataset import StreamingLeRobotDataset def resolve_delta_timestamps( diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index 46154d92a..b05dbf2cc 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -14,23 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from pprint import pformat -from typing import Any import datasets import numpy as np from PIL import Image as PILImage -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.datasets.utils import ( +from lerobot.utils.constants import DEFAULT_FEATURES +from lerobot.utils.utils import is_valid_numpy_dtype_string + +from .utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, - DEFAULT_FEATURES, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, ) -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR -from lerobot.utils.utils import is_valid_numpy_dtype_string def get_hf_features_from_features(features: dict) -> datasets.Features: @@ -71,199 +69,6 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: return datasets.Features(hf_features) -def _validate_feature_names(features: dict[str, dict]) -> None: - """Validate that feature names do not contain invalid characters. - - Args: - features (dict): The LeRobot features dictionary. - - Raises: - ValueError: If any feature name contains '/'. - """ - invalid_features = {name: ft for name, ft in features.items() if "/" in name} - if invalid_features: - raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") - - -def hw_to_dataset_features( - hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True -) -> dict[str, dict]: - """Convert hardware-specific features to a LeRobot dataset feature dictionary. - - This function takes a dictionary describing hardware outputs (like joint states - or camera image shapes) and formats it into the standard LeRobot feature - specification. - - Args: - hw_features (dict): Dictionary mapping feature names to their type (float for - joints) or shape (tuple for images). - prefix (str): The prefix to add to the feature keys (e.g., "observation" - or "action"). - use_video (bool): If True, image features are marked as "video", otherwise "image". - - Returns: - dict: A LeRobot features dictionary. - """ - features = {} - joint_fts = { - key: ftype - for key, ftype in hw_features.items() - if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) - } - cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} - - if joint_fts and prefix == ACTION: - features[prefix] = { - "dtype": "float32", - "shape": (len(joint_fts),), - "names": list(joint_fts), - } - - if joint_fts and prefix == OBS_STR: - features[f"{prefix}.state"] = { - "dtype": "float32", - "shape": (len(joint_fts),), - "names": list(joint_fts), - } - - for key, shape in cam_fts.items(): - features[f"{prefix}.images.{key}"] = { - "dtype": "video" if use_video else "image", - "shape": shape, - "names": ["height", "width", "channels"], - } - - _validate_feature_names(features) - return features - - -def build_dataset_frame( - ds_features: dict[str, dict], values: dict[str, Any], prefix: str -) -> dict[str, np.ndarray]: - """Construct a single data frame from raw values based on dataset features. - - A "frame" is a dictionary containing all the data for a single timestep, - formatted as numpy arrays according to the feature specification. - - Args: - ds_features (dict): The LeRobot dataset features dictionary. - values (dict): A dictionary of raw values from the hardware/environment. - prefix (str): The prefix to filter features by (e.g., "observation" - or "action"). - - Returns: - dict: A dictionary representing a single frame of data. - """ - frame = {} - for key, ft in ds_features.items(): - if key in DEFAULT_FEATURES or not key.startswith(prefix): - continue - elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: - frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) - elif ft["dtype"] in ["image", "video"]: - frame[key] = values[key.removeprefix(f"{prefix}.images.")] - - return frame - - -def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: - """Convert dataset features to policy features. - - This function transforms the dataset's feature specification into a format - that a policy can use, classifying features by type (e.g., visual, state, - action) and ensuring correct shapes (e.g., channel-first for images). - - Args: - features (dict): The LeRobot dataset features dictionary. - - Returns: - dict: A dictionary mapping feature keys to `PolicyFeature` objects. - - Raises: - ValueError: If an image feature does not have a 3D shape. - """ - # TODO(aliberts): Implement "type" in dataset features and simplify this - policy_features = {} - for key, ft in features.items(): - shape = ft["shape"] - if ft["dtype"] in ["image", "video"]: - type = FeatureType.VISUAL - if len(shape) != 3: - raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") - - names = ft["names"] - # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. - if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) - shape = (shape[2], shape[0], shape[1]) - elif key == OBS_ENV_STATE: - type = FeatureType.ENV - elif key.startswith(OBS_STR): - type = FeatureType.STATE - elif key.startswith(ACTION): - type = FeatureType.ACTION - else: - continue - - policy_features[key] = PolicyFeature( - type=type, - shape=shape, - ) - - return policy_features - - -def combine_feature_dicts(*dicts: dict) -> dict: - """Merge LeRobot grouped feature dicts. - - - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. - - For others (e.g. `observation.images.*`), the last one wins (if they are identical). - - Args: - *dicts: A variable number of LeRobot feature dictionaries to merge. - - Returns: - dict: A single merged feature dictionary. - - Raises: - ValueError: If there's a dtype mismatch for a feature being merged. - """ - out: dict = {} - for d in dicts: - for key, value in d.items(): - if not isinstance(value, dict): - out[key] = value - continue - - dtype = value.get("dtype") - shape = value.get("shape") - is_vector = ( - dtype not in ("image", "video", "string") - and isinstance(shape, tuple) - and len(shape) == 1 - and "names" in value - ) - - if is_vector: - # Initialize or retrieve the accumulating dict for this feature key - target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) - # Ensure consistent data types across merged entries - if "dtype" in target and dtype != target["dtype"]: - raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") - - # Merge feature names: append only new ones to preserve order without duplicates - seen = set(target["names"]) - for n in value["names"]: - if n not in seen: - target["names"].append(n) - seen.add(n) - # Recompute the shape to reflect the updated number of features - target["shape"] = (len(target["names"]),) - else: - # For images/videos and non-1D entries: override with the latest definition - out[key] = value - return out - - def create_empty_dataset_info( codebase_version: str, fps: int, diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py index cee6cfba8..2ee859e97 100644 --- a/src/lerobot/datasets/io_utils.py +++ b/src/lerobot/datasets/io_utils.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from pathlib import Path from typing import Any @@ -29,7 +28,10 @@ from datasets.table import embed_table_storage from PIL import Image as PILImage from torchvision import transforms -from lerobot.datasets.utils import ( +from lerobot.utils.io_utils import load_json, write_json +from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict + +from .utils import ( DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_EPISODES_PATH, DEFAULT_SUBTASKS_PATH, @@ -37,11 +39,8 @@ from lerobot.datasets.utils import ( EPISODES_DIR, INFO_PATH, STATS_PATH, - flatten_dict, serialize_dict, - unflatten_dict, ) -from lerobot.utils.utils import SuppressProgressBars def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: @@ -116,33 +115,6 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: return dataset -def load_json(fpath: Path) -> Any: - """Load data from a JSON file. - - Args: - fpath (Path): Path to the JSON file. - - Returns: - Any: The data loaded from the JSON file. - """ - with open(fpath) as f: - return json.load(f) - - -def write_json(data: dict, fpath: Path) -> None: - """Write data to a JSON file. - - Creates parent directories if they don't exist. - - Args: - data (dict): The dictionary to write. - fpath (Path): The path to the output JSON file. - """ - fpath.parent.mkdir(exist_ok=True, parents=True) - with open(fpath, "w") as f: - json.dump(data, f, indent=4, ensure_ascii=False) - - def write_info(info: dict, local_dir: Path) -> None: write_json(info, local_dir / INFO_PATH) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 2f0154cda..7cda5d677 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -24,20 +24,21 @@ import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError -from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata -from lerobot.datasets.dataset_reader import DatasetReader -from lerobot.datasets.dataset_writer import DatasetWriter -from lerobot.datasets.utils import ( +from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE + +from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from .dataset_reader import DatasetReader +from .dataset_writer import DatasetWriter +from .utils import ( create_lerobot_dataset_card, get_safe_version, is_valid_version, ) -from lerobot.datasets.video_utils import ( +from .video_utils import ( StreamingVideoEncoder, get_safe_default_codec, resolve_vcodec, ) -from lerobot.utils.constants import HF_LEROBOT_HUB_CACHE logger = logging.getLogger(__name__) diff --git a/src/lerobot/datasets/multi_dataset.py b/src/lerobot/datasets/multi_dataset.py index 092443077..b4b7a941d 100644 --- a/src/lerobot/datasets/multi_dataset.py +++ b/src/lerobot/datasets/multi_dataset.py @@ -21,12 +21,13 @@ import datasets import torch import torch.utils -from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.feature_utils import get_hf_features_from_features -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.video_utils import VideoFrame from lerobot.utils.constants import HF_LEROBOT_HOME +from .compute_stats import aggregate_stats +from .feature_utils import get_hf_features_from_features +from .lerobot_dataset import LeRobotDataset +from .video_utils import VideoFrame + logger = logging.getLogger(__name__) diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index 96779fdc6..cf02a52ac 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -16,11 +16,11 @@ import re from collections.abc import Sequence from typing import Any -from lerobot.configs.types import PipelineFeatureType -from lerobot.datasets.feature_utils import hw_to_dataset_features +from lerobot.configs import PipelineFeatureType from lerobot.processor import DataProcessorPipeline from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR +from lerobot.utils.feature_utils import hw_to_dataset_features def create_initial_features( diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index 1767cc79d..f47d71367 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -22,20 +22,21 @@ import numpy as np import torch from datasets import load_dataset -from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata -from lerobot.datasets.feature_utils import get_delta_indices -from lerobot.datasets.io_utils import item_to_torch -from lerobot.datasets.utils import ( +from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE + +from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata +from .feature_utils import get_delta_indices +from .io_utils import item_to_torch +from .utils import ( check_version_compatibility, find_float_index, is_float_in_list, safe_shard, ) -from lerobot.datasets.video_utils import ( +from .video_utils import ( VideoDecoderCache, decode_video_frames_torchcodec, ) -from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE class LookBackError(Exception): diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 36e7934ed..c6815e0f5 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -17,9 +17,7 @@ import contextlib import importlib.resources import json import logging -from collections.abc import Iterator from pathlib import Path -from typing import Any import datasets import numpy as np @@ -28,6 +26,8 @@ import torch from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError +from lerobot.utils.utils import flatten_dict, unflatten_dict + V30_MESSAGE = """ The dataset you requested ({repo_id}) is in {version} format. @@ -93,14 +93,6 @@ LEGACY_EPISODES_PATH = "meta/episodes.jsonl" LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" LEGACY_TASKS_PATH = "meta/tasks.jsonl" -DEFAULT_FEATURES = { - "timestamp": {"dtype": "float32", "shape": (1,), "names": None}, - "frame_index": {"dtype": "int64", "shape": (1,), "names": None}, - "episode_index": {"dtype": "int64", "shape": (1,), "names": None}, - "index": {"dtype": "int64", "shape": (1,), "names": None}, - "task_index": {"dtype": "int64", "shape": (1,), "names": None}, -} - def has_legacy_hub_download_metadata(root: Path) -> bool: """Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror. @@ -123,59 +115,6 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) - return chunk_idx, file_idx -def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: - """Flatten a nested dictionary by joining keys with a separator. - - Example: - >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3} - >>> print(flatten_dict(dct)) - {'a/b': 1, 'a/c/d': 2, 'e': 3} - - Args: - d (dict): The dictionary to flatten. - parent_key (str): The base key to prepend to the keys in this level. - sep (str): The separator to use between keys. - - Returns: - dict: A flattened dictionary. - """ - items = [] - for k, v in d.items(): - new_key = f"{parent_key}{sep}{k}" if parent_key else k - if isinstance(v, dict): - items.extend(flatten_dict(v, new_key, sep=sep).items()) - else: - items.append((new_key, v)) - return dict(items) - - -def unflatten_dict(d: dict, sep: str = "/") -> dict: - """Unflatten a dictionary with delimited keys into a nested dictionary. - - Example: - >>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3} - >>> print(unflatten_dict(flat_dct)) - {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} - - Args: - d (dict): A dictionary with flattened keys. - sep (str): The separator used in the keys. - - Returns: - dict: A nested dictionary. - """ - outdict = {} - for key, value in d.items(): - parts = key.split(sep) - d = outdict - for part in parts[:-1]: - if part not in d: - d[part] = {} - d = d[part] - d[parts[-1]] = value - return outdict - - def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: """Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible. @@ -332,27 +271,6 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> raise ForwardCompatibilityError(repo_id, min(upper_versions)) -def cycle(iterable: Any) -> Iterator[Any]: - """Create a dataloader-safe cyclical iterator. - - This is an equivalent of `itertools.cycle` but is safe for use with - PyTorch DataLoaders with multiple workers. - See https://github.com/pytorch/pytorch/issues/23900 for details. - - Args: - iterable: The iterable to cycle over. - - Yields: - Items from the iterable, restarting from the beginning when exhausted. - """ - iterator = iter(iterable) - while True: - try: - yield next(iterator) - except StopIteration: - iterator = iter(iterable) - - def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None: """Create a branch on an existing Hugging Face repo. diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 59c8c7d3e..cabe592d0 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -37,6 +37,8 @@ import torchvision from datasets.features.features import register_feature from PIL import Image +from lerobot.utils.import_utils import get_safe_default_codec + logger = logging.getLogger(__name__) # List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build. @@ -116,16 +118,6 @@ def resolve_vcodec(vcodec: str) -> str: return "libsvtav1" -def get_safe_default_codec(): - if importlib.util.find_spec("torchcodec"): - return "torchcodec" - else: - logger.warning( - "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder" - ) - return "pyav" - - def decode_video_frames( video_path: Path | str, timestamps: list[float], @@ -271,7 +263,10 @@ class VideoDecoderCache: if importlib.util.find_spec("torchcodec"): from torchcodec.decoders import VideoDecoder else: - raise ImportError("torchcodec is required but not available.") + raise ImportError( + "'torchcodec' is required but not installed. " + "Install it with: pip install 'lerobot[dataset]' (or uv pip install 'lerobot[dataset]')" + ) video_path = str(video_path) @@ -606,7 +601,7 @@ class _CameraEncoderThread(threading.Thread): self.encoder_threads = encoder_threads def run(self) -> None: - from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width + from .compute_stats import RunningQuantileStats, auto_downsample_height_width container = None output_stream = None diff --git a/src/lerobot/envs/__init__.py b/src/lerobot/envs/__init__.py index 183c12325..277fd04f4 100644 --- a/src/lerobot/envs/__init__.py +++ b/src/lerobot/envs/__init__.py @@ -12,4 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configs import AlohaEnv, EnvConfig, HubEnvConfig, PushtEnv # noqa: F401 +# NOTE: gymnasium is currently a core dependency but is a candidate for moving to an +# optional extra in the future. When that transition happens, uncomment the guard below +# and update the extra name to the one that will contain gymnasium. +# from lerobot.utils.import_utils import require_package +# require_package("gymnasium", extra="", import_name="gymnasium") + +from .configs import AlohaEnv, EnvConfig, HILSerlRobotEnvConfig, HubEnvConfig, PushtEnv +from .factory import make_env, make_env_config, make_env_pre_post_processors +from .utils import check_env_attributes_and_types, close_envs, env_to_policy_features, preprocess_observation + +__all__ = [ + "AlohaEnv", + "EnvConfig", + "HILSerlRobotEnvConfig", + "HubEnvConfig", + "PushtEnv", + "check_env_attributes_and_types", + "close_envs", + "env_to_policy_features", + "make_env", + "make_env_config", + "make_env_pre_post_processors", + "preprocess_observation", +] diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index af5bda33f..2a7c52d45 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -23,7 +23,8 @@ import draccus import gymnasium as gym from gymnasium.envs.registration import registry as gym_registry -from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.configs import FeatureType, PolicyFeature +from lerobot.processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, PolicyProcessorPipeline from lerobot.robots import RobotConfig from lerobot.teleoperators.config import TeleoperatorConfig from lerobot.utils.constants import ( @@ -124,8 +125,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): def get_env_processors(self): """Return (preprocessor, postprocessor) for this env. Default: identity.""" - from lerobot.processor.pipeline import PolicyProcessorPipeline - return PolicyProcessorPipeline(steps=[]), PolicyProcessorPipeline(steps=[]) @@ -418,7 +417,7 @@ class LiberoEnv(EnvConfig): return kwargs def create_envs(self, n_envs: int, use_async_envs: bool = False): - from lerobot.envs.libero import create_libero_envs + from .libero import create_libero_envs if self.task is None: raise ValueError("LiberoEnv requires a task to be specified") @@ -436,9 +435,6 @@ class LiberoEnv(EnvConfig): ) def get_env_processors(self): - from lerobot.processor.env_processor import LiberoProcessorStep - from lerobot.processor.pipeline import PolicyProcessorPipeline - return ( PolicyProcessorPipeline(steps=[LiberoProcessorStep()]), PolicyProcessorPipeline(steps=[]), @@ -487,7 +483,7 @@ class MetaworldEnv(EnvConfig): } def create_envs(self, n_envs: int, use_async_envs: bool = False): - from lerobot.envs.metaworld import create_metaworld_envs + from .metaworld import create_metaworld_envs if self.task is None: raise ValueError("MetaWorld requires a task to be specified") @@ -568,9 +564,6 @@ class IsaaclabArenaEnv(HubEnvConfig): return {} def get_env_processors(self): - from lerobot.processor.env_processor import IsaaclabArenaProcessorStep - from lerobot.processor.pipeline import PolicyProcessorPipeline - state_keys = tuple(k.strip() for k in (self.state_keys or "").split(",") if k.strip()) camera_keys = tuple(k.strip() for k in (self.camera_keys or "").split(",") if k.strip()) if not state_keys and not camera_keys: diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 40d5425cc..317cf2e6f 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -19,8 +19,8 @@ from typing import Any import gymnasium as gym -from lerobot.envs.configs import EnvConfig, HubEnvConfig -from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result +from .configs import EnvConfig, HubEnvConfig +from .utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result def make_env_config(env_type: str, **kwargs) -> EnvConfig: diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 1b814db52..ec90d0ffd 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -29,9 +29,10 @@ from gymnasium import spaces from libero.libero import benchmark, get_libero_path from libero.libero.envs import OffScreenRenderEnv -from lerobot.envs.utils import _LazyAsyncVectorEnv from lerobot.types import RobotObservation +from .utils import _LazyAsyncVectorEnv + def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: """Normalize camera_name into a non-empty list of strings.""" diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py index 49c775957..1dc513a68 100644 --- a/src/lerobot/envs/metaworld.py +++ b/src/lerobot/envs/metaworld.py @@ -25,9 +25,10 @@ import metaworld.policies as policies import numpy as np from gymnasium import spaces -from lerobot.envs.utils import _LazyAsyncVectorEnv from lerobot.types import RobotObservation +from .utils import _LazyAsyncVectorEnv + # ---- Load configuration data from the external JSON file ---- CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" try: diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index ff5f53735..b0d834a05 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -27,11 +27,12 @@ import torch from huggingface_hub import hf_hub_download, snapshot_download from torch import Tensor -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.envs.configs import EnvConfig +from lerobot.configs import FeatureType, PolicyFeature from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import get_channel_first_image_shape +from .configs import EnvConfig + def _convert_nested_dict(d): result = {} diff --git a/src/lerobot/model/__init__.py b/src/lerobot/model/__init__.py new file mode 100644 index 000000000..2f82e5053 --- /dev/null +++ b/src/lerobot/model/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Kinematics utilities for robot modeling. + +from .kinematics import RobotKinematics as RobotKinematics + +__all__ = ["RobotKinematics"] diff --git a/src/lerobot/motors/__init__.py b/src/lerobot/motors/__init__.py index 5df80d5ba..63ac14a38 100644 --- a/src/lerobot/motors/__init__.py +++ b/src/lerobot/motors/__init__.py @@ -19,3 +19,5 @@ from .motors_bus import ( MotorCalibration, MotorNormMode, ) + +__all__ = ["Motor", "MotorCalibration", "MotorNormMode"] diff --git a/src/lerobot/motors/damiao/__init__.py b/src/lerobot/motors/damiao/__init__.py index 8240138cf..5a98fa4d2 100644 --- a/src/lerobot/motors/damiao/__init__.py +++ b/src/lerobot/motors/damiao/__init__.py @@ -15,4 +15,6 @@ # limitations under the License. from .damiao import DamiaoMotorsBus -from .tables import * +from .tables import * # noqa: F403 — hardware constant tables + +__all__ = ["DamiaoMotorsBus"] diff --git a/src/lerobot/motors/dynamixel/__init__.py b/src/lerobot/motors/dynamixel/__init__.py index 425f8538a..01fcadf4f 100644 --- a/src/lerobot/motors/dynamixel/__init__.py +++ b/src/lerobot/motors/dynamixel/__init__.py @@ -15,4 +15,6 @@ # limitations under the License. from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode -from .tables import * +from .tables import * # noqa: F403 — hardware constant tables + +__all__ = ["DriveMode", "DynamixelMotorsBus", "OperatingMode", "TorqueMode"] diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index bca455dc5..4502bd668 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -21,6 +21,9 @@ import logging from copy import deepcopy from enum import Enum +from typing import TYPE_CHECKING + +from lerobot.utils.import_utils import _dynamixel_sdk_available, require_package from ..encoding_utils import decode_twos_complement, encode_twos_complement from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address @@ -33,6 +36,11 @@ from .tables import ( MODEL_RESOLUTION, ) +if TYPE_CHECKING or _dynamixel_sdk_available: + import dynamixel_sdk as dxl +else: + dxl = None + PROTOCOL_VERSION = 2.0 DEFAULT_BAUDRATE = 1_000_000 DEFAULT_TIMEOUT_MS = 1000 @@ -82,23 +90,6 @@ class TorqueMode(Enum): DISABLED = 0 -def _split_into_byte_chunks(value: int, length: int) -> list[int]: - import dynamixel_sdk as dxl - - if length == 1: - data = [value] - elif length == 2: - data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] - elif length == 4: - data = [ - dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), - dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), - dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), - dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), - ] - return data - - class DynamixelMotorsBus(SerialMotorsBus): """ The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with @@ -123,9 +114,8 @@ class DynamixelMotorsBus(SerialMotorsBus): motors: dict[str, Motor], calibration: dict[str, MotorCalibration] | None = None, ): + require_package("dynamixel-sdk", extra="dynamixel", import_name="dynamixel_sdk") super().__init__(port, motors, calibration) - import dynamixel_sdk as dxl - self.port_handler = dxl.PortHandler(self.port) self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION) self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) @@ -244,7 +234,18 @@ class DynamixelMotorsBus(SerialMotorsBus): return half_turn_homings def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: - return _split_into_byte_chunks(value, length) + if length == 1: + data = [value] + elif length == 2: + data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] + elif length == 4: + data = [ + dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), + ] + return data def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: for n_try in range(1 + num_retry): diff --git a/src/lerobot/motors/feetech/__init__.py b/src/lerobot/motors/feetech/__init__.py index 75da2d221..6c06d8b95 100644 --- a/src/lerobot/motors/feetech/__init__.py +++ b/src/lerobot/motors/feetech/__init__.py @@ -15,4 +15,6 @@ # limitations under the License. from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode -from .tables import * +from .tables import * # noqa: F403 — hardware constant tables + +__all__ = ["DriveMode", "FeetechMotorsBus", "OperatingMode", "TorqueMode"] diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 58a65310d..9b1e0fb7e 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -16,6 +16,9 @@ import logging from copy import deepcopy from enum import Enum from pprint import pformat +from typing import TYPE_CHECKING + +from lerobot.utils.import_utils import _feetech_sdk_available, require_package from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address @@ -32,6 +35,11 @@ from .tables import ( SCAN_BAUDRATES, ) +if TYPE_CHECKING or _feetech_sdk_available: + import scservo_sdk as scs +else: + scs = None + DEFAULT_PROTOCOL_VERSION = 0 DEFAULT_BAUDRATE = 1_000_000 DEFAULT_TIMEOUT_MS = 1000 @@ -65,23 +73,6 @@ class TorqueMode(Enum): DISABLED = 0 -def _split_into_byte_chunks(value: int, length: int) -> list[int]: - import scservo_sdk as scs - - if length == 1: - data = [value] - elif length == 2: - data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] - elif length == 4: - data = [ - scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), - scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), - scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), - scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), - ] - return data - - def patch_setPacketTimeout(self, packet_length): # noqa: N802 """ HACK: This patches the PortHandler behavior to set the correct packet timeouts. @@ -119,11 +110,10 @@ class FeetechMotorsBus(SerialMotorsBus): calibration: dict[str, MotorCalibration] | None = None, protocol_version: int = DEFAULT_PROTOCOL_VERSION, ): + require_package("feetech-servo-sdk", extra="feetech", import_name="scservo_sdk") super().__init__(port, motors, calibration) self.protocol_version = protocol_version self._assert_same_protocol() - import scservo_sdk as scs - self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign] @@ -195,8 +185,6 @@ class FeetechMotorsBus(SerialMotorsBus): raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.") def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: - import scservo_sdk as scs - model = self.motors[motor].model search_baudrates = ( [initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model] @@ -329,11 +317,20 @@ class FeetechMotorsBus(SerialMotorsBus): return ids_values def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: - return _split_into_byte_chunks(value, length) + if length == 1: + data = [value] + elif length == 2: + data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] + elif length == 4: + data = [ + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), + scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), + scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), + ] + return data def _broadcast_ping(self) -> tuple[dict[int, int], int]: - import scservo_sdk as scs - data_list: dict[int, int] = {} status_length = 6 diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 509f5e95f..209489bb9 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -29,12 +29,22 @@ from dataclasses import dataclass from enum import Enum from functools import cached_property from pprint import pformat -from typing import Protocol +from typing import TYPE_CHECKING, Protocol -import serial -from deepdiff import DeepDiff from tqdm import tqdm +from lerobot.utils.import_utils import _deepdiff_available, _serial_available, require_package + +if TYPE_CHECKING or _serial_available: + import serial +else: + serial = None # type: ignore[assignment] + +if TYPE_CHECKING or _deepdiff_available: + from deepdiff import DeepDiff +else: + DeepDiff = None # type: ignore[assignment, misc] + from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.utils import enter_pressed, move_cursor_up @@ -346,6 +356,8 @@ class SerialMotorsBus(MotorsBusBase): motors: dict[str, Motor], calibration: dict[str, MotorCalibration] | None = None, ): + require_package("pyserial", extra="hardware", import_name="serial") + require_package("deepdiff", extra="hardware") super().__init__(port, motors, calibration) self.port_handler: PortHandler diff --git a/src/lerobot/motors/robstride/__init__.py b/src/lerobot/motors/robstride/__init__.py index 7933ac6fa..4729b3968 100644 --- a/src/lerobot/motors/robstride/__init__.py +++ b/src/lerobot/motors/robstride/__init__.py @@ -15,4 +15,6 @@ # limitations under the License. from .robstride import RobstrideMotorsBus -from .tables import * +from .tables import * # noqa: F403 — hardware constant tables + +__all__ = ["RobstrideMotorsBus"] diff --git a/src/lerobot/optim/__init__.py b/src/lerobot/optim/__init__.py index de2c4c996..46676027b 100644 --- a/src/lerobot/optim/__init__.py +++ b/src/lerobot/optim/__init__.py @@ -12,4 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .optimizers import OptimizerConfig as OptimizerConfig +from .optimizers import ( + AdamConfig as AdamConfig, + AdamWConfig as AdamWConfig, + MultiAdamConfig as MultiAdamConfig, + OptimizerConfig as OptimizerConfig, + SGDConfig as SGDConfig, + XVLAAdamWConfig as XVLAAdamWConfig, + load_optimizer_state, + save_optimizer_state, +) +from .schedulers import ( + CosineDecayWithWarmupSchedulerConfig as CosineDecayWithWarmupSchedulerConfig, + DiffuserSchedulerConfig as DiffuserSchedulerConfig, + LRSchedulerConfig as LRSchedulerConfig, + VQBeTSchedulerConfig as VQBeTSchedulerConfig, + load_scheduler_state, + save_scheduler_state, +) + +# NOTE: make_optimizer_and_scheduler is intentionally NOT re-exported here +# to avoid circular dependencies (it imports lerobot.configs.train and lerobot.policies). +# Import directly: ``from lerobot.optim.factory import make_optimizer_and_scheduler`` + +__all__ = [ + # Optimizer configs + "AdamConfig", + "AdamWConfig", + "MultiAdamConfig", + "OptimizerConfig", + "SGDConfig", + "XVLAAdamWConfig", + # Scheduler configs + "CosineDecayWithWarmupSchedulerConfig", + "DiffuserSchedulerConfig", + "LRSchedulerConfig", + "VQBeTSchedulerConfig", + # State management + "load_optimizer_state", + "load_scheduler_state", + "save_optimizer_state", + "save_scheduler_state", +] diff --git a/src/lerobot/optim/factory.py b/src/lerobot/optim/factory.py index 699289993..ce519e0b2 100644 --- a/src/lerobot/optim/factory.py +++ b/src/lerobot/optim/factory.py @@ -19,7 +19,7 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from lerobot.configs.train import TrainPipelineConfig -from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies import PreTrainedPolicy def make_optimizer_and_scheduler( diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py index e2e3d8937..0bdd7a37e 100644 --- a/src/lerobot/optim/optimizers.py +++ b/src/lerobot/optim/optimizers.py @@ -23,13 +23,12 @@ import draccus import torch from safetensors.torch import load_file, save_file -from lerobot.datasets.io_utils import write_json -from lerobot.datasets.utils import flatten_dict, unflatten_dict from lerobot.utils.constants import ( OPTIMIZER_PARAM_GROUPS, OPTIMIZER_STATE, ) -from lerobot.utils.io_utils import deserialize_json_into_object +from lerobot.utils.io_utils import deserialize_json_into_object, write_json +from lerobot.utils.utils import flatten_dict, unflatten_dict # Type alias for parameters accepted by optimizer build() methods. # This matches PyTorch's optimizer signature while also supporting: diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index 19c3fd7bd..914edd2db 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -23,9 +23,8 @@ import draccus from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from lerobot.datasets.io_utils import write_json from lerobot.utils.constants import SCHEDULER_STATE -from lerobot.utils.io_utils import deserialize_json_into_object +from lerobot.utils.io_utils import deserialize_json_into_object, write_json @dataclass @@ -48,6 +47,9 @@ class DiffuserSchedulerConfig(LRSchedulerConfig): num_warmup_steps: int | None = None def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + from lerobot.utils.import_utils import require_package + + require_package("diffusers", extra="diffusion") from diffusers.optimization import get_scheduler kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer} diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 55ce09cf9..e138a84d9 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -14,30 +14,55 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig +from .factory import get_policy_class, make_policy, make_policy_config, make_pre_post_processors from .groot.configuration_groot import GrootConfig as GrootConfig from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig from .pi0.configuration_pi0 import PI0Config as PI0Config from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi05.configuration_pi05 import PI05Config as PI05Config +from .pretrained import PreTrainedPolicy as PreTrainedPolicy +from .rtc import ActionInterpolator as ActionInterpolator +from .sac.configuration_sac import SACConfig as SACConfig +from .sac.reward_model.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig +from .sarm.configuration_sarm import SARMConfig as SARMConfig from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig -from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig +from .utils import make_robot_action, prepare_observation_for_inference from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig from .wall_x.configuration_wall_x import WallXConfig as WallXConfig from .xvla.configuration_xvla import XVLAConfig as XVLAConfig +# NOTE: Policy modeling classes (e.g., SACPolicy) are intentionally NOT re-exported here. +# They have heavy optional dependencies and are loaded lazily via get_policy_class(). +# Import directly: ``from lerobot.policies.sac.modeling_sac import SACPolicy`` + __all__ = [ + # Configuration classes "ACTConfig", "DiffusionConfig", + "GrootConfig", "MultiTaskDiTConfig", "PI0Config", - "PI05Config", "PI0FastConfig", - "SmolVLAConfig", + "PI05Config", + "RewardClassifierConfig", + "SACConfig", "SARMConfig", + "SmolVLAConfig", "TDMPCConfig", "VQBeTConfig", - "GrootConfig", - "XVLAConfig", "WallXConfig", + "XVLAConfig", + # Base class + "PreTrainedPolicy", + # RTC utilities + "ActionInterpolator", + # Utility functions + "make_robot_action", + "prepare_observation_for_inference", + # Factory functions + "get_policy_class", + "make_policy", + "make_policy_config", + "make_pre_post_processors", ] diff --git a/src/lerobot/policies/act/__init__.py b/src/lerobot/policies/act/__init__.py new file mode 100644 index 000000000..44f15189f --- /dev/null +++ b/src/lerobot/policies/act/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_act import ACTConfig +from .modeling_act import ACTPolicy +from .processor_act import make_act_pre_post_processors + +__all__ = ["ACTConfig", "ACTPolicy", "make_act_pre_post_processors"] diff --git a/src/lerobot/policies/act/configuration_act.py b/src/lerobot/policies/act/configuration_act.py index bd89185fd..b5c3d68f1 100644 --- a/src/lerobot/policies/act/configuration_act.py +++ b/src/lerobot/policies/act/configuration_act.py @@ -15,9 +15,8 @@ # limitations under the License. from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode -from lerobot.optim.optimizers import AdamWConfig +from lerobot.configs import NormalizationMode, PreTrainedConfig +from lerobot.optim import AdamWConfig @PreTrainedConfig.register_subclass("act") diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index a5c48eb3d..0120258ee 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -33,10 +33,11 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.policies.act.configuration_act import ACTConfig -from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE +from ..pretrained import PreTrainedPolicy +from .configuration_act import ACTConfig + class ACTPolicy(PreTrainedPolicy): """ diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index 727b18cef..d87ade900 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -17,7 +17,6 @@ from typing import Any import torch -from lerobot.policies.act.configuration_act import ACTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -26,10 +25,13 @@ from lerobot.processor import ( PolicyProcessorPipeline, RenameObservationsProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from .configuration_act import ACTConfig + def make_act_pre_post_processors( config: ACTConfig, diff --git a/src/lerobot/policies/diffusion/__init__.py b/src/lerobot/policies/diffusion/__init__.py new file mode 100644 index 000000000..4f6ee820a --- /dev/null +++ b/src/lerobot/policies/diffusion/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_diffusion import DiffusionConfig +from .modeling_diffusion import DiffusionPolicy +from .processor_diffusion import make_diffusion_pre_post_processors + +__all__ = ["DiffusionConfig", "DiffusionPolicy", "make_diffusion_pre_post_processors"] diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 91b3df214..8e3d4bf19 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -16,10 +16,8 @@ # limitations under the License. from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode -from lerobot.optim.optimizers import AdamConfig -from lerobot.optim.schedulers import DiffuserSchedulerConfig +from lerobot.configs import NormalizationMode, PreTrainedConfig +from lerobot.optim import AdamConfig, DiffuserSchedulerConfig @PreTrainedConfig.register_subclass("diffusion") diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index aa8d5dd14..5b3b97571 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -29,19 +29,18 @@ import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision -from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from torch import Tensor, nn -from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import ( +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE + +from ..pretrained import PreTrainedPolicy +from ..utils import ( get_device_from_parameters, get_dtype_from_parameters, get_output_shape, populate_queues, ) -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE +from .configuration_diffusion import DiffusionConfig class DiffusionPolicy(PreTrainedPolicy): @@ -151,11 +150,17 @@ class DiffusionPolicy(PreTrainedPolicy): return loss, None -def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: +def _make_noise_scheduler(name: str, **kwargs: dict): """ Factory for noise scheduler instances of the requested type. All kwargs are passed to the scheduler. """ + from lerobot.utils.import_utils import require_package + + require_package("diffusers", extra="diffusion") + from diffusers.schedulers.scheduling_ddim import DDIMScheduler + from diffusers.schedulers.scheduling_ddpm import DDPMScheduler + if name == "DDPM": return DDPMScheduler(**kwargs) elif name == "DDIM": diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py index a7799be64..c4bc17680 100644 --- a/src/lerobot/policies/diffusion/processor_diffusion.py +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -18,7 +18,6 @@ from typing import Any import torch -from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -27,10 +26,13 @@ from lerobot.processor import ( PolicyProcessorPipeline, RenameObservationsProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from .configuration_diffusion import DiffusionConfig + def make_diffusion_pre_post_processors( config: DiffusionConfig, diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 501dd7af1..611a6e9bc 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -18,34 +18,19 @@ from __future__ import annotations import importlib import logging -from typing import Any, TypedDict, Unpack +from typing import TYPE_CHECKING, Any, TypedDict, Unpack import torch -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType -from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata -from lerobot.datasets.feature_utils import dataset_to_policy_features -from lerobot.envs.configs import EnvConfig -from lerobot.envs.utils import env_to_policy_features -from lerobot.policies.act.configuration_act import ACTConfig -from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.policies.groot.configuration_groot import GrootConfig -from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig -from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi05.configuration_pi05 import PI05Config -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.sac.configuration_sac import SACConfig -from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig -from lerobot.policies.sarm.configuration_sarm import SARMConfig -from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig -from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig -from lerobot.policies.utils import validate_visual_features_consistency -from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig -from lerobot.policies.wall_x.configuration_wall_x import WallXConfig -from lerobot.policies.xvla.configuration_xvla import XVLAConfig -from lerobot.processor import PolicyProcessorPipeline -from lerobot.processor.converters import ( +if TYPE_CHECKING: + from lerobot.datasets import LeRobotDatasetMetadata + +from lerobot.configs import FeatureType, PreTrainedConfig +from lerobot.envs import EnvConfig, env_to_policy_features +from lerobot.processor import ( + AbsoluteActionsProcessorStep, + PolicyProcessorPipeline, + RelativeActionsProcessorStep, batch_to_transition, policy_action_to_transition, transition_to_batch, @@ -57,6 +42,24 @@ from lerobot.utils.constants import ( POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, ) +from lerobot.utils.feature_utils import dataset_to_policy_features + +from .act.configuration_act import ACTConfig +from .diffusion.configuration_diffusion import DiffusionConfig +from .groot.configuration_groot import GrootConfig +from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig +from .pi0.configuration_pi0 import PI0Config +from .pi05.configuration_pi05 import PI05Config +from .pretrained import PreTrainedPolicy +from .sac.configuration_sac import SACConfig +from .sac.reward_model.configuration_classifier import RewardClassifierConfig +from .sarm.configuration_sarm import SARMConfig +from .smolvla.configuration_smolvla import SmolVLAConfig +from .tdmpc.configuration_tdmpc import TDMPCConfig +from .utils import validate_visual_features_consistency +from .vqbet.configuration_vqbet import VQBeTConfig +from .wall_x.configuration_wall_x import WallXConfig +from .xvla.configuration_xvla import XVLAConfig def _reconnect_relative_absolute_steps( @@ -69,11 +72,6 @@ def _reconnect_relative_absolute_steps( the RelativeActionsProcessorStep so it can read the cached state at inference time. That reference is not serializable, so we re-establish it here after loading. """ - from lerobot.processor.relative_action_processor import ( - AbsoluteActionsProcessorStep, - RelativeActionsProcessorStep, - ) - relative_step = next((s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep)), None) if relative_step is None: return @@ -99,63 +97,63 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: NotImplementedError: If the policy name is not recognized. """ if name == "tdmpc": - from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy + from .tdmpc.modeling_tdmpc import TDMPCPolicy return TDMPCPolicy elif name == "diffusion": - from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy + from .diffusion.modeling_diffusion import DiffusionPolicy return DiffusionPolicy elif name == "act": - from lerobot.policies.act.modeling_act import ACTPolicy + from .act.modeling_act import ACTPolicy return ACTPolicy elif name == "multi_task_dit": - from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy + from .multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy return MultiTaskDiTPolicy elif name == "vqbet": - from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy + from .vqbet.modeling_vqbet import VQBeTPolicy return VQBeTPolicy elif name == "pi0": - from lerobot.policies.pi0.modeling_pi0 import PI0Policy + from .pi0.modeling_pi0 import PI0Policy return PI0Policy elif name == "pi0_fast": - from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy + from .pi0_fast.modeling_pi0_fast import PI0FastPolicy return PI0FastPolicy elif name == "pi05": - from lerobot.policies.pi05.modeling_pi05 import PI05Policy + from .pi05.modeling_pi05 import PI05Policy return PI05Policy elif name == "sac": - from lerobot.policies.sac.modeling_sac import SACPolicy + from .sac.modeling_sac import SACPolicy return SACPolicy elif name == "reward_classifier": - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier + from .sac.reward_model.modeling_classifier import Classifier return Classifier elif name == "smolvla": - from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy + from .smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy elif name == "sarm": - from lerobot.policies.sarm.modeling_sarm import SARMRewardModel + from .sarm.modeling_sarm import SARMRewardModel return SARMRewardModel elif name == "groot": - from lerobot.policies.groot.modeling_groot import GrootPolicy + from .groot.modeling_groot import GrootPolicy return GrootPolicy elif name == "xvla": - from lerobot.policies.xvla.modeling_xvla import XVLAPolicy + from .xvla.modeling_xvla import XVLAPolicy return XVLAPolicy elif name == "wall_x": - from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy + from .wall_x.modeling_wall_x import WallXPolicy return WallXPolicy else: @@ -315,7 +313,7 @@ def make_pre_post_processors( # Create a new processor based on policy type if isinstance(policy_cfg, TDMPCConfig): - from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors + from .tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors processors = make_tdmpc_pre_post_processors( config=policy_cfg, @@ -323,7 +321,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, DiffusionConfig): - from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors + from .diffusion.processor_diffusion import make_diffusion_pre_post_processors processors = make_diffusion_pre_post_processors( config=policy_cfg, @@ -331,7 +329,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, ACTConfig): - from lerobot.policies.act.processor_act import make_act_pre_post_processors + from .act.processor_act import make_act_pre_post_processors processors = make_act_pre_post_processors( config=policy_cfg, @@ -339,7 +337,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, MultiTaskDiTConfig): - from lerobot.policies.multi_task_dit.processor_multi_task_dit import ( + from .multi_task_dit.processor_multi_task_dit import ( make_multi_task_dit_pre_post_processors, ) @@ -349,7 +347,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, VQBeTConfig): - from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors + from .vqbet.processor_vqbet import make_vqbet_pre_post_processors processors = make_vqbet_pre_post_processors( config=policy_cfg, @@ -357,7 +355,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, PI0Config): - from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors + from .pi0.processor_pi0 import make_pi0_pre_post_processors processors = make_pi0_pre_post_processors( config=policy_cfg, @@ -365,7 +363,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, PI05Config): - from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors + from .pi05.processor_pi05 import make_pi05_pre_post_processors processors = make_pi05_pre_post_processors( config=policy_cfg, @@ -373,7 +371,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, SACConfig): - from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors + from .sac.processor_sac import make_sac_pre_post_processors processors = make_sac_pre_post_processors( config=policy_cfg, @@ -381,7 +379,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, RewardClassifierConfig): - from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor + from .sac.reward_model.processor_classifier import make_classifier_processor processors = make_classifier_processor( config=policy_cfg, @@ -389,7 +387,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, SmolVLAConfig): - from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors + from .smolvla.processor_smolvla import make_smolvla_pre_post_processors processors = make_smolvla_pre_post_processors( config=policy_cfg, @@ -397,7 +395,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, SARMConfig): - from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors + from .sarm.processor_sarm import make_sarm_pre_post_processors processors = make_sarm_pre_post_processors( config=policy_cfg, @@ -405,7 +403,7 @@ def make_pre_post_processors( dataset_meta=kwargs.get("dataset_meta"), ) elif isinstance(policy_cfg, GrootConfig): - from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors + from .groot.processor_groot import make_groot_pre_post_processors processors = make_groot_pre_post_processors( config=policy_cfg, @@ -413,7 +411,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, XVLAConfig): - from lerobot.policies.xvla.processor_xvla import ( + from .xvla.processor_xvla import ( make_xvla_pre_post_processors, ) @@ -423,7 +421,7 @@ def make_pre_post_processors( ) elif isinstance(policy_cfg, WallXConfig): - from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors + from .wall_x.processor_wall_x import make_wall_x_pre_post_processors processors = make_wall_x_pre_post_processors( config=policy_cfg, diff --git a/src/lerobot/policies/groot/action_head/__init__.py b/src/lerobot/policies/groot/action_head/__init__.py index 3159bfe65..63ffc39e6 100644 --- a/src/lerobot/policies/groot/action_head/__init__.py +++ b/src/lerobot/policies/groot/action_head/__init__.py @@ -12,3 +12,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +__all__: list[str] = [] diff --git a/src/lerobot/policies/groot/action_head/cross_attention_dit.py b/src/lerobot/policies/groot/action_head/cross_attention_dit.py index 40f7ba603..a4cd1a0b7 100755 --- a/src/lerobot/policies/groot/action_head/cross_attention_dit.py +++ b/src/lerobot/policies/groot/action_head/cross_attention_dit.py @@ -14,21 +14,37 @@ # limitations under the License. +from typing import TYPE_CHECKING + import torch import torch.nn.functional as F # noqa: N812 -from diffusers import ConfigMixin, ModelMixin -from diffusers.configuration_utils import register_to_config -from diffusers.models.attention import Attention, FeedForward -from diffusers.models.embeddings import ( - SinusoidalPositionalEmbedding, - TimestepEmbedding, - Timesteps, -) from torch import nn +from lerobot.utils.import_utils import _diffusers_available, require_package + +if TYPE_CHECKING or _diffusers_available: + from diffusers import ConfigMixin, ModelMixin + from diffusers.configuration_utils import register_to_config + from diffusers.models.attention import Attention, FeedForward + from diffusers.models.embeddings import ( + SinusoidalPositionalEmbedding, + TimestepEmbedding, + Timesteps, + ) +else: + ConfigMixin = object + ModelMixin = nn.Module + register_to_config = lambda fn: fn # noqa: E731 + Attention = None + FeedForward = None + SinusoidalPositionalEmbedding = None + TimestepEmbedding = None + Timesteps = None + class TimestepEncoder(nn.Module): def __init__(self, embedding_dim, compute_dtype=torch.float32): + require_package("diffusers", extra="groot") super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) @@ -88,6 +104,7 @@ class BasicTransformerBlock(nn.Module): ff_bias: bool = True, attention_out_bias: bool = True, ): + require_package("diffusers", extra="groot") super().__init__() self.dim = dim self.num_attention_heads = num_attention_heads diff --git a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py index bfc456ba0..4fda21ca5 100644 --- a/src/lerobot/policies/groot/action_head/flow_matching_action_head.py +++ b/src/lerobot/policies/groot/action_head/flow_matching_action_head.py @@ -31,11 +31,10 @@ else: PretrainedConfig = object BatchFeature = None -from lerobot.policies.groot.action_head.action_encoder import ( +from .action_encoder import ( SinusoidalPositionalEncoding, swish, ) - from .cross_attention_dit import DiT, SelfAttentionTransformer diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index 4f3d78222..17cb631d7 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -16,10 +16,8 @@ from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import ACTION, OBS_STATE diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index 06ff5a04d..fc753839a 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -41,12 +41,13 @@ try: except ImportError: tree = None -from lerobot.policies.groot.action_head.flow_matching_action_head import ( +from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME + +from .action_head.flow_matching_action_head import ( FlowmatchingActionHead, FlowmatchingActionHeadConfig, ) -from lerobot.policies.groot.utils import ensure_eagle_cache_ready -from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME +from .utils import ensure_eagle_cache_ready DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve()) DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5" diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index 9a479b8f9..4b612bca4 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -41,12 +41,13 @@ from typing import TypeVar import torch from torch import Tensor -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.policies.groot.configuration_groot import GrootConfig -from lerobot.policies.groot.groot_n1 import GR00TN15 -from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.configs import FeatureType, PolicyFeature from lerobot.utils.constants import ACTION, OBS_IMAGES +from ..pretrained import PreTrainedPolicy +from .configuration_groot import GrootConfig +from .groot_n1 import GR00TN15 + T = TypeVar("T", bound="GrootPolicy") diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 8bf9dabca..3367de711 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -30,12 +30,11 @@ else: AutoProcessor = None ProcessorMixin = object -from lerobot.configs.types import ( +from lerobot.configs import ( FeatureType, NormalizationMode, PolicyFeature, ) -from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -44,8 +43,6 @@ from lerobot.processor import ( ProcessorStep, ProcessorStepRegistry, RenameObservationsProcessorStep, -) -from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, ) @@ -60,6 +57,8 @@ from lerobot.utils.constants import ( POLICY_PREPROCESSOR_DEFAULT_NAME, ) +from .configuration_groot import GrootConfig + # Defaults for Eagle processor locations DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5" diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py index 061230687..33be3113f 100644 --- a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -17,10 +17,8 @@ import logging from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode -from lerobot.optim.optimizers import AdamConfig -from lerobot.optim.schedulers import DiffuserSchedulerConfig +from lerobot.configs import NormalizationMode, PreTrainedConfig +from lerobot.optim import AdamConfig, DiffuserSchedulerConfig @PreTrainedConfig.register_subclass("multi_task_dit") diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index 4fee851e0..8e5d1e3cb 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -34,21 +34,18 @@ import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 import torchvision -from diffusers.schedulers.scheduling_ddim import DDIMScheduler -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from torch import Tensor -from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.utils.import_utils import _transformers_available +from .configuration_multi_task_dit import MultiTaskDiTConfig + # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: from transformers import CLIPTextModel, CLIPVisionModel else: CLIPTextModel = None CLIPVisionModel = None -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import populate_queues from lerobot.utils.constants import ( ACTION, OBS_IMAGES, @@ -57,6 +54,9 @@ from lerobot.utils.constants import ( OBS_STATE, ) +from ..pretrained import PreTrainedPolicy +from ..utils import populate_queues + # -- Policy -- @@ -643,6 +643,12 @@ class DiffusionObjective(nn.Module): "prediction_type": config.prediction_type, } + from lerobot.utils.import_utils import require_package + + require_package("diffusers", extra="multi_task_dit") + from diffusers.schedulers.scheduling_ddim import DDIMScheduler + from diffusers.schedulers.scheduling_ddpm import DDPMScheduler + if config.noise_scheduler_type == "DDPM": self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs) elif config.noise_scheduler_type == "DDIM": diff --git a/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py index fc94599c2..5f5b9994e 100644 --- a/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/processor_multi_task_dit.py @@ -18,7 +18,6 @@ from typing import Any import torch -from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -28,10 +27,13 @@ from lerobot.processor import ( RenameObservationsProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from .configuration_multi_task_dit import MultiTaskDiTConfig + def make_multi_task_dit_pre_post_processors( config: MultiTaskDiTConfig, diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index cf4b636a3..a06315f07 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -16,13 +16,12 @@ from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig -from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from ..rtc.configuration_rtc import RTCConfig + DEFAULT_IMAGE_SIZE = 224 diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index aebf32964..22e4e6a26 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -33,7 +33,7 @@ if TYPE_CHECKING or _transformers_available: from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma - from lerobot.policies.pi_gemma import ( + from ..pi_gemma import ( PaliGemmaForConditionalGenerationWithPiGemma, PiGemmaForCausalLM, _gated_residual, @@ -48,10 +48,7 @@ else: PaliGemmaForConditionalGenerationWithPiGemma = None -from lerobot.configs.policies import PreTrainedConfig -from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config -from lerobot.policies.pretrained import PreTrainedPolicy, T -from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.configs import PreTrainedConfig from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, @@ -60,6 +57,10 @@ from lerobot.utils.constants import ( OPENPI_ATTENTION_MASK_VALUE, ) +from ..pretrained import PreTrainedPolicy, T +from ..rtc.modeling_rtc import RTCProcessor +from .configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config + class ActionSelectKwargs(TypedDict, total=False): inference_delay: int | None diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index 0302876a1..ad861f85b 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -18,8 +18,7 @@ from typing import Any import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.policies.pi0.configuration_pi0 import PI0Config +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.processor import ( AbsoluteActionsProcessorStep, AddBatchDimensionProcessorStep, @@ -34,10 +33,13 @@ from lerobot.processor import ( RenameObservationsProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from .configuration_pi0 import PI0Config + @ProcessorStepRegistry.register(name="pi0_new_line_processor") class Pi0NewLineProcessor(ComplementaryDataProcessorStep): diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 6760be0a2..124e85cc9 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -16,13 +16,12 @@ from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig -from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from ..rtc.configuration_rtc import RTCConfig + DEFAULT_IMAGE_SIZE = 224 diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 96c4002f2..a44817a74 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -33,7 +33,7 @@ if TYPE_CHECKING or _transformers_available: from transformers.models.auto import CONFIG_MAPPING from transformers.models.gemma import modeling_gemma - from lerobot.policies.pi_gemma import ( + from ..pi_gemma import ( PaliGemmaForConditionalGenerationWithPiGemma, PiGemmaForCausalLM, _gated_residual, @@ -46,10 +46,7 @@ else: _gated_residual = None layernorm_forward = None PaliGemmaForConditionalGenerationWithPiGemma = None -from lerobot.configs.policies import PreTrainedConfig -from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config -from lerobot.policies.pretrained import PreTrainedPolicy, T -from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.configs import PreTrainedConfig from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, @@ -57,6 +54,10 @@ from lerobot.utils.constants import ( OPENPI_ATTENTION_MASK_VALUE, ) +from ..pretrained import PreTrainedPolicy, T +from ..rtc.modeling_rtc import RTCProcessor +from .configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config + class ActionSelectKwargs(TypedDict, total=False): inference_delay: int | None diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index cb616af87..2d015b24f 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -21,8 +21,7 @@ from typing import Any import numpy as np import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.processor import ( AbsoluteActionsProcessorStep, AddBatchDimensionProcessorStep, @@ -36,8 +35,9 @@ from lerobot.processor import ( RenameObservationsProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_STATE, @@ -45,6 +45,8 @@ from lerobot.utils.constants import ( POLICY_PREPROCESSOR_DEFAULT_NAME, ) +from .configuration_pi05 import PI05Config + @ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step") @dataclass diff --git a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py index 6a645fae1..e5c6851f4 100644 --- a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py @@ -16,13 +16,12 @@ from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig -from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from ..rtc.configuration_rtc import RTCConfig + DEFAULT_IMAGE_SIZE = 224 diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index 1bcf9794c..e86b8ad27 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -38,7 +38,7 @@ if TYPE_CHECKING or _transformers_available: from transformers import AutoTokenizer from transformers.models.auto import CONFIG_MAPPING - from lerobot.policies.pi_gemma import ( + from ..pi_gemma import ( PaliGemmaForConditionalGenerationWithPiGemma, PiGemmaModel, ) @@ -48,10 +48,7 @@ else: PiGemmaModel = None PaliGemmaForConditionalGenerationWithPiGemma = None -from lerobot.configs.policies import PreTrainedConfig -from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig -from lerobot.policies.pretrained import PreTrainedPolicy, T -from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.configs import PreTrainedConfig from lerobot.utils.constants import ( ACTION, ACTION_TOKEN_MASK, @@ -61,6 +58,10 @@ from lerobot.utils.constants import ( OPENPI_ATTENTION_MASK_VALUE, ) +from ..pretrained import PreTrainedPolicy, T +from ..rtc.modeling_rtc import RTCProcessor +from .configuration_pi0_fast import PI0FastConfig + class ActionSelectKwargs(TypedDict, total=False): temperature: float | None diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py index c4a510615..60a519786 100644 --- a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py @@ -21,8 +21,7 @@ from typing import Any import numpy as np import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.processor import ( AbsoluteActionsProcessorStep, ActionTokenizerProcessorStep, @@ -37,8 +36,9 @@ from lerobot.processor import ( RenameObservationsProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_STATE, @@ -46,6 +46,8 @@ from lerobot.utils.constants import ( POLICY_PREPROCESSOR_DEFAULT_NAME, ) +from .configuration_pi0_fast import PI0FastConfig + @ProcessorStepRegistry.register(name="pi0_fast_prepare_state_tokenizer_processor_step") @dataclass diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index 70efeba6f..724f920f3 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -29,11 +29,12 @@ from huggingface_hub.errors import HfHubHTTPError from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from torch import Tensor, nn -from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig -from lerobot.policies.utils import log_model_loading_keys from lerobot.utils.hub import HubMixin +from .utils import log_model_loading_keys + T = TypeVar("T", bound="PreTrainedPolicy") diff --git a/src/lerobot/policies/rtc/__init__.py b/src/lerobot/policies/rtc/__init__.py index ac7b72ef7..7a29dcac0 100644 --- a/src/lerobot/policies/rtc/__init__.py +++ b/src/lerobot/policies/rtc/__init__.py @@ -14,11 +14,11 @@ """Real-Time Chunking (RTC) utilities for action-chunking policies.""" -from lerobot.policies.rtc.action_interpolator import ActionInterpolator -from lerobot.policies.rtc.action_queue import ActionQueue -from lerobot.policies.rtc.configuration_rtc import RTCConfig -from lerobot.policies.rtc.latency_tracker import LatencyTracker -from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from .action_interpolator import ActionInterpolator +from .action_queue import ActionQueue +from .configuration_rtc import RTCConfig +from .latency_tracker import LatencyTracker +from .modeling_rtc import RTCProcessor __all__ = [ "ActionInterpolator", diff --git a/src/lerobot/policies/rtc/action_queue.py b/src/lerobot/policies/rtc/action_queue.py index 3c20d6d21..dbbdc41df 100644 --- a/src/lerobot/policies/rtc/action_queue.py +++ b/src/lerobot/policies/rtc/action_queue.py @@ -27,7 +27,7 @@ from threading import Lock import torch from torch import Tensor -from lerobot.policies.rtc.configuration_rtc import RTCConfig +from .configuration_rtc import RTCConfig logger = logging.getLogger(__name__) diff --git a/src/lerobot/policies/rtc/configuration_rtc.py b/src/lerobot/policies/rtc/configuration_rtc.py index 70a8dfb09..c70fe3de0 100644 --- a/src/lerobot/policies/rtc/configuration_rtc.py +++ b/src/lerobot/policies/rtc/configuration_rtc.py @@ -23,7 +23,7 @@ Based on: from dataclasses import dataclass -from lerobot.configs.types import RTCAttentionSchedule +from lerobot.configs import RTCAttentionSchedule @dataclass diff --git a/src/lerobot/policies/rtc/modeling_rtc.py b/src/lerobot/policies/rtc/modeling_rtc.py index 280905adf..c1aeed328 100644 --- a/src/lerobot/policies/rtc/modeling_rtc.py +++ b/src/lerobot/policies/rtc/modeling_rtc.py @@ -27,9 +27,10 @@ import math import torch from torch import Tensor -from lerobot.configs.types import RTCAttentionSchedule -from lerobot.policies.rtc.configuration_rtc import RTCConfig -from lerobot.policies.rtc.debug_tracker import Tracker +from lerobot.configs import RTCAttentionSchedule + +from .configuration_rtc import RTCConfig +from .debug_tracker import Tracker logger = logging.getLogger(__name__) diff --git a/src/lerobot/policies/sac/__init__.py b/src/lerobot/policies/sac/__init__.py new file mode 100644 index 000000000..cf5f149f3 --- /dev/null +++ b/src/lerobot/policies/sac/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_sac import SACConfig +from .modeling_sac import SACPolicy +from .processor_sac import make_sac_pre_post_processors + +__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"] diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/sac/configuration_sac.py index ada12330c..db0a77672 100644 --- a/src/lerobot/policies/sac/configuration_sac.py +++ b/src/lerobot/policies/sac/configuration_sac.py @@ -17,9 +17,8 @@ from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode -from lerobot.optim.optimizers import MultiAdamConfig +from lerobot.configs import NormalizationMode, PreTrainedConfig +from lerobot.optim import MultiAdamConfig from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index d5dd71a48..cc7030ce2 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -28,11 +28,12 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature -from lerobot.policies.utils import get_device_from_parameters from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE +from ..pretrained import PreTrainedPolicy +from ..utils import get_device_from_parameters +from .configuration_sac import SACConfig, is_image_feature + DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py index cf90e3cb4..3409307c2 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/sac/processor_sac.py @@ -19,7 +19,6 @@ from typing import Any import torch -from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -28,10 +27,13 @@ from lerobot.processor import ( PolicyProcessorPipeline, RenameObservationsProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from .configuration_sac import SACConfig + def make_sac_pre_post_processors( config: SACConfig, diff --git a/src/lerobot/policies/sac/reward_model/__init__.py b/src/lerobot/policies/sac/reward_model/__init__.py new file mode 100644 index 000000000..1504a9947 --- /dev/null +++ b/src/lerobot/policies/sac/reward_model/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_classifier import RewardClassifierConfig +from .modeling_classifier import Classifier +from .processor_classifier import make_classifier_processor + +__all__ = ["RewardClassifierConfig", "Classifier", "make_classifier_processor"] diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/policies/sac/reward_model/configuration_classifier.py index 879e3c1af..3a5bfa424 100644 --- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py +++ b/src/lerobot/policies/sac/reward_model/configuration_classifier.py @@ -15,10 +15,8 @@ # limitations under the License. from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode -from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig -from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.configs import NormalizationMode, PreTrainedConfig +from lerobot.optim import AdamWConfig, LRSchedulerConfig, OptimizerConfig from lerobot.utils.constants import OBS_IMAGE diff --git a/src/lerobot/policies/sac/reward_model/modeling_classifier.py b/src/lerobot/policies/sac/reward_model/modeling_classifier.py index dba6a174b..c8b7efe58 100644 --- a/src/lerobot/policies/sac/reward_model/modeling_classifier.py +++ b/src/lerobot/policies/sac/reward_model/modeling_classifier.py @@ -19,10 +19,11 @@ import logging import torch from torch import Tensor, nn -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.utils.constants import OBS_IMAGE, REWARD +from ...pretrained import PreTrainedPolicy +from .configuration_classifier import RewardClassifierConfig + class ClassifierOutput: """Wrapper for classifier outputs with additional metadata.""" diff --git a/src/lerobot/policies/sac/reward_model/processor_classifier.py b/src/lerobot/policies/sac/reward_model/processor_classifier.py index c2a34eab2..1f7a66e58 100644 --- a/src/lerobot/policies/sac/reward_model/processor_classifier.py +++ b/src/lerobot/policies/sac/reward_model/processor_classifier.py @@ -18,15 +18,17 @@ from typing import Any import torch -from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.processor import ( DeviceProcessorStep, IdentityProcessorStep, NormalizerProcessorStep, PolicyAction, PolicyProcessorPipeline, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + +from .configuration_classifier import RewardClassifierConfig def make_classifier_processor( diff --git a/src/lerobot/policies/sarm/__init__.py b/src/lerobot/policies/sarm/__init__.py new file mode 100644 index 000000000..b164c87ef --- /dev/null +++ b/src/lerobot/policies/sarm/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_sarm import SARMConfig +from .modeling_sarm import SARMRewardModel + +__all__ = ["SARMConfig", "SARMRewardModel"] diff --git a/src/lerobot/policies/sarm/compute_rabc_weights.py b/src/lerobot/policies/sarm/compute_rabc_weights.py index 485c1096b..07d0780b5 100644 --- a/src/lerobot/policies/sarm/compute_rabc_weights.py +++ b/src/lerobot/policies/sarm/compute_rabc_weights.py @@ -57,10 +57,11 @@ import pyarrow.parquet as pq import torch from tqdm import tqdm -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.sarm.modeling_sarm import SARMRewardModel -from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors -from lerobot.policies.sarm.sarm_utils import normalize_stage_tau +from lerobot.datasets import LeRobotDataset + +from .modeling_sarm import SARMRewardModel +from .processor_sarm import make_sarm_pre_post_processors +from .sarm_utils import normalize_stage_tau def get_reward_model_path_from_parquet(parquet_path: Path) -> str | None: diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 673422fe2..fc8daa055 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -22,10 +22,8 @@ Paper: https://arxiv.org/abs/2509.25358 from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import OBS_IMAGES, OBS_STATE diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index 6051d90f8..710554e4b 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -34,13 +34,14 @@ import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.sarm.configuration_sarm import SARMConfig -from lerobot.policies.sarm.sarm_utils import ( +from lerobot.utils.constants import OBS_STR + +from ..pretrained import PreTrainedPolicy +from .configuration_sarm import SARMConfig +from .sarm_utils import ( normalize_stage_tau, pad_state_to_max_dim, ) -from lerobot.utils.constants import OBS_STR class StageTransformer(nn.Module): diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index f377a7ffa..e939b3485 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -16,41 +16,60 @@ """SARM Processor for encoding images/text and generating stage+tau targets.""" +from __future__ import annotations + import random -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np -import pandas as pd import torch -from faker import Faker from PIL import Image -from transformers import CLIPModel, CLIPProcessor -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.policies.sarm.configuration_sarm import SARMConfig -from lerobot.policies.sarm.sarm_utils import ( +from lerobot.utils.import_utils import ( + _faker_available, + _pandas_available, + _transformers_available, + require_package, +) + +if TYPE_CHECKING or _transformers_available: + from transformers import CLIPModel, CLIPProcessor +else: + CLIPModel = None # type: ignore[assignment, misc] + CLIPProcessor = None # type: ignore[assignment, misc] + +if TYPE_CHECKING or _pandas_available: + import pandas as pd +else: + pd = None # type: ignore[assignment] + +if TYPE_CHECKING or _faker_available: + from faker import Faker +else: + Faker = None # type: ignore[assignment, misc] + +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyProcessorPipeline, + ProcessorStep, + RenameObservationsProcessorStep, + from_tensor_to_numpy, + policy_action_to_transition, + transition_to_policy_action, +) +from lerobot.types import EnvTransition, PolicyAction, TransitionKey +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME + +from .configuration_sarm import SARMConfig +from .sarm_utils import ( apply_rewind_augmentation, compute_absolute_indices, find_stage_and_tau, pad_state_to_max_dim, ) -from lerobot.processor import ( - AddBatchDimensionProcessorStep, - DeviceProcessorStep, - NormalizerProcessorStep, - PolicyAction, - PolicyProcessorPipeline, - ProcessorStep, - RenameObservationsProcessorStep, -) -from lerobot.processor.converters import ( - from_tensor_to_numpy, - policy_action_to_transition, - transition_to_policy_action, -) -from lerobot.processor.pipeline import PipelineFeatureType -from lerobot.types import EnvTransition, TransitionKey -from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME class SARMEncodingProcessorStep(ProcessorStep): @@ -63,6 +82,9 @@ class SARMEncodingProcessorStep(ProcessorStep): dataset_meta=None, dataset_stats: dict | None = None, ): + require_package("transformers", extra="sarm") + require_package("faker", extra="sarm") + require_package("pandas", extra="dataset") super().__init__() self.config = config self.image_key = image_key or config.image_key diff --git a/src/lerobot/policies/smolvla/__init__.py b/src/lerobot/policies/smolvla/__init__.py new file mode 100644 index 000000000..690f15860 --- /dev/null +++ b/src/lerobot/policies/smolvla/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_smolvla import SmolVLAConfig +from .modeling_smolvla import SmolVLAPolicy +from .processor_smolvla import make_smolvla_pre_post_processors + +__all__ = ["SmolVLAConfig", "SmolVLAPolicy", "make_smolvla_pre_post_processors"] diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index 5007abbb4..6d5288db3 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -14,15 +14,12 @@ from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import ( - CosineDecayWithWarmupSchedulerConfig, -) -from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import OBS_IMAGES +from ..rtc.configuration_rtc import RTCConfig + @PreTrainedConfig.register_subclass("smolvla") @dataclass diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 7110ba7d2..ee3ff4db9 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -60,16 +60,17 @@ import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.rtc.modeling_rtc import RTCProcessor -from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig -from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel -from lerobot.policies.utils import ( - populate_queues, -) from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.utils.device_utils import get_safe_dtype +from ..pretrained import PreTrainedPolicy +from ..rtc.modeling_rtc import RTCProcessor +from ..utils import ( + populate_queues, +) +from .configuration_smolvla import SmolVLAConfig +from .smolvlm_with_expert import SmolVLMWithExpertModel + class ActionSelectKwargs(TypedDict, total=False): inference_delay: int | None diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 3fc130aa1..8d6c8aca4 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -18,23 +18,23 @@ from typing import Any import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, - ComplementaryDataProcessorStep, DeviceProcessorStep, + NewLineTaskProcessorStep, NormalizerProcessorStep, PolicyAction, PolicyProcessorPipeline, - ProcessorStepRegistry, RenameObservationsProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from .configuration_smolvla import SmolVLAConfig + def make_smolvla_pre_post_processors( config: SmolVLAConfig, @@ -69,7 +69,7 @@ def make_smolvla_pre_post_processors( input_steps = [ RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one AddBatchDimensionProcessorStep(), - SmolVLANewLineProcessor(), + NewLineTaskProcessorStep(), TokenizerProcessorStep( tokenizer_name=config.vlm_model_name, padding=config.pad_language_to, @@ -101,41 +101,3 @@ def make_smolvla_pre_post_processors( to_output=transition_to_policy_action, ), ) - - -@ProcessorStepRegistry.register(name="smolvla_new_line_processor") -class SmolVLANewLineProcessor(ComplementaryDataProcessorStep): - """ - A processor step that ensures the 'task' description ends with a newline character. - - This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a - newline at the end of the prompt. It handles both single string tasks and lists - of string tasks. - """ - - def complementary_data(self, complementary_data): - if "task" not in complementary_data: - return complementary_data - - task = complementary_data["task"] - if task is None: - return complementary_data - - new_complementary_data = dict(complementary_data) - - # Handle both string and list of strings - if isinstance(task, str): - # Single string: add newline if not present - if not task.endswith("\n"): - new_complementary_data["task"] = f"{task}\n" - elif isinstance(task, list) and all(isinstance(t, str) for t in task): - # List of strings: add newline to each if not present - new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] - # If task is neither string nor list of strings, leave unchanged - - return new_complementary_data - - def transform_features( - self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] - ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: - return features diff --git a/src/lerobot/policies/smolvla/smolvlm_with_expert.py b/src/lerobot/policies/smolvla/smolvlm_with_expert.py index caca41dab..ea806f185 100644 --- a/src/lerobot/policies/smolvla/smolvlm_with_expert.py +++ b/src/lerobot/policies/smolvla/smolvlm_with_expert.py @@ -13,16 +13,27 @@ # limitations under the License. import copy +from typing import TYPE_CHECKING import torch from torch import nn -from transformers import ( - AutoConfig, - AutoModel, - AutoModelForImageTextToText, - AutoProcessor, - SmolVLMForConditionalGeneration, -) + +from lerobot.utils.import_utils import _transformers_available, require_package + +if TYPE_CHECKING or _transformers_available: + from transformers import ( + AutoConfig, + AutoModel, + AutoModelForImageTextToText, + AutoProcessor, + SmolVLMForConditionalGeneration, + ) +else: + AutoConfig = None + AutoModel = None + AutoModelForImageTextToText = None + AutoProcessor = None + SmolVLMForConditionalGeneration = None def apply_rope(x, positions, max_wavelength=10_000): @@ -73,6 +84,7 @@ class SmolVLMWithExpertModel(nn.Module): device: str = "auto", ): super().__init__() + require_package("transformers", extra="smolvla") if load_vlm_weights: print(f"Loading {model_id} weights ...") self.vlm = AutoModelForImageTextToText.from_pretrained( diff --git a/src/lerobot/policies/tdmpc/__init__.py b/src/lerobot/policies/tdmpc/__init__.py new file mode 100644 index 000000000..5663e23c4 --- /dev/null +++ b/src/lerobot/policies/tdmpc/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_tdmpc import TDMPCConfig +from .modeling_tdmpc import TDMPCPolicy +from .processor_tdmpc import make_tdmpc_pre_post_processors + +__all__ = ["TDMPCConfig", "TDMPCPolicy", "make_tdmpc_pre_post_processors"] diff --git a/src/lerobot/policies/tdmpc/configuration_tdmpc.py b/src/lerobot/policies/tdmpc/configuration_tdmpc.py index 3ec493472..bb8a2cf96 100644 --- a/src/lerobot/policies/tdmpc/configuration_tdmpc.py +++ b/src/lerobot/policies/tdmpc/configuration_tdmpc.py @@ -16,9 +16,8 @@ # limitations under the License. from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode -from lerobot.optim.optimizers import AdamConfig +from lerobot.configs import NormalizationMode, PreTrainedConfig +from lerobot.optim import AdamConfig @PreTrainedConfig.register_subclass("tdmpc") diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index f83c82e21..a50bb9670 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -35,11 +35,12 @@ import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig -from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD +from ..pretrained import PreTrainedPolicy +from ..utils import get_device_from_parameters, get_output_shape, populate_queues +from .configuration_tdmpc import TDMPCConfig + class TDMPCPolicy(PreTrainedPolicy): """Implementation of TD-MPC learning + inference. diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py index 9b6f97e50..7afe956dc 100644 --- a/src/lerobot/policies/tdmpc/processor_tdmpc.py +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -18,7 +18,6 @@ from typing import Any import torch -from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -27,10 +26,13 @@ from lerobot.processor import ( PolicyProcessorPipeline, RenameObservationsProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from .configuration_tdmpc import TDMPCConfig + def make_tdmpc_pre_post_processors( config: TDMPCConfig, diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 82ab51005..c37127813 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -21,11 +21,10 @@ import numpy as np import torch from torch import nn -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.datasets.feature_utils import build_dataset_frame +from lerobot.configs import FeatureType, PolicyFeature, PreTrainedConfig from lerobot.types import PolicyAction, RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.feature_utils import build_dataset_frame def populate_queues( diff --git a/src/lerobot/policies/vqbet/__init__.py b/src/lerobot/policies/vqbet/__init__.py new file mode 100644 index 000000000..842dd5d0b --- /dev/null +++ b/src/lerobot/policies/vqbet/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_vqbet import VQBeTConfig +from .modeling_vqbet import VQBeTPolicy +from .processor_vqbet import make_vqbet_pre_post_processors + +__all__ = ["VQBeTConfig", "VQBeTPolicy", "make_vqbet_pre_post_processors"] diff --git a/src/lerobot/policies/vqbet/configuration_vqbet.py b/src/lerobot/policies/vqbet/configuration_vqbet.py index 32906e528..d02745321 100644 --- a/src/lerobot/policies/vqbet/configuration_vqbet.py +++ b/src/lerobot/policies/vqbet/configuration_vqbet.py @@ -18,10 +18,8 @@ from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode -from lerobot.optim.optimizers import AdamConfig -from lerobot.optim.schedulers import VQBeTSchedulerConfig +from lerobot.configs import NormalizationMode, PreTrainedConfig +from lerobot.optim import AdamConfig, VQBeTSchedulerConfig @PreTrainedConfig.register_subclass("vqbet") diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 6d3976b79..153f7fe3c 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -27,12 +27,13 @@ import torch.nn.functional as F # noqa: N812 import torchvision from torch import Tensor, nn -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues -from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig -from lerobot.policies.vqbet.vqbet_utils import GPT, ResidualVQ from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from ..pretrained import PreTrainedPolicy +from ..utils import get_device_from_parameters, get_output_shape, populate_queues +from .configuration_vqbet import VQBeTConfig +from .vqbet_utils import GPT, ResidualVQ + # ruff: noqa: N806 diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py index 1e19ff779..f7b6a061e 100644 --- a/src/lerobot/policies/vqbet/processor_vqbet.py +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -19,7 +19,6 @@ from typing import Any import torch -from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -28,10 +27,13 @@ from lerobot.processor import ( PolicyProcessorPipeline, RenameObservationsProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from .configuration_vqbet import VQBeTConfig + def make_vqbet_pre_post_processors( config: VQBeTConfig, diff --git a/src/lerobot/policies/vqbet/vqbet_utils.py b/src/lerobot/policies/vqbet/vqbet_utils.py index 7b13577f6..f8bfcb06a 100644 --- a/src/lerobot/policies/vqbet/vqbet_utils.py +++ b/src/lerobot/policies/vqbet/vqbet_utils.py @@ -30,7 +30,7 @@ from torch import einsum, nn from torch.cuda.amp import autocast from torch.optim import Optimizer -from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from .configuration_vqbet import VQBeTConfig # ruff: noqa: N806 diff --git a/src/lerobot/policies/wall_x/__init__.py b/src/lerobot/policies/wall_x/__init__.py index d80c27bda..16fd2c8ab 100644 --- a/src/lerobot/policies/wall_x/__init__.py +++ b/src/lerobot/policies/wall_x/__init__.py @@ -15,5 +15,7 @@ # limitations under the License. from .configuration_wall_x import WallXConfig +from .modeling_wall_x import WallXPolicy +from .processor_wall_x import make_wall_x_pre_post_processors __all__ = ["WallXConfig", "WallXPolicy", "make_wall_x_pre_post_processors"] diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py index 5269c4e10..70576a46b 100644 --- a/src/lerobot/policies/wall_x/configuration_wall_x.py +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -14,10 +14,8 @@ from dataclasses import dataclass, field -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import ACTION, OBS_STATE diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index 84ee05743..bfecf3852 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -34,35 +34,31 @@ lerobot-train \ ``` """ +import logging import math from collections import deque from os import PathLike -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from peft import LoraConfig, get_peft_model from PIL import Image -from qwen_vl_utils.vision_process import smart_resize from torch import Tensor from torch.distributions import Beta from torch.nn import CrossEntropyLoss -from torchdiffeq import odeint -from transformers import AutoProcessor, BatchFeature -from transformers.cache_utils import ( - StaticCache, -) -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLForConditionalGeneration, -) -from transformers.utils import is_torchdynamo_compiling, logging -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import populate_queues -from lerobot.policies.wall_x.configuration_wall_x import WallXConfig -from lerobot.policies.wall_x.constant import ( +from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.import_utils import ( + _wallx_deps_available, + require_package, +) + +from ..pretrained import PreTrainedPolicy +from ..utils import populate_queues +from .configuration_wall_x import WallXConfig +from .constant import ( GENERATE_SUBTASK_RATIO, IMAGE_FACTOR, MAX_PIXELS, @@ -72,21 +68,47 @@ from lerobot.policies.wall_x.constant import ( RESOLUTION, TOKENIZER_MAX_LENGTH, ) -from lerobot.policies.wall_x.qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig -from lerobot.policies.wall_x.qwen_model.qwen2_5_vl_moe import ( - Qwen2_5_VisionTransformerPretrainedModel, - Qwen2_5_VLACausalLMOutputWithPast, - Qwen2_5_VLMoEModel, -) -from lerobot.policies.wall_x.utils import ( + +if TYPE_CHECKING or _wallx_deps_available: + from peft import LoraConfig, get_peft_model + from qwen_vl_utils.vision_process import smart_resize + from torchdiffeq import odeint + from transformers import AutoProcessor, BatchFeature + from transformers.cache_utils import StaticCache + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, + ) + from transformers.utils import is_torchdynamo_compiling + + from .qwen_model.configuration_qwen2_5_vl import Qwen2_5_VLConfig + from .qwen_model.qwen2_5_vl_moe import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLACausalLMOutputWithPast, + Qwen2_5_VLMoEModel, + ) +else: + LoraConfig = None + get_peft_model = None + smart_resize = None + odeint = None + AutoProcessor = None + BatchFeature = None + StaticCache = None + Qwen2_5_VLForConditionalGeneration = None + is_torchdynamo_compiling = None + Qwen2_5_VLConfig = None + Qwen2_5_VisionTransformerPretrainedModel = None + Qwen2_5_VLACausalLMOutputWithPast = None + Qwen2_5_VLMoEModel = None + +from .utils import ( get_wallx_normal_text, preprocesser_call, process_grounding_points, replace_action_token, ) -from lerobot.utils.constants import ACTION, OBS_STATE -logger = logging.get_logger(__name__) +logger = logging.getLogger(__name__) class SinusoidalPosEmb(nn.Module): @@ -253,7 +275,13 @@ class ActionHead(nn.Module): return self.propri_proj(proprioception) -class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): +# Conditional base: when transformers is unavailable the class still parses +# (inheriting from nn.Module) but cannot be instantiated—require_package in +# WallXPolicy.__init__ gives the user a clear error before that happens. +_Qwen2_5_VLForAction_Base = Qwen2_5_VLForConditionalGeneration if _wallx_deps_available else nn.Module + + +class Qwen2_5_VLMoEForAction(_Qwen2_5_VLForAction_Base): """ Qwen2.5 Vision-Language Mixture of Experts model for action processing. @@ -1708,6 +1736,10 @@ class WallXPolicy(PreTrainedPolicy): name = "wall_x" def __init__(self, config: WallXConfig, **kwargs): + require_package("transformers", extra="wallx") + require_package("peft", extra="wallx") + require_package("torchdiffeq", extra="wallx") + require_package("qwen-vl-utils", extra="wallx", import_name="qwen_vl_utils") super().__init__(config) config.validate_features() self.config = config diff --git a/src/lerobot/policies/wall_x/processor_wall_x.py b/src/lerobot/policies/wall_x/processor_wall_x.py index e4e281541..069cef5d6 100644 --- a/src/lerobot/policies/wall_x/processor_wall_x.py +++ b/src/lerobot/policies/wall_x/processor_wall_x.py @@ -18,8 +18,7 @@ from typing import Any import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.policies.wall_x.configuration_wall_x import WallXConfig +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.processor import ( AddBatchDimensionProcessorStep, ComplementaryDataProcessorStep, @@ -30,10 +29,13 @@ from lerobot.processor import ( ProcessorStepRegistry, RenameObservationsProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from .configuration_wall_x import WallXConfig + def make_wall_x_pre_post_processors( config: WallXConfig, diff --git a/src/lerobot/policies/wall_x/utils.py b/src/lerobot/policies/wall_x/utils.py index e08ef69d5..d38a2d509 100644 --- a/src/lerobot/policies/wall_x/utils.py +++ b/src/lerobot/policies/wall_x/utils.py @@ -25,15 +25,22 @@ import random import re from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING, Any import torch -from transformers import BatchFeature -from lerobot.policies.wall_x.constant import ( +from lerobot.utils.import_utils import _transformers_available + +if TYPE_CHECKING or _transformers_available: + from transformers import BatchFeature +else: + BatchFeature = None + +from lerobot.utils.constants import OBS_IMAGES + +from .constant import ( CAMERA_NAME_MAPPING, ) -from lerobot.utils.constants import OBS_IMAGES @dataclass diff --git a/src/lerobot/policies/xvla/__init__.py b/src/lerobot/policies/xvla/__init__.py index 71b04e76f..58609e91c 100644 --- a/src/lerobot/policies/xvla/__init__.py +++ b/src/lerobot/policies/xvla/__init__.py @@ -1,6 +1,15 @@ -# register the processor steps -from lerobot.policies.xvla.processor_xvla import ( +from .configuration_xvla import XVLAConfig +from .modeling_xvla import XVLAPolicy +from .processor_xvla import ( XVLAAddDomainIdProcessorStep, XVLAImageNetNormalizeProcessorStep, XVLAImageToFloatProcessorStep, ) + +__all__ = [ + "XVLAConfig", + "XVLAPolicy", + "XVLAAddDomainIdProcessorStep", + "XVLAImageNetNormalizeProcessorStep", + "XVLAImageToFloatProcessorStep", +] diff --git a/src/lerobot/policies/xvla/configuration_xvla.py b/src/lerobot/policies/xvla/configuration_xvla.py index 30700b042..614c9a944 100644 --- a/src/lerobot/policies/xvla/configuration_xvla.py +++ b/src/lerobot/policies/xvla/configuration_xvla.py @@ -21,10 +21,8 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import XVLAAdamWConfig -from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig +from lerobot.optim import CosineDecayWithWarmupSchedulerConfig, XVLAAdamWConfig from lerobot.utils.constants import OBS_IMAGES # Conditional import for type checking and lazy loading diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index 0436ae527..04e923fdd 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -23,22 +23,30 @@ import logging import os from collections import deque from pathlib import Path +from typing import TYPE_CHECKING import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from lerobot.configs.policies import PreTrainedConfig -from lerobot.policies.pretrained import PreTrainedPolicy, T -from lerobot.policies.utils import populate_queues +from lerobot.configs import PreTrainedConfig from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE +from lerobot.utils.import_utils import _transformers_available, require_package +from ..pretrained import PreTrainedPolicy, T +from ..utils import populate_queues from .action_hub import build_action_space -from .configuration_florence2 import Florence2Config from .configuration_xvla import XVLAConfig -from .modeling_florence2 import Florence2ForConditionalGeneration from .soft_transformer import SoftPromptedTransformer +# Florence2 config and modeling depend on transformers +if TYPE_CHECKING or _transformers_available: + from .configuration_florence2 import Florence2Config + from .modeling_florence2 import Florence2ForConditionalGeneration +else: + Florence2Config = None + Florence2ForConditionalGeneration = None + class XVLAModel(nn.Module): """ @@ -274,6 +282,7 @@ class XVLAPolicy(PreTrainedPolicy): name = "xvla" def __init__(self, config: XVLAConfig, **kwargs): + require_package("transformers", extra="xvla") super().__init__(config) config.validate_features() florence_config = config.get_florence_config() diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index 0fa9ffe3f..0336ec722 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -20,10 +20,7 @@ from typing import Any import numpy as np import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.datasets.factory import IMAGENET_STATS -from lerobot.policies.xvla.configuration_xvla import XVLAConfig -from lerobot.policies.xvla.utils import rotate6d_to_axis_angle +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, @@ -36,10 +33,12 @@ from lerobot.processor import ( RenameObservationsProcessorStep, TokenizerProcessorStep, UnnormalizerProcessorStep, + policy_action_to_transition, + transition_to_policy_action, ) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( + IMAGENET_STATS, OBS_IMAGES, OBS_PREFIX, OBS_STATE, @@ -47,6 +46,9 @@ from lerobot.utils.constants import ( POLICY_PREPROCESSOR_DEFAULT_NAME, ) +from .configuration_xvla import XVLAConfig +from .utils import rotate6d_to_axis_angle + def make_xvla_pre_post_processors( config: XVLAConfig, diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 122b3533c..3688a4b8c 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -27,10 +27,20 @@ from .batch_processor import AddBatchDimensionProcessorStep from .converters import ( batch_to_transition, create_transition, + from_tensor_to_numpy, + identity_transition, + observation_to_transition, + policy_action_to_transition, + robot_action_observation_to_transition, + robot_action_to_transition, transition_to_batch, + transition_to_observation, + transition_to_policy_action, + transition_to_robot_action, ) from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep from .device_processor import DeviceProcessorStep +from .env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep from .factory import ( make_default_processors, make_default_robot_action_processor, @@ -51,6 +61,7 @@ from .hil_processor import ( RewardClassifierProcessorStep, TimeLimitProcessorStep, ) +from .newline_task_processor import NewLineTaskProcessorStep from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats from .observation_processor import VanillaObservationProcessorStep from .pipeline import ( @@ -81,7 +92,7 @@ from .relative_action_processor import ( to_absolute_actions, to_relative_actions, ) -from .rename_processor import RenameObservationsProcessorStep +from .rename_processor import RenameObservationsProcessorStep, rename_stats from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep __all__ = [ @@ -91,6 +102,15 @@ __all__ = [ "ComplementaryDataProcessorStep", "batch_to_transition", "create_transition", + "from_tensor_to_numpy", + "identity_transition", + "observation_to_transition", + "policy_action_to_transition", + "robot_action_observation_to_transition", + "robot_action_to_transition", + "transition_to_observation", + "transition_to_policy_action", + "transition_to_robot_action", "DeviceProcessorStep", "DoneProcessorStep", "EnvAction", @@ -110,6 +130,7 @@ __all__ = [ "RelativeActionsProcessorStep", "MapDeltaActionToRobotActionStep", "MapTensorToDeltaActionDictStep", + "NewLineTaskProcessorStep", "NormalizerProcessorStep", "Numpy2TorchActionProcessorStep", "ObservationProcessorStep", @@ -122,10 +143,13 @@ __all__ = [ "RobotAction", "RobotActionProcessorStep", "RobotObservation", + "rename_stats", "RenameObservationsProcessorStep", "RewardClassifierProcessorStep", "RewardProcessorStep", "DataProcessorPipeline", + "IsaaclabArenaProcessorStep", + "LiberoProcessorStep", "TimeLimitProcessorStep", "AddBatchDimensionProcessorStep", "RobotProcessorPipeline", diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index c904acf84..eb7db255a 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -24,7 +24,7 @@ from dataclasses import dataclass, field from torch import Tensor -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.types import EnvTransition, PolicyAction from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index f7f5676ac..86b2feec1 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -16,7 +16,7 @@ from dataclasses import dataclass -from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.types import PolicyAction, RobotAction from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 36c80e58e..1171c7e78 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -24,7 +24,7 @@ from typing import Any import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.types import EnvTransition, PolicyAction, TransitionKey from lerobot.utils.device_utils import get_safe_torch_device diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index a77e066cf..75cbb79de 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -17,7 +17,7 @@ from dataclasses import dataclass import torch -from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index e756ded7f..2ec5f6e64 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -16,8 +16,8 @@ from dataclasses import dataclass -from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.types import EnvAction, EnvTransition, PolicyAction +from lerobot.configs import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvAction, EnvTransition, PolicyAction, TransitionKey from .converters import to_tensor from .hil_processor import TELEOP_ACTION_KEY @@ -75,8 +75,6 @@ class Numpy2TorchActionProcessorStep(ProcessorStep): def __call__(self, transition: EnvTransition) -> EnvTransition: """Converts numpy action to torch tensor if action exists, otherwise passes through.""" - from lerobot.types import TransitionKey - self._current_transition = transition.copy() new_transition = self._current_transition diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 0b8521c2b..c6f98c689 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -24,7 +24,7 @@ import numpy as np import torch import torchvision.transforms.functional as F # noqa: N812 -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.teleoperators.utils import TeleopEvents if TYPE_CHECKING: diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 525b7431c..37df4be41 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -57,8 +57,8 @@ import torch from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import load_file as load_safetensors -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors +from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature +from lerobot.policies import get_policy_class, make_policy_config, make_pre_post_processors from lerobot.utils.constants import ACTION diff --git a/src/lerobot/processor/newline_task_processor.py b/src/lerobot/processor/newline_task_processor.py new file mode 100644 index 000000000..ea61bdd71 --- /dev/null +++ b/src/lerobot/processor/newline_task_processor.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.configs import PipelineFeatureType, PolicyFeature + +from .pipeline import ComplementaryDataProcessorStep, ProcessorStepRegistry + + +# NOTE: The registry name "smolvla_new_line_processor" is kept for backward compatibility +# with serialized processor configs that reference this name. +@ProcessorStepRegistry.register(name="smolvla_new_line_processor") +class NewLineTaskProcessorStep(ComplementaryDataProcessorStep): + """ + A processor step that ensures the 'task' description ends with a newline character. + + This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a + newline at the end of the prompt. It handles both single string tasks and lists + of string tasks. + """ + + def complementary_data(self, complementary_data): + if "task" not in complementary_data: + return complementary_data + + task = complementary_data["task"] + if task is None: + return complementary_data + + new_complementary_data = dict(complementary_data) + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: add newline if not present + if not task.endswith("\n"): + new_complementary_data["task"] = f"{task}\n" + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: add newline to each if not present + new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + # If task is neither string nor list of strings, leave unchanged + + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 8a7a1176a..7516c7b47 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -19,14 +19,17 @@ from __future__ import annotations from copy import deepcopy from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING, Any import torch from torch import Tensor -from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.configs import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.types import EnvTransition, PolicyAction, TransitionKey + +if TYPE_CHECKING: + from lerobot.datasets import LeRobotDataset + from lerobot.utils.constants import ACTION from .converters import from_tensor_to_numpy, to_tensor diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index d22d8fb96..12d1f82a2 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -20,7 +20,7 @@ import numpy as np import torch from torch import Tensor -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index abfb31421..2b949d5cb 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -45,8 +45,9 @@ import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file, save_file -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.types import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey +from lerobot.utils.constants import HF_LEROBOT_HOME from lerobot.utils.hub import HubMixin from .converters import batch_to_transition, create_transition, transition_to_batch @@ -422,8 +423,6 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): """ if save_directory is None: # Use default directory in HF_LEROBOT_HOME - from lerobot.utils.constants import HF_LEROBOT_HOME - sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py index 25887d414..25d622dc2 100644 --- a/src/lerobot/processor/policy_robot_bridge.py +++ b/src/lerobot/processor/policy_robot_bridge.py @@ -19,10 +19,12 @@ from typing import Any import torch -from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.types import PolicyAction, RobotAction from lerobot.utils.constants import ACTION +from .pipeline import ActionProcessorStep, ProcessorStepRegistry + @dataclass @ProcessorStepRegistry.register("robot_action_to_policy_action_processor") diff --git a/src/lerobot/processor/relative_action_processor.py b/src/lerobot/processor/relative_action_processor.py index e00d26e98..d9f97f2c6 100644 --- a/src/lerobot/processor/relative_action_processor.py +++ b/src/lerobot/processor/relative_action_processor.py @@ -19,7 +19,7 @@ from typing import Any import torch from torch import Tensor -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs import PipelineFeatureType, PolicyFeature from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import OBS_STATE diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 6cae5921f..5ffec6868 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -17,7 +17,7 @@ from copy import deepcopy from dataclasses import dataclass, field from typing import Any -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs import PipelineFeatureType, PolicyFeature from .pipeline import ObservationProcessorStep, ProcessorStepRegistry diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 0b5305dcf..a808e6127 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any import torch -from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.types import EnvTransition, RobotObservation, TransitionKey from lerobot.utils.constants import ( ACTION_TOKEN_MASK, diff --git a/src/lerobot/rl/__init__.py b/src/lerobot/rl/__init__.py new file mode 100644 index 000000000..6a7c750d3 --- /dev/null +++ b/src/lerobot/rl/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Reinforcement learning modules. + +Requires: ``pip install 'lerobot[hilserl]'`` + +Available modules (import directly):: + + from lerobot.rl.actor import ... + from lerobot.rl.learner import ... + from lerobot.rl.learner_service import ... + from lerobot.rl.buffer import ... + from lerobot.rl.eval_policy import ... + from lerobot.rl.gym_manipulator import ... +""" + +from lerobot.utils.import_utils import require_package + +require_package("grpcio", extra="hilserl", import_name="grpc") + +__all__: list[str] = [] diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 18c0ca1ea..0d785bde3 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -60,10 +60,8 @@ from torch.multiprocessing import Event, Queue from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.policies.factory import make_policy +from lerobot.policies import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy -from lerobot.rl.process import ProcessSignalHandler -from lerobot.rl.queue import get_last_item_from_queue from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents @@ -96,6 +94,8 @@ from .gym_manipulator import ( make_robot_env, step_env_and_process_transition, ) +from .process import ProcessSignalHandler +from .queue import get_last_item_from_queue # Main entry point diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 68954162d..97aaa9caa 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -23,7 +23,7 @@ import torch import torch.nn.functional as F # noqa: N812 from tqdm import tqdm -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, REWARD from lerobot.utils.transition import Transition diff --git a/src/lerobot/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py index 4345fed3c..b6bde2273 100644 --- a/src/lerobot/rl/crop_dataset_roi.py +++ b/src/lerobot/rl/crop_dataset_roi.py @@ -24,7 +24,7 @@ import torch import torchvision.transforms.functional as F # type: ignore # noqa: N812 from tqdm import tqdm # type: ignore -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset from lerobot.utils.constants import DONE, REWARD diff --git a/src/lerobot/rl/eval_policy.py b/src/lerobot/rl/eval_policy.py index fb2504f2a..4398351c5 100644 --- a/src/lerobot/rl/eval_policy.py +++ b/src/lerobot/rl/eval_policy.py @@ -18,8 +18,8 @@ import logging from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.factory import make_policy +from lerobot.datasets import LeRobotDataset +from lerobot.policies import make_policy from lerobot.robots import ( # noqa: F401 RobotConfig, make_robot_from_config, diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index bd64d205f..b6ff7155a 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -25,9 +25,9 @@ import torch from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.envs.configs import HILSerlRobotEnvConfig -from lerobot.model.kinematics import RobotKinematics +from lerobot.datasets import LeRobotDataset +from lerobot.envs import HILSerlRobotEnvConfig +from lerobot.model import RobotKinematics from lerobot.processor import ( AddBatchDimensionProcessorStep, AddTeleopActionAsComplimentaryDataStep, @@ -50,8 +50,8 @@ from lerobot.processor import ( TransitionKey, VanillaObservationProcessorStep, create_transition, + identity_transition, ) -from lerobot.processor.converters import identity_transition from lerobot.robots import ( # noqa: F401 RobotConfig, make_robot_from_config, diff --git a/src/lerobot/rl/joint_observations_processor.py b/src/lerobot/rl/joint_observations_processor.py index 2fbcc7c46..dc677e26c 100644 --- a/src/lerobot/rl/joint_observations_processor.py +++ b/src/lerobot/rl/joint_observations_processor.py @@ -19,8 +19,8 @@ from typing import Any import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.processor.pipeline import ( +from lerobot.configs import PipelineFeatureType, PolicyFeature +from lerobot.processor import ( ObservationProcessorStep, ProcessorStepRegistry, ) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 2853fbcb3..073d9a65f 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -60,15 +60,18 @@ from torch.multiprocessing import Queue from torch.optim.optimizer import Optimizer from lerobot.cameras import opencv # noqa: F401 +from lerobot.common.train_utils import ( + get_step_checkpoint_dir, + load_training_state as utils_load_training_state, + save_checkpoint, + update_last_checkpoint, +) +from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.datasets.factory import make_dataset -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.factory import make_policy +from lerobot.datasets import LeRobotDataset, make_dataset +from lerobot.policies import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy -from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions -from lerobot.rl.process import ProcessSignalHandler -from lerobot.rl.wandb_utils import WandBLogger from lerobot.robots import so_follower # noqa: F401 from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents @@ -88,19 +91,15 @@ from lerobot.utils.constants import ( ) from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.random_utils import set_seed -from lerobot.utils.train_utils import ( - get_step_checkpoint_dir, - load_training_state as utils_load_training_state, - save_checkpoint, - update_last_checkpoint, -) from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.utils.utils import ( format_big_number, init_logging, ) +from .buffer import ReplayBuffer, concatenate_batch_transitions from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService +from .process import ProcessSignalHandler @parser.wrap() @@ -152,7 +151,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None): # Setup WandB logging if enabled if cfg.wandb.enable and cfg.wandb.project: - from lerobot.rl.wandb_utils import WandBLogger + from lerobot.common.wandb_utils import WandBLogger wandb_logger = WandBLogger(cfg) else: diff --git a/src/lerobot/rl/learner_service.py b/src/lerobot/rl/learner_service.py index 7ef38119b..4128cdf55 100644 --- a/src/lerobot/rl/learner_service.py +++ b/src/lerobot/rl/learner_service.py @@ -19,10 +19,11 @@ import logging import time from multiprocessing import Event, Queue -from lerobot.rl.queue import get_last_item_from_queue from lerobot.transport import services_pb2, services_pb2_grpc from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks +from .queue import get_last_item_from_queue + MAX_WORKERS = 3 # Stream parameters, send transitions and interactions SHUTDOWN_TIMEOUT = 10 diff --git a/src/lerobot/robots/__init__.py b/src/lerobot/robots/__init__.py index 1dba0f1b0..eb8b06fb8 100644 --- a/src/lerobot/robots/__init__.py +++ b/src/lerobot/robots/__init__.py @@ -17,3 +17,5 @@ from .config import RobotConfig from .robot import Robot from .utils import make_robot_from_config + +__all__ = ["Robot", "RobotConfig", "make_robot_from_config"] diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py index c48ac5934..c27398278 100644 --- a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -17,10 +17,10 @@ import logging from functools import cached_property -from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig from ..robot import Robot from .config_bi_openarm_follower import BiOpenArmFollowerConfig diff --git a/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py index ef5d70cab..9ed56aeac 100644 --- a/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py +++ b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py @@ -17,9 +17,9 @@ from dataclasses import dataclass, field from lerobot.cameras import CameraConfig -from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase from ..config import RobotConfig +from ..openarm_follower import OpenArmFollowerConfigBase @RobotConfig.register_subclass("bi_openarm_follower") diff --git a/src/lerobot/robots/bi_so_follower/__init__.py b/src/lerobot/robots/bi_so_follower/__init__.py index f631a14db..1d63dcb2c 100644 --- a/src/lerobot/robots/bi_so_follower/__init__.py +++ b/src/lerobot/robots/bi_so_follower/__init__.py @@ -16,3 +16,5 @@ from .bi_so_follower import BiSOFollower from .config_bi_so_follower import BiSOFollowerConfig + +__all__ = ["BiSOFollower", "BiSOFollowerConfig"] diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index ba1826e29..f592150a6 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -17,11 +17,11 @@ import logging from functools import cached_property -from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot +from ..so_follower import SOFollower, SOFollowerRobotConfig from .config_bi_so_follower import BiSOFollowerConfig logger = logging.getLogger(__name__) diff --git a/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py b/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py index dca74fa2d..97afbab4f 100644 --- a/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py @@ -16,9 +16,8 @@ from dataclasses import dataclass -from lerobot.robots.so_follower import SOFollowerConfig - from ..config import RobotConfig +from ..so_follower import SOFollowerConfig @RobotConfig.register_subclass("bi_so_follower") diff --git a/src/lerobot/robots/hope_jr/__init__.py b/src/lerobot/robots/hope_jr/__init__.py index 26603ebb0..94fcf86e4 100644 --- a/src/lerobot/robots/hope_jr/__init__.py +++ b/src/lerobot/robots/hope_jr/__init__.py @@ -17,3 +17,5 @@ from .config_hope_jr import HopeJrArmConfig, HopeJrHandConfig from .hope_jr_arm import HopeJrArm from .hope_jr_hand import HopeJrHand + +__all__ = ["HopeJrArm", "HopeJrArmConfig", "HopeJrHand", "HopeJrHandConfig"] diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index 7f6492ef0..4918bcae3 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -18,7 +18,7 @@ import logging import time from functools import cached_property -from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.cameras import make_cameras_from_configs from lerobot.motors import Motor, MotorNormMode from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index 784804836..566628724 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -18,7 +18,7 @@ import logging import time from functools import cached_property -from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.cameras import make_cameras_from_configs from lerobot.motors import Motor, MotorNormMode from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( diff --git a/src/lerobot/robots/koch_follower/__init__.py b/src/lerobot/robots/koch_follower/__init__.py index 6271c4e55..8f4435924 100644 --- a/src/lerobot/robots/koch_follower/__init__.py +++ b/src/lerobot/robots/koch_follower/__init__.py @@ -16,3 +16,5 @@ from .config_koch_follower import KochFollowerConfig from .koch_follower import KochFollower + +__all__ = ["KochFollower", "KochFollowerConfig"] diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index 44e83f6a3..3f40ac738 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -18,7 +18,7 @@ import logging import time from functools import cached_property -from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.cameras import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.dynamixel import ( DynamixelMotorsBus, diff --git a/src/lerobot/robots/lekiwi/__init__.py b/src/lerobot/robots/lekiwi/__init__.py index ada2ff368..3d2191242 100644 --- a/src/lerobot/robots/lekiwi/__init__.py +++ b/src/lerobot/robots/lekiwi/__init__.py @@ -17,3 +17,5 @@ from .config_lekiwi import LeKiwiClientConfig, LeKiwiConfig from .lekiwi import LeKiwi from .lekiwi_client import LeKiwiClient + +__all__ = ["LeKiwi", "LeKiwiClient", "LeKiwiClientConfig", "LeKiwiConfig"] diff --git a/src/lerobot/robots/lekiwi/config_lekiwi.py b/src/lerobot/robots/lekiwi/config_lekiwi.py index acaf5f0ec..51fa8f03f 100644 --- a/src/lerobot/robots/lekiwi/config_lekiwi.py +++ b/src/lerobot/robots/lekiwi/config_lekiwi.py @@ -14,8 +14,8 @@ from dataclasses import dataclass, field -from lerobot.cameras.configs import CameraConfig, Cv2Rotation -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.cameras import CameraConfig, Cv2Rotation +from lerobot.cameras.opencv import OpenCVCameraConfig from ..config import RobotConfig diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 60fac89e5..b73ebeab9 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -22,7 +22,7 @@ from typing import Any import numpy as np -from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.cameras import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( FeetechMotorsBus, diff --git a/src/lerobot/robots/omx_follower/__init__.py b/src/lerobot/robots/omx_follower/__init__.py index db48dffe9..328ac8d80 100644 --- a/src/lerobot/robots/omx_follower/__init__.py +++ b/src/lerobot/robots/omx_follower/__init__.py @@ -19,3 +19,5 @@ from .config_omx_follower import OmxFollowerConfig from .omx_follower import OmxFollower + +__all__ = ["OmxFollower", "OmxFollowerConfig"] diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index 5d161daa2..c30eec97a 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -18,7 +18,7 @@ import logging import time from functools import cached_property -from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.cameras import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.dynamixel import ( DriveMode, diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index 99e8b920b..4d1765f07 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -19,7 +19,7 @@ import time from functools import cached_property from typing import Any -from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.cameras import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus from lerobot.types import RobotAction, RobotObservation diff --git a/src/lerobot/robots/reachy2/__init__.py b/src/lerobot/robots/reachy2/__init__.py index 1a38fd03b..b7afd006d 100644 --- a/src/lerobot/robots/reachy2/__init__.py +++ b/src/lerobot/robots/reachy2/__init__.py @@ -23,3 +23,13 @@ from .robot_reachy2 import ( REACHY2_VEL, Reachy2Robot, ) + +__all__ = [ + "REACHY2_ANTENNAS_JOINTS", + "REACHY2_L_ARM_JOINTS", + "REACHY2_NECK_JOINTS", + "REACHY2_R_ARM_JOINTS", + "REACHY2_VEL", + "Reachy2Robot", + "Reachy2RobotConfig", +] diff --git a/src/lerobot/robots/reachy2/configuration_reachy2.py b/src/lerobot/robots/reachy2/configuration_reachy2.py index 63293e675..8cb67a495 100644 --- a/src/lerobot/robots/reachy2/configuration_reachy2.py +++ b/src/lerobot/robots/reachy2/configuration_reachy2.py @@ -14,8 +14,7 @@ from dataclasses import dataclass, field -from lerobot.cameras import CameraConfig -from lerobot.cameras.configs import ColorMode +from lerobot.cameras import CameraConfig, ColorMode from lerobot.cameras.reachy2_camera import Reachy2CameraConfig from ..config import RobotConfig diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index 5227a096a..ef55f71b9 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -18,7 +18,7 @@ from __future__ import annotations import time from typing import TYPE_CHECKING, Any -from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.cameras import make_cameras_from_configs from lerobot.types import RobotAction, RobotObservation from lerobot.utils.import_utils import _reachy2_sdk_available diff --git a/src/lerobot/robots/so_follower/__init__.py b/src/lerobot/robots/so_follower/__init__.py index eea2fcbdf..45de205a8 100644 --- a/src/lerobot/robots/so_follower/__init__.py +++ b/src/lerobot/robots/so_follower/__init__.py @@ -21,3 +21,13 @@ from .config_so_follower import ( SOFollowerRobotConfig, ) from .so_follower import SO100Follower, SO101Follower, SOFollower + +__all__ = [ + "SO100Follower", + "SO100FollowerConfig", + "SO101Follower", + "SO101FollowerConfig", + "SOFollower", + "SOFollowerConfig", + "SOFollowerRobotConfig", +] diff --git a/src/lerobot/robots/so_follower/robot_kinematic_processor.py b/src/lerobot/robots/so_follower/robot_kinematic_processor.py index 2aa60e12a..8114fdc2c 100644 --- a/src/lerobot/robots/so_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so_follower/robot_kinematic_processor.py @@ -19,8 +19,8 @@ from typing import Any import numpy as np -from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.model.kinematics import RobotKinematics +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.model import RobotKinematics from lerobot.processor import ( EnvTransition, ObservationProcessorStep, diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index ca132d102..0651f566c 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -18,7 +18,7 @@ import logging import time from functools import cached_property -from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.cameras import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( FeetechMotorsBus, diff --git a/src/lerobot/robots/unitree_g1/gr00t_locomotion.py b/src/lerobot/robots/unitree_g1/gr00t_locomotion.py index 31166e123..12fe26073 100644 --- a/src/lerobot/robots/unitree_g1/gr00t_locomotion.py +++ b/src/lerobot/robots/unitree_g1/gr00t_locomotion.py @@ -21,7 +21,7 @@ import numpy as np import onnxruntime as ort from huggingface_hub import hf_hub_download -from lerobot.robots.unitree_g1.g1_utils import ( +from .g1_utils import ( REMOTE_AXES, REMOTE_BUTTONS, G1_29_JointIndex, diff --git a/src/lerobot/robots/unitree_g1/holosoma_locomotion.py b/src/lerobot/robots/unitree_g1/holosoma_locomotion.py index 857bb97bc..3d3bccbdc 100644 --- a/src/lerobot/robots/unitree_g1/holosoma_locomotion.py +++ b/src/lerobot/robots/unitree_g1/holosoma_locomotion.py @@ -22,7 +22,7 @@ import onnx import onnxruntime as ort from huggingface_hub import hf_hub_download -from lerobot.robots.unitree_g1.g1_utils import ( +from .g1_utils import ( REMOTE_AXES, G1_29_JointArmIndex, G1_29_JointIndex, diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 9e373c05f..785861a5a 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -25,9 +25,14 @@ from typing import TYPE_CHECKING, Protocol, runtime_checkable import numpy as np -from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK -from lerobot.robots.unitree_g1.g1_utils import ( +from lerobot.cameras import make_cameras_from_configs +from lerobot.types import RobotAction, RobotObservation +from lerobot.utils.import_utils import _unitree_sdk_available + +from ..robot import Robot +from .config_unitree_g1 import UnitreeG1Config +from .g1_kinematics import G1_29_ArmIK +from .g1_utils import ( REMOTE_AXES, REMOTE_KEYS, G1_29_JointArmIndex, @@ -35,11 +40,6 @@ from lerobot.robots.unitree_g1.g1_utils import ( default_remote_input, make_locomotion_controller, ) -from lerobot.types import RobotAction, RobotObservation -from lerobot.utils.import_utils import _unitree_sdk_available - -from ..robot import Robot -from .config_unitree_g1 import UnitreeG1Config if TYPE_CHECKING or _unitree_sdk_available: from unitree_sdk2py.core.channel import ( @@ -127,7 +127,7 @@ class UnitreeG1(Robot): self._ChannelPublisher = _SDKChannelPublisher self._ChannelSubscriber = _SDKChannelSubscriber else: - from lerobot.robots.unitree_g1.unitree_sdk2_socket import ( + from .unitree_sdk2_socket import ( ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber, @@ -290,7 +290,7 @@ class UnitreeG1(Robot): def connect(self, calibrate: bool = True) -> None: # connect to DDS # Initialize DDS channel and simulation environment if self.config.is_simulation: - from lerobot.envs.factory import make_env + from lerobot.envs import make_env self._ChannelFactoryInitialize(0, "lo") self._env_wrapper = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) diff --git a/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py b/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py index 0f1f8f8d6..4f0b787aa 100644 --- a/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py +++ b/src/lerobot/robots/unitree_g1/unitree_sdk2_socket.py @@ -20,7 +20,7 @@ from typing import Any import zmq -from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config +from .config_unitree_g1 import UnitreeG1Config # Module-level ZMQ state mirrors the Unitree SDK's global ChannelFactory Singleton. # Only one robot connection per process is supported. diff --git a/src/lerobot/scripts/augment_dataset_quantile_stats.py b/src/lerobot/scripts/augment_dataset_quantile_stats.py index 4d80c9332..4ee99a541 100644 --- a/src/lerobot/scripts/augment_dataset_quantile_stats.py +++ b/src/lerobot/scripts/augment_dataset_quantile_stats.py @@ -44,10 +44,14 @@ from huggingface_hub import HfApi from requests import HTTPError from tqdm import tqdm -from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats -from lerobot.datasets.dataset_metadata import CODEBASE_VERSION -from lerobot.datasets.io_utils import write_stats -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import ( + CODEBASE_VERSION, + DEFAULT_QUANTILES, + LeRobotDataset, + aggregate_stats, + get_feature_stats, + write_stats, +) from lerobot.utils.utils import init_logging diff --git a/src/lerobot/scripts/convert_dataset_v21_to_v30.py b/src/lerobot/scripts/convert_dataset_v21_to_v30.py index 2b6dcf732..59e635712 100644 --- a/src/lerobot/scripts/convert_dataset_v21_to_v30.py +++ b/src/lerobot/scripts/convert_dataset_v21_to_v30.py @@ -51,6 +51,10 @@ import shutil from pathlib import Path from typing import Any +from lerobot.utils.import_utils import require_package + +require_package("jsonlines", extra="dataset") + import jsonlines import pandas as pd import pyarrow as pa @@ -59,8 +63,7 @@ from datasets import Dataset, Features, Image from huggingface_hub import HfApi, snapshot_download from requests import HTTPError -from lerobot.datasets.compute_stats import aggregate_stats -from lerobot.datasets.dataset_metadata import CODEBASE_VERSION +from lerobot.datasets import CODEBASE_VERSION, LeRobotDataset, aggregate_stats from lerobot.datasets.io_utils import ( cast_stats_to_numpy, get_file_size_in_mb, @@ -72,7 +75,6 @@ from lerobot.datasets.io_utils import ( write_stats, write_tasks, ) -from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -82,12 +84,11 @@ from lerobot.datasets.utils import ( LEGACY_EPISODES_PATH, LEGACY_EPISODES_STATS_PATH, LEGACY_TASKS_PATH, - flatten_dict, update_chunk_file_indices, ) from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s from lerobot.utils.constants import HF_LEROBOT_HOME -from lerobot.utils.utils import init_logging +from lerobot.utils.utils import flatten_dict, init_logging V21 = "v2.1" V30 = "v3.0" diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 242067978..e68d7438b 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -15,6 +15,8 @@ """ Helper to recalibrate your device (robot or teleoperator). +Requires: pip install 'lerobot[hardware]' + Example: ```shell @@ -31,8 +33,8 @@ from pprint import pformat import draccus -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 -from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401 from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index c4b676c67..d07a2767d 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -15,6 +15,8 @@ # limitations under the License. """ Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. +Requires: pip install 'lerobot[dataset_viz]' (includes dataset + viz extras) + Note: The last frame of the episode doesn't always correspond to a final state. That's because our datasets are composed of transition from state to state up to the antepenultimate state associated to the ultimate action to arrive in the final state. @@ -66,12 +68,11 @@ import time from pathlib import Path import numpy as np -import rerun as rr import torch import torch.utils.data import tqdm -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD from lerobot.utils.utils import init_logging @@ -117,6 +118,11 @@ def visualize_dataset( if mode not in ["local", "distant"]: raise ValueError(mode) + from lerobot.utils.import_utils import require_package + + require_package("rerun-sdk", extra="viz", import_name="rerun") + import rerun as rr + spawn_local_viewer = mode == "local" and not save rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer) diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index db06f90c6..0cfb34325 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -17,6 +17,8 @@ """ Edit LeRobot datasets using various transformation tools. +Requires: pip install 'lerobot[dataset]' + This script allows you to delete episodes, split datasets, merge datasets, remove features, modify tasks, recompute stats, and convert image datasets to video format. When new_repo_id is specified, creates a new dataset. @@ -178,7 +180,8 @@ from pathlib import Path import draccus from lerobot.configs import parser -from lerobot.datasets.dataset_tools import ( +from lerobot.datasets import ( + LeRobotDataset, convert_image_to_video_dataset, delete_episodes, merge_datasets, @@ -187,7 +190,6 @@ from lerobot.datasets.dataset_tools import ( remove_feature, split_dataset, ) -from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.constants import HF_LEROBOT_HOME from lerobot.utils.utils import init_logging diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index cd912280f..d45483d21 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -15,6 +15,9 @@ # limitations under the License. """Evaluate a policy on an environment by running rollouts and computing metrics. +Requires: pip install 'lerobot[evaluation]' plus the policy extra (e.g. lerobot[pi]) + and the environment extra (e.g. lerobot[pusht]) if evaluating in simulation. + Usage examples: You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_pusht) @@ -71,14 +74,14 @@ from tqdm import trange from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig -from lerobot.envs.factory import make_env, make_env_pre_post_processors -from lerobot.envs.utils import ( +from lerobot.envs import ( check_env_attributes_and_types, close_envs, + make_env, + make_env_pre_post_processors, preprocess_observation, ) -from lerobot.policies.factory import make_policy, make_pre_post_processors -from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.processor import PolicyProcessorPipeline from lerobot.types import PolicyAction from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD diff --git a/src/lerobot/scripts/lerobot_find_cameras.py b/src/lerobot/scripts/lerobot_find_cameras.py index 0248a2768..72f4096da 100644 --- a/src/lerobot/scripts/lerobot_find_cameras.py +++ b/src/lerobot/scripts/lerobot_find_cameras.py @@ -37,11 +37,9 @@ from typing import Any import numpy as np from PIL import Image -from lerobot.cameras.configs import ColorMode -from lerobot.cameras.opencv.camera_opencv import OpenCVCamera -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.cameras.realsense.camera_realsense import RealSenseCamera -from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig +from lerobot.cameras import ColorMode +from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig +from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig logger = logging.getLogger(__name__) diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index bcb93ba12..c4f867631 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -41,7 +41,7 @@ from dataclasses import dataclass import draccus import numpy as np -from lerobot.model.kinematics import RobotKinematics +from lerobot.model import RobotKinematics from lerobot.robots import ( # noqa: F401 RobotConfig, bi_openarm_follower, diff --git a/src/lerobot/scripts/lerobot_find_port.py b/src/lerobot/scripts/lerobot_find_port.py index e32b9cb99..93065c473 100644 --- a/src/lerobot/scripts/lerobot_find_port.py +++ b/src/lerobot/scripts/lerobot_find_port.py @@ -28,7 +28,10 @@ from pathlib import Path def find_available_ports(): - from serial.tools import list_ports # Part of pyserial library + from lerobot.utils.import_utils import require_package + + require_package("pyserial", extra="hardware", import_name="serial") + from serial.tools import list_ports if platform.system() == "Windows": # List COM ports using pyserial diff --git a/src/lerobot/scripts/lerobot_imgtransform_viz.py b/src/lerobot/scripts/lerobot_imgtransform_viz.py index bc13f0508..7cd4c782d 100644 --- a/src/lerobot/scripts/lerobot_imgtransform_viz.py +++ b/src/lerobot/scripts/lerobot_imgtransform_viz.py @@ -35,9 +35,9 @@ from pathlib import Path import draccus from torchvision.transforms import ToPILImage -from lerobot.configs.default import DatasetConfig -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.transforms import ( +from lerobot.configs import DatasetConfig +from lerobot.datasets import LeRobotDataset +from lerobot.transforms import ( ImageTransforms, ImageTransformsConfig, make_transform_from_config, diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index c58f8f103..fa92a296d 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -15,6 +15,8 @@ """ Records a dataset. Actions for the robot can be either generated by teleoperation or by a policy. +Requires: pip install 'lerobot[core_scripts]' (includes dataset + hardware + viz extras) + Example: ```shell @@ -76,24 +78,33 @@ from typing import Any import torch -from lerobot.cameras import ( # noqa: F401 - CameraConfig, # noqa: F401 +from lerobot.cameras import CameraConfig # noqa: F401 +from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.reachy2_camera import Reachy2CameraConfig # noqa: F401 +from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401 +from lerobot.common.control_utils import ( + init_keyboard_listener, + is_headless, + predict_action, + sanity_check_dataset_name, + sanity_check_dataset_robot_compatibility, +) +from lerobot.configs import PreTrainedConfig, parser +from lerobot.datasets import ( + LeRobotDataset, + VideoEncodingManager, + aggregate_pipeline_dataset_features, + create_initial_features, + safe_stop_image_writer, +) +from lerobot.policies import ( + ActionInterpolator, + PreTrainedPolicy, + make_policy, + make_pre_post_processors, + make_robot_action, ) -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 -from lerobot.cameras.reachy2_camera.configuration_reachy2_camera import Reachy2CameraConfig # noqa: F401 -from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 -from lerobot.configs import parser -from lerobot.configs.policies import PreTrainedConfig -from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts -from lerobot.datasets.image_writer import safe_stop_image_writer -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features -from lerobot.datasets.video_utils import VideoEncodingManager -from lerobot.policies.factory import make_policy, make_pre_post_processors -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.rtc import ActionInterpolator -from lerobot.policies.utils import make_robot_action from lerobot.processor import ( PolicyAction, PolicyProcessorPipeline, @@ -101,8 +112,8 @@ from lerobot.processor import ( RobotObservation, RobotProcessorPipeline, make_default_processors, + rename_stats, ) -from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -133,16 +144,10 @@ from lerobot.teleoperators import ( # noqa: F401 so_leader, unitree_g1, ) -from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop +from lerobot.teleoperators.keyboard import KeyboardTeleop from lerobot.utils.constants import ACTION, OBS_STR -from lerobot.utils.control_utils import ( - init_keyboard_listener, - is_headless, - predict_action, - sanity_check_dataset_name, - sanity_check_dataset_robot_compatibility, -) from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import ( diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 09e7d4e8b..41d2926cc 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -15,6 +15,8 @@ """ Replays the actions of an episode from a dataset on a robot. +Requires: pip install 'lerobot[core_scripts]' (includes dataset + hardware + viz extras) + Examples: ```shell @@ -46,7 +48,7 @@ from pathlib import Path from pprint import pformat from lerobot.configs import parser -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets import LeRobotDataset from lerobot.processor import ( make_default_robot_action_processor, ) diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index f050d572a..76157595e 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -15,6 +15,8 @@ """ Simple script to control a robot from teleoperation. +Requires: pip install 'lerobot[hardware]' + Example: ```shell @@ -56,11 +58,9 @@ import time from dataclasses import asdict, dataclass from pprint import pformat -import rerun as rr - -from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 -from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 +from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.processor import ( RobotAction, @@ -103,7 +103,7 @@ from lerobot.teleoperators import ( # noqa: F401 from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import init_logging, move_cursor_up -from lerobot.utils.visualization_utils import init_rerun, log_rerun_data +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun @dataclass @@ -240,7 +240,7 @@ def teleoperate(cfg: TeleoperateConfig): pass finally: if cfg.display_data: - rr.rerun_shutdown() + shutdown_rerun() teleop.disconnect() robot.disconnect() diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 0a7212911..a862c640d 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -13,48 +13,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Train a policy. + +Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wandb extras) +""" + import dataclasses import logging import time from contextlib import nullcontext from pprint import pformat -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from accelerate import Accelerator import torch -from accelerate import Accelerator from termcolor import colored from torch.optim import Optimizer from tqdm import tqdm -from lerobot.configs import parser -from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.factory import make_dataset -from lerobot.datasets.sampler import EpisodeAwareSampler -from lerobot.datasets.utils import cycle -from lerobot.envs.factory import make_env, make_env_pre_post_processors -from lerobot.envs.utils import close_envs -from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy, make_pre_post_processors -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.rl.wandb_utils import WandBLogger -from lerobot.scripts.lerobot_eval import eval_policy_all -from lerobot.utils.import_utils import register_third_party_plugins -from lerobot.utils.logging_utils import AverageMeter, MetricsTracker -from lerobot.utils.random_utils import set_seed -from lerobot.utils.train_utils import ( +from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, load_training_state, save_checkpoint, update_last_checkpoint, ) +from lerobot.common.wandb_utils import WandBLogger +from lerobot.configs import parser +from lerobot.configs.train import TrainPipelineConfig +from lerobot.datasets import EpisodeAwareSampler, make_dataset +from lerobot.envs import close_envs, make_env, make_env_pre_post_processors +from lerobot.optim.factory import make_optimizer_and_scheduler +from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors +from lerobot.utils.import_utils import register_third_party_plugins +from lerobot.utils.logging_utils import AverageMeter, MetricsTracker +from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( + cycle, format_big_number, has_method, init_logging, inside_slurm, ) +from .lerobot_eval import eval_policy_all + def update_policy( train_metrics: MetricsTracker, @@ -62,7 +67,7 @@ def update_policy( batch: Any, optimizer: Optimizer, grad_clip_norm: float, - accelerator: Accelerator, + accelerator: "Accelerator", lr_scheduler=None, lock=None, rabc_weights_provider=None, @@ -151,7 +156,7 @@ def update_policy( @parser.wrap() -def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): +def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): """ Main function to train a policy. @@ -167,6 +172,11 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): cfg: A `TrainPipelineConfig` object containing all training configurations. accelerator: Optional Accelerator instance. If None, one will be created automatically. """ + from lerobot.utils.import_utils import require_package + + require_package("accelerate", extra="training") + from accelerate import Accelerator + cfg.validate() # Create Accelerator if not provided diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 35c2b60cd..c821a4d54 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -60,9 +60,8 @@ if TYPE_CHECKING or _transformers_available: else: AutoProcessor = None -from lerobot.configs import parser -from lerobot.configs.types import NormalizationMode -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.configs import NormalizationMode, parser +from lerobot.datasets import LeRobotDataset from lerobot.utils.constants import ACTION, OBS_STATE diff --git a/src/lerobot/teleoperators/__init__.py b/src/lerobot/teleoperators/__init__.py index ee508dddb..d66e4b67d 100644 --- a/src/lerobot/teleoperators/__init__.py +++ b/src/lerobot/teleoperators/__init__.py @@ -17,3 +17,5 @@ from .config import TeleoperatorConfig from .teleoperator import Teleoperator from .utils import TeleopEvents, make_teleoperator_from_config + +__all__ = ["Teleoperator", "TeleoperatorConfig", "TeleopEvents", "make_teleoperator_from_config"] diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py index b44f1fbea..624729c02 100644 --- a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -17,11 +17,10 @@ import logging from functools import cached_property -from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected -from ..openarm_leader import OpenArmLeader +from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig from ..teleoperator import Teleoperator from .config_bi_openarm_leader import BiOpenArmLeaderConfig diff --git a/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py index 39fc90add..f7ec929ed 100644 --- a/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py +++ b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py @@ -16,9 +16,8 @@ from dataclasses import dataclass -from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfigBase - from ..config import TeleoperatorConfig +from ..openarm_leader import OpenArmLeaderConfigBase @TeleoperatorConfig.register_subclass("bi_openarm_leader") diff --git a/src/lerobot/teleoperators/bi_so_leader/__init__.py b/src/lerobot/teleoperators/bi_so_leader/__init__.py index b902270f9..cf78beb0c 100644 --- a/src/lerobot/teleoperators/bi_so_leader/__init__.py +++ b/src/lerobot/teleoperators/bi_so_leader/__init__.py @@ -15,3 +15,5 @@ # limitations under the License. from .bi_so_leader import BiSOLeader, BiSOLeaderConfig + +__all__ = ["BiSOLeader", "BiSOLeaderConfig"] diff --git a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py index e84ac6f50..f2e88d20a 100644 --- a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py @@ -17,10 +17,9 @@ import logging from functools import cached_property -from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected -from ..so_leader import SOLeader +from ..so_leader import SOLeader, SOLeaderTeleopConfig from ..teleoperator import Teleoperator from .config_bi_so_leader import BiSOLeaderConfig diff --git a/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py index c2f23c617..f477d0f26 100644 --- a/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py @@ -16,9 +16,8 @@ from dataclasses import dataclass -from lerobot.teleoperators.so_leader import SOLeaderConfig - from ..config import TeleoperatorConfig +from ..so_leader import SOLeaderConfig @TeleoperatorConfig.register_subclass("bi_so_leader") diff --git a/src/lerobot/teleoperators/gamepad/__init__.py b/src/lerobot/teleoperators/gamepad/__init__.py index 6f9f7fbd9..3c2709dc7 100644 --- a/src/lerobot/teleoperators/gamepad/__init__.py +++ b/src/lerobot/teleoperators/gamepad/__init__.py @@ -16,3 +16,5 @@ from .configuration_gamepad import GamepadTeleopConfig from .teleop_gamepad import GamepadTeleop + +__all__ = ["GamepadTeleop", "GamepadTeleopConfig"] diff --git a/src/lerobot/teleoperators/homunculus/__init__.py b/src/lerobot/teleoperators/homunculus/__init__.py index b3c6c0bf5..ee1544e4c 100644 --- a/src/lerobot/teleoperators/homunculus/__init__.py +++ b/src/lerobot/teleoperators/homunculus/__init__.py @@ -18,3 +18,11 @@ from .config_homunculus import HomunculusArmConfig, HomunculusGloveConfig from .homunculus_arm import HomunculusArm from .homunculus_glove import HomunculusGlove from .joints_translation import homunculus_glove_to_hope_jr_hand + +__all__ = [ + "HomunculusArm", + "HomunculusArmConfig", + "HomunculusGlove", + "HomunculusGloveConfig", + "homunculus_glove_to_hope_jr_hand", +] diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index 178eed544..225235b59 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -18,11 +18,16 @@ import logging import threading from collections import deque from pprint import pformat - -import serial +from typing import TYPE_CHECKING from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.import_utils import _serial_available, require_package + +if TYPE_CHECKING or _serial_available: + import serial +else: + serial = None # type: ignore[assignment] from lerobot.utils.utils import enter_pressed, move_cursor_up from ..teleoperator import Teleoperator @@ -40,6 +45,7 @@ class HomunculusArm(Teleoperator): name = "homunculus_arm" def __init__(self, config: HomunculusArmConfig): + require_package("pyserial", extra="hardware", import_name="serial") super().__init__(config) self.config = config self.serial = serial.Serial(config.port, config.baud_rate, timeout=1) diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index c4393d660..655bae726 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -18,17 +18,22 @@ import logging import threading from collections import deque from pprint import pformat - -import serial +from typing import TYPE_CHECKING from lerobot.motors import MotorCalibration from lerobot.motors.motors_bus import MotorNormMode -from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.import_utils import _serial_available, require_package + +if TYPE_CHECKING or _serial_available: + import serial +else: + serial = None # type: ignore[assignment] from lerobot.utils.utils import enter_pressed, move_cursor_up from ..teleoperator import Teleoperator from .config_homunculus import HomunculusGloveConfig +from .joints_translation import homunculus_glove_to_hope_jr_hand logger = logging.getLogger(__name__) @@ -66,6 +71,7 @@ class HomunculusGlove(Teleoperator): name = "homunculus_glove" def __init__(self, config: HomunculusGloveConfig): + require_package("pyserial", extra="hardware", import_name="serial") super().__init__(config) self.config = config self.serial = serial.Serial(config.port, config.baud_rate, timeout=1) diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 090aa7fae..0f1c7d7f1 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -23,6 +23,7 @@ from typing import Any from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.import_utils import _pynput_available from ..teleoperator import Teleoperator from ..utils import TeleopEvents @@ -32,20 +33,18 @@ from .configuration_keyboard import ( KeyboardTeleopConfig, ) -PYNPUT_AVAILABLE = True -try: - if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): - logging.info("No DISPLAY set. Skipping pynput import.") - raise ImportError("pynput blocked intentionally due to no display.") - - from pynput import keyboard -except ImportError: - keyboard = None - PYNPUT_AVAILABLE = False -except Exception as e: - keyboard = None - PYNPUT_AVAILABLE = False - logging.info(f"Could not import pynput: {e}") +PYNPUT_AVAILABLE = _pynput_available +keyboard = None +if PYNPUT_AVAILABLE: + try: + if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): + logging.info("No DISPLAY set. Skipping pynput import.") + PYNPUT_AVAILABLE = False + else: + from pynput import keyboard + except Exception as e: + PYNPUT_AVAILABLE = False + logging.info(f"Could not import pynput: {e}") class KeyboardTeleop(Teleoperator): diff --git a/src/lerobot/teleoperators/koch_leader/__init__.py b/src/lerobot/teleoperators/koch_leader/__init__.py index 1bf9d51db..7176649ec 100644 --- a/src/lerobot/teleoperators/koch_leader/__init__.py +++ b/src/lerobot/teleoperators/koch_leader/__init__.py @@ -16,3 +16,5 @@ from .config_koch_leader import KochLeaderConfig from .koch_leader import KochLeader + +__all__ = ["KochLeader", "KochLeaderConfig"] diff --git a/src/lerobot/teleoperators/omx_leader/__init__.py b/src/lerobot/teleoperators/omx_leader/__init__.py index 04d96d63e..259e26143 100644 --- a/src/lerobot/teleoperators/omx_leader/__init__.py +++ b/src/lerobot/teleoperators/omx_leader/__init__.py @@ -16,3 +16,5 @@ from .config_omx_leader import OmxLeaderConfig from .omx_leader import OmxLeader + +__all__ = ["OmxLeader", "OmxLeaderConfig"] diff --git a/src/lerobot/teleoperators/phone/__init__.py b/src/lerobot/teleoperators/phone/__init__.py index 2b28c1f97..2656a5014 100644 --- a/src/lerobot/teleoperators/phone/__init__.py +++ b/src/lerobot/teleoperators/phone/__init__.py @@ -16,3 +16,5 @@ from .config_phone import PhoneConfig from .teleop_phone import Phone + +__all__ = ["Phone", "PhoneConfig"] diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py index c498bed7d..3d57a5a71 100644 --- a/src/lerobot/teleoperators/phone/phone_processor.py +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -16,11 +16,12 @@ from dataclasses import dataclass, field -from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import ProcessorStepRegistry, RobotActionProcessorStep -from lerobot.teleoperators.phone.config_phone import PhoneOS from lerobot.types import RobotAction +from .config_phone import PhoneOS + @ProcessorStepRegistry.register("map_phone_action_to_robot_action") @dataclass diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py index 221ee8083..f68843194 100644 --- a/src/lerobot/teleoperators/phone/teleop_phone.py +++ b/src/lerobot/teleoperators/phone/teleop_phone.py @@ -26,11 +26,12 @@ import hebi import numpy as np from teleop import Teleop -from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS -from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.rotation import Rotation +from ..teleoperator import Teleoperator +from .config_phone import PhoneConfig, PhoneOS + logger = logging.getLogger(__name__) diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py b/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py index a07a4a6cd..aab1aec14 100644 --- a/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py +++ b/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py @@ -23,3 +23,13 @@ from .reachy2_teleoperator import ( REACHY2_VEL, Reachy2Teleoperator, ) + +__all__ = [ + "REACHY2_ANTENNAS_JOINTS", + "REACHY2_L_ARM_JOINTS", + "REACHY2_NECK_JOINTS", + "REACHY2_R_ARM_JOINTS", + "REACHY2_VEL", + "Reachy2Teleoperator", + "Reachy2TeleoperatorConfig", +] diff --git a/src/lerobot/teleoperators/so_leader/__init__.py b/src/lerobot/teleoperators/so_leader/__init__.py index e5aaa31b6..26ef66677 100644 --- a/src/lerobot/teleoperators/so_leader/__init__.py +++ b/src/lerobot/teleoperators/so_leader/__init__.py @@ -21,3 +21,13 @@ from .config_so_leader import ( SOLeaderTeleopConfig, ) from .so_leader import SO100Leader, SO101Leader, SOLeader + +__all__ = [ + "SO100Leader", + "SO100LeaderConfig", + "SO101Leader", + "SO101LeaderConfig", + "SOLeader", + "SOLeaderConfig", + "SOLeaderTeleopConfig", +] diff --git a/src/lerobot/teleoperators/unitree_g1/exo_calib.py b/src/lerobot/teleoperators/unitree_g1/exo_calib.py index b90e8fd7e..05f5180ff 100644 --- a/src/lerobot/teleoperators/unitree_g1/exo_calib.py +++ b/src/lerobot/teleoperators/unitree_g1/exo_calib.py @@ -22,15 +22,24 @@ and calculate arctan2 of the unit circle to get the joint angle. We then store the ellipse parameters and the zero offset for each joint to be used at runtime. """ +from __future__ import annotations + import json import logging import time from collections import deque from dataclasses import dataclass, field from pathlib import Path +from typing import TYPE_CHECKING import numpy as np -import serial + +from lerobot.utils.import_utils import _serial_available + +if TYPE_CHECKING or _serial_available: + import serial +else: + serial = None # type: ignore[assignment] logger = logging.getLogger(__name__) @@ -82,7 +91,7 @@ class ExoskeletonCalibration: } @classmethod - def from_dict(cls, data: dict) -> "ExoskeletonCalibration": + def from_dict(cls, data: dict) -> ExoskeletonCalibration: joints = [ ExoskeletonJointCalibration( name=j["name"], diff --git a/src/lerobot/teleoperators/unitree_g1/exo_serial.py b/src/lerobot/teleoperators/unitree_g1/exo_serial.py index 4f45997c0..9b1c71891 100644 --- a/src/lerobot/teleoperators/unitree_g1/exo_serial.py +++ b/src/lerobot/teleoperators/unitree_g1/exo_serial.py @@ -14,12 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import logging from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING -import serial +from lerobot.utils.import_utils import _serial_available, require_package + +if TYPE_CHECKING or _serial_available: + import serial +else: + serial = None # type: ignore[assignment] from .exo_calib import ExoskeletonCalibration, exo_raw_to_angles, run_exo_calibration @@ -68,6 +76,7 @@ class ExoskeletonArm: calibration: ExoskeletonCalibration | None = None def __post_init__(self): + require_package("pyserial", extra="hardware", import_name="serial") if self.calibration_fpath.is_file(): self._load_calibration() diff --git a/src/lerobot/transforms/__init__.py b/src/lerobot/transforms/__init__.py new file mode 100644 index 000000000..6cf9699d0 --- /dev/null +++ b/src/lerobot/transforms/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transforms import ( + ImageTransformConfig, + ImageTransforms, + ImageTransformsConfig, + RandomSubsetApply, + SharpnessJitter, + make_transform_from_config, +) + +__all__ = [ + "ImageTransformConfig", + "ImageTransforms", + "ImageTransformsConfig", + "RandomSubsetApply", + "SharpnessJitter", + "make_transform_from_config", +] diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/transforms/transforms.py similarity index 100% rename from src/lerobot/datasets/transforms.py rename to src/lerobot/transforms/transforms.py diff --git a/src/lerobot/transport/__init__.py b/src/lerobot/transport/__init__.py new file mode 100644 index 000000000..92ed74188 --- /dev/null +++ b/src/lerobot/transport/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +gRPC transport layer for async inference. + +Requires: ``pip install 'lerobot[grpcio-dep]'`` + +Available modules (import directly):: + + from lerobot.transport.utils import ... +""" + +from lerobot.utils.import_utils import require_package + +require_package("grpcio", extra="grpcio-dep", import_name="grpc") + +__all__: list[str] = [] diff --git a/src/lerobot/utils/__init__.py b/src/lerobot/utils/__init__.py new file mode 100644 index 000000000..ee4808353 --- /dev/null +++ b/src/lerobot/utils/__init__.py @@ -0,0 +1,65 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Public API for lightweight, base-dependency-only utilities. + +Heavy cross-cutting modules (train_utils, control_utils) have been moved +to ``lerobot.common``. ``visualization_utils`` remains here but is +intentionally NOT re-exported to avoid pulling in optional dependencies. +""" + +from .constants import ( + ACTION, + DEFAULT_FEATURES, + DONE, + IMAGENET_STATS, + OBS_ENV_STATE, + OBS_IMAGE, + OBS_IMAGES, + OBS_STATE, + OBS_STR, + REWARD, +) +from .decorators import check_if_already_connected, check_if_not_connected +from .device_utils import auto_select_torch_device, get_safe_torch_device, is_torch_device_available +from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from .import_utils import is_package_available, require_package + +__all__ = [ + # Constants + "ACTION", + "DEFAULT_FEATURES", + "DONE", + "IMAGENET_STATS", + "OBS_ENV_STATE", + "OBS_IMAGE", + "OBS_IMAGES", + "OBS_STATE", + "OBS_STR", + "REWARD", + # Device utilities + "auto_select_torch_device", + "get_safe_torch_device", + "is_torch_device_available", + # Import guards + "is_package_available", + "require_package", + # Decorators + "check_if_already_connected", + "check_if_not_connected", + # Errors + "DeviceAlreadyConnectedError", + "DeviceNotConnectedError", +] diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index fd10cab35..43869228d 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -75,6 +75,21 @@ default_calibration_path = HF_LEROBOT_HOME / "calibration" HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser() +# Dataset meta-features (auto-populated by the recording pipeline) +DEFAULT_FEATURES = { + "timestamp": {"dtype": "float32", "shape": (1,), "names": None}, + "frame_index": {"dtype": "int64", "shape": (1,), "names": None}, + "episode_index": {"dtype": "int64", "shape": (1,), "names": None}, + "index": {"dtype": "int64", "shape": (1,), "names": None}, + "task_index": {"dtype": "int64", "shape": (1,), "names": None}, +} + +# ImageNet normalization constants +IMAGENET_STATS = { + "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) + "std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1) +} + # streaming datasets LOOKBACK_BACKTRACKTABLE = 100 LOOKAHEAD_BACKTRACKTABLE = 100 diff --git a/src/lerobot/utils/decorators.py b/src/lerobot/utils/decorators.py index 8fc2f9a07..75171f637 100644 --- a/src/lerobot/utils/decorators.py +++ b/src/lerobot/utils/decorators.py @@ -16,7 +16,7 @@ from functools import wraps -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError def check_if_not_connected(func): diff --git a/src/lerobot/utils/feature_utils.py b/src/lerobot/utils/feature_utils.py new file mode 100644 index 000000000..2a4886234 --- /dev/null +++ b/src/lerobot/utils/feature_utils.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Lightweight feature-manipulation utilities. + +These functions are intentionally kept free of heavy dependencies (e.g. the +HuggingFace ``datasets`` library) so that they can be imported from anywhere +in the codebase – including modules that are part of the *minimal* install – +without triggering the ``lerobot.datasets`` package guard. +""" + +from typing import Any + +import numpy as np + +from lerobot.configs import FeatureType, PolicyFeature + +from .constants import ACTION, DEFAULT_FEATURES, OBS_ENV_STATE, OBS_STR + + +def _validate_feature_names(features: dict[str, dict]) -> None: + """Validate that feature names do not contain invalid characters. + + Args: + features (dict): The LeRobot features dictionary. + + Raises: + ValueError: If any feature name contains '/'. + """ + invalid_features = {name: ft for name, ft in features.items() if "/" in name} + if invalid_features: + raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") + + +def hw_to_dataset_features( + hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True +) -> dict[str, dict]: + """Convert hardware-specific features to a LeRobot dataset feature dictionary. + + This function takes a dictionary describing hardware outputs (like joint states + or camera image shapes) and formats it into the standard LeRobot feature + specification. + + Args: + hw_features (dict): Dictionary mapping feature names to their type (float for + joints) or shape (tuple for images). + prefix (str): The prefix to add to the feature keys (e.g., "observation" + or "action"). + use_video (bool): If True, image features are marked as "video", otherwise "image". + + Returns: + dict: A LeRobot features dictionary. + """ + features = {} + joint_fts = { + key: ftype + for key, ftype in hw_features.items() + if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) + } + cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} + + if joint_fts and prefix == ACTION: + features[prefix] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + if joint_fts and prefix == OBS_STR: + features[f"{prefix}.state"] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + for key, shape in cam_fts.items(): + features[f"{prefix}.images.{key}"] = { + "dtype": "video" if use_video else "image", + "shape": shape, + "names": ["height", "width", "channels"], + } + + _validate_feature_names(features) + return features + + +def build_dataset_frame( + ds_features: dict[str, dict], values: dict[str, Any], prefix: str +) -> dict[str, np.ndarray]: + """Construct a single data frame from raw values based on dataset features. + + A "frame" is a dictionary containing all the data for a single timestep, + formatted as numpy arrays according to the feature specification. + + Args: + ds_features (dict): The LeRobot dataset features dictionary. + values (dict): A dictionary of raw values from the hardware/environment. + prefix (str): The prefix to filter features by (e.g., "observation" + or "action"). + + Returns: + dict: A dictionary representing a single frame of data. + """ + frame = {} + for key, ft in ds_features.items(): + if key in DEFAULT_FEATURES or not key.startswith(prefix): + continue + elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: + frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) + elif ft["dtype"] in ["image", "video"]: + frame[key] = values[key.removeprefix(f"{prefix}.images.")] + + return frame + + +def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + """Convert dataset features to policy features. + + This function transforms the dataset's feature specification into a format + that a policy can use, classifying features by type (e.g., visual, state, + action) and ensuring correct shapes (e.g., channel-first for images). + + Args: + features (dict): The LeRobot dataset features dictionary. + + Returns: + dict: A dictionary mapping feature keys to `PolicyFeature` objects. + + Raises: + ValueError: If an image feature does not have a 3D shape. + """ + # TODO(aliberts): Implement "type" in dataset features and simplify this + policy_features = {} + for key, ft in features.items(): + shape = ft["shape"] + if ft["dtype"] in ["image", "video"]: + type = FeatureType.VISUAL + if len(shape) != 3: + raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") + + names = ft["names"] + # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. + if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif key == OBS_ENV_STATE: + type = FeatureType.ENV + elif key.startswith(OBS_STR): + type = FeatureType.STATE + elif key.startswith(ACTION): + type = FeatureType.ACTION + else: + continue + + policy_features[key] = PolicyFeature( + type=type, + shape=shape, + ) + + return policy_features + + +def combine_feature_dicts(*dicts: dict) -> dict: + """Merge LeRobot grouped feature dicts. + + - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. + - For others (e.g. `observation.images.*`), the last one wins (if they are identical). + + Args: + *dicts: A variable number of LeRobot feature dictionaries to merge. + + Returns: + dict: A single merged feature dictionary. + + Raises: + ValueError: If there's a dtype mismatch for a feature being merged. + """ + out: dict = {} + for d in dicts: + for key, value in d.items(): + if not isinstance(value, dict): + out[key] = value + continue + + dtype = value.get("dtype") + shape = value.get("shape") + is_vector = ( + dtype not in ("image", "video", "string") + and isinstance(shape, tuple) + and len(shape) == 1 + and "names" in value + ) + + if is_vector: + # Initialize or retrieve the accumulating dict for this feature key + target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) + # Ensure consistent data types across merged entries + if "dtype" in target and dtype != target["dtype"]: + raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") + + # Merge feature names: append only new ones to preserve order without duplicates + seen = set(target["names"]) + for n in value["names"]: + if n not in seen: + target["names"].append(n) + seen.add(n) + # Recompute the shape to reflect the updated number of features + target["shape"] = (len(target["names"]),) + else: + # For images/videos and non-1D entries: override with the latest definition + out[key] = value + return out diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 2b26b2302..8cd24b0fa 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -69,13 +69,64 @@ def is_package_available( return package_exists +def get_safe_default_codec(): + logger = logging.getLogger(__name__) + if importlib.util.find_spec("torchcodec"): + return "torchcodec" + else: + logger.warning( + "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder" + ) + return "pyav" + + +_require_package_cache: dict[str, bool] = {} + + +def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None: + """Raise an informative ImportError if a package required by an optional feature is missing.""" + cache_key = import_name or pkg_name + if cache_key not in _require_package_cache: + _require_package_cache[cache_key] = is_package_available(pkg_name, import_name) + if not _require_package_cache[cache_key]: + raise ImportError( + f"'{pkg_name}' is required but not installed. Install it with: " + f"pip install 'lerobot[{extra}]' (or uv pip install 'lerobot[{extra}]')" + ) + + +# ── Centralised availability flags ──────────────────────────────────────── +# Every optional-dependency check lives here so that the rest of the codebase +# can simply ``from lerobot.utils.import_utils import _foo_available``. +# Do NOT define ad-hoc ``is_package_available(...)`` calls in other modules. + +# ML / training _transformers_available = is_package_available("transformers") _peft_available = is_package_available("peft") _scipy_available = is_package_available("scipy") +_diffusers_available = is_package_available("diffusers") +_torchdiffeq_available = is_package_available("torchdiffeq") + +# Hardware SDKs +_serial_available = is_package_available("pyserial", import_name="serial") +_deepdiff_available = is_package_available("deepdiff") +_dynamixel_sdk_available = is_package_available("dynamixel-sdk", import_name="dynamixel_sdk") +_feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="scservo_sdk") _reachy2_sdk_available = is_package_available("reachy2_sdk") _can_available = is_package_available("python-can", "can") _unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py") + +# Data / serialization +_pandas_available = is_package_available("pandas") +_faker_available = is_package_available("faker") + +# Misc +_pynput_available = is_package_available("pynput") _pygame_available = is_package_available("pygame") +_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils") +_wallx_deps_available = ( + _transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available +) def make_device_from_device_class(config: ChoiceRegistry) -> Any: diff --git a/src/lerobot/utils/io_utils.py b/src/lerobot/utils/io_utils.py index d70ea8b6a..e037b412c 100644 --- a/src/lerobot/utils/io_utils.py +++ b/src/lerobot/utils/io_utils.py @@ -14,21 +14,80 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -import warnings +import logging from pathlib import Path +from typing import Any -import imageio +logger = logging.getLogger(__name__) JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...] -def write_video(video_path, stacked_frames, fps): - # Filter out DeprecationWarnings raised from pkg_resources - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning - ) - imageio.mimsave(video_path, stacked_frames, fps=fps) +def load_json(fpath: Path) -> Any: + """Load data from a JSON file. + + Args: + fpath (Path): Path to the JSON file. + + Returns: + Any: The data loaded from the JSON file. + """ + with open(fpath) as f: + return json.load(f) + + +def write_json(data: dict, fpath: Path) -> None: + """Write data to a JSON file. + + Creates parent directories if they don't exist. + + Args: + data (dict): The dictionary to write. + fpath (Path): The path to the output JSON file. + """ + fpath.parent.mkdir(exist_ok=True, parents=True) + with open(fpath, "w") as f: + json.dump(data, f, indent=4, ensure_ascii=False) + + +def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None: + """Write a sequence of RGB frames to an MP4 video file using libx264. + + Args: + video_path: Output file path. + stacked_frames: List of HWC uint8 numpy arrays (RGB). + fps: Frames per second for the output video. + """ + from .import_utils import require_package + + require_package("av", extra="av-dep") + import av + + with av.open(str(video_path), mode="w") as container: + orig_height, orig_width = stacked_frames[0].shape[:2] + # yuv420p requires even dimensions; crop by one pixel if needed + height = orig_height if orig_height % 2 == 0 else orig_height - 1 + width = orig_width if orig_width % 2 == 0 else orig_width - 1 + if height != orig_height or width != orig_width: + logger.warning( + "Frame dimensions %dx%d are not even; cropping to %dx%d for yuv420p compatibility.", + orig_width, + orig_height, + width, + height, + ) + stream = container.add_stream("libx264", rate=fps) + stream.width = width + stream.height = height + stream.pix_fmt = "yuv420p" + for frame_array in stacked_frames: + if height != orig_height or width != orig_width: + frame_array = frame_array[:height, :width] + frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24") + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(): + container.mux(packet) def deserialize_json_into_object[T: JsonLike](fpath: Path, obj: T) -> T: diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py index 1497c0585..0ce596f55 100644 --- a/src/lerobot/utils/logging_utils.py +++ b/src/lerobot/utils/logging_utils.py @@ -16,7 +16,7 @@ from collections.abc import Callable from typing import Any -from lerobot.utils.utils import format_big_number +from .utils import format_big_number class AverageMeter: diff --git a/src/lerobot/utils/random_utils.py b/src/lerobot/utils/random_utils.py index b34d357aa..e280fc342 100644 --- a/src/lerobot/utils/random_utils.py +++ b/src/lerobot/utils/random_utils.py @@ -23,8 +23,8 @@ import numpy as np import torch from safetensors.torch import load_file, save_file -from lerobot.datasets.utils import flatten_dict, unflatten_dict -from lerobot.utils.constants import RNG_STATE +from .constants import RNG_STATE +from .utils import flatten_dict, unflatten_dict def serialize_python_rng_state() -> dict[str, torch.Tensor]: diff --git a/src/lerobot/utils/transition.py b/src/lerobot/utils/transition.py index fe3620861..a79b95151 100644 --- a/src/lerobot/utils/transition.py +++ b/src/lerobot/utils/transition.py @@ -18,7 +18,7 @@ from typing import TypedDict import torch -from lerobot.utils.constants import ACTION +from .constants import ACTION class Transition(TypedDict): diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index f6aa93bea..2574f1fa3 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -22,11 +22,12 @@ import select import subprocess import sys import time +from collections.abc import Iterator from copy import copy, deepcopy from datetime import datetime from pathlib import Path from statistics import mean -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np @@ -199,6 +200,80 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float): return days, hours, minutes, seconds +def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: + """Flatten a nested dictionary by joining keys with a separator. + + Example: + >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3} + >>> print(flatten_dict(dct)) + {'a/b': 1, 'a/c/d': 2, 'e': 3} + + Args: + d (dict): The dictionary to flatten. + parent_key (str): The base key to prepend to the keys in this level. + sep (str): The separator to use between keys. + + Returns: + dict: A flattened dictionary. + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_dict(d: dict, sep: str = "/") -> dict: + """Unflatten a dictionary with delimited keys into a nested dictionary. + + Example: + >>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3} + >>> print(unflatten_dict(flat_dct)) + {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} + + Args: + d (dict): A dictionary with flattened keys. + sep (str): The separator used in the keys. + + Returns: + dict: A nested dictionary. + """ + outdict = {} + for key, value in d.items(): + parts = key.split(sep) + d_inner = outdict + for part in parts[:-1]: + if part not in d_inner: + d_inner[part] = {} + d_inner = d_inner[part] + d_inner[parts[-1]] = value + return outdict + + +def cycle(iterable: Any) -> Iterator[Any]: + """Create a dataloader-safe cyclical iterator. + + This is an equivalent of `itertools.cycle` but is safe for use with + PyTorch DataLoaders with multiple workers. + See https://github.com/pytorch/pytorch/issues/23900 for details. + + Args: + iterable: The iterable to cycle over. + + Yields: + Items from the iterable, restarting from the beginning when exhausted. + """ + iterator = iter(iterable) + while True: + try: + yield next(iterator) + except StopIteration: + iterator = iter(iterable) + + class SuppressProgressBars: """ Context manager to suppress progress bars. @@ -212,14 +287,22 @@ class SuppressProgressBars: """ def __enter__(self): - from datasets.utils.logging import disable_progress_bar + try: + from datasets.utils.logging import disable_progress_bar - disable_progress_bar() + disable_progress_bar() + except ImportError: + logging.getLogger(__name__).debug( + "SuppressProgressBars is a no-op because 'datasets' is not installed." + ) def __exit__(self, exc_type, exc_val, exc_tb): - from datasets.utils.logging import enable_progress_bar + try: + from datasets.utils.logging import enable_progress_bar - enable_progress_bar() + enable_progress_bar() + except ImportError: + pass class TimerManager: diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 782358c9e..d9d5bf6b5 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -16,11 +16,11 @@ import numbers import os import numpy as np -import rerun as rr from lerobot.types import RobotAction, RobotObservation from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR +from .import_utils import require_package def init_rerun( @@ -34,6 +34,10 @@ def init_rerun( ip: Optional IP for connecting to a Rerun server. port: Optional port for connecting to a Rerun server. """ + + require_package("rerun-sdk", extra="viz", import_name="rerun") + import rerun as rr + batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size rr.init(session_name) @@ -44,6 +48,15 @@ def init_rerun( rr.spawn(memory_limit=memory_limit) +def shutdown_rerun() -> None: + """Shuts down the Rerun SDK gracefully.""" + + require_package("rerun-sdk", extra="viz", import_name="rerun") + import rerun as rr + + rr.rerun_shutdown() + + def _is_scalar(x): return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or ( isinstance(x, np.ndarray) and x.ndim == 0 @@ -73,6 +86,10 @@ def log_rerun_data( action: An optional dictionary containing action data to log. compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality. """ + + require_package("rerun-sdk", extra="viz", import_name="rerun") + import rerun as rr + if observation: for k, v in observation.items(): if v is None: diff --git a/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py b/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py index ce15d16fd..182058563 100644 --- a/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py +++ b/tests/artifacts/image_transforms/save_image_transforms_to_safetensors.py @@ -19,7 +19,7 @@ import torch from safetensors.torch import save_file from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.transforms import ( +from lerobot.transforms import ( ImageTransformConfig, ImageTransforms, ImageTransformsConfig, diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 7359f6169..ffb3efd03 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -21,7 +21,7 @@ from safetensors.torch import save_file from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.factory import make_dataset +from lerobot.datasets import make_dataset from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors from lerobot.utils.constants import OBS_STR diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index 54ca29b48..8c5861a91 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -35,8 +35,10 @@ from concurrent import futures import pytest import torch -# Skip entire module if grpc is not available +# Skip entire module if required deps are not available pytest.importorskip("grpc") +pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])") +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") # ----------------------------------------------------------------------------- # End-to-end test diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py index a9e53200d..17fca2a44 100644 --- a/tests/async_inference/test_helpers.py +++ b/tests/async_inference/test_helpers.py @@ -16,10 +16,14 @@ import math import pickle import time -import numpy as np -import torch +import pytest -from lerobot.async_inference.helpers import ( +pytest.importorskip("grpc") + +import numpy as np # noqa: E402 +import torch # noqa: E402 + +from lerobot.async_inference.helpers import ( # noqa: E402 FPSTracker, TimedAction, TimedObservation, diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index c3ee37c8f..5cec2051c 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -24,7 +24,7 @@ import torch from lerobot.configs.types import PolicyFeature from lerobot.utils.constants import OBS_STATE -from tests.utils import require_package +from tests.utils import skip_if_package_missing # ----------------------------------------------------------------------------- # Test fixtures @@ -62,7 +62,7 @@ class MockPolicy: @pytest.fixture -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def policy_server(): """Fresh `PolicyServer` instance with a stubbed-out policy model.""" # Import only when the test actually runs (after decorator check) diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py index d7ef5b350..e2d840358 100644 --- a/tests/async_inference/test_robot_client.py +++ b/tests/async_inference/test_robot_client.py @@ -25,8 +25,10 @@ from queue import Queue import pytest import torch -# Skip entire module if grpc is not available +# Skip entire module if required deps are not available pytest.importorskip("grpc") +pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])") +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") # ----------------------------------------------------------------------------- # Test fixtures diff --git a/tests/conftest.py b/tests/conftest.py index 2fcf878ab..cadeaf0d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,24 +17,39 @@ import traceback import pytest -from serial import SerialException from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.utils.import_utils import is_package_available from tests.utils import DEVICE -# Import fixture modules as plugins +# Import fixture modules as plugins. +# Fixtures that depend on optional packages are only registered when those packages are available, +# so that tests can be collected and run even with a minimal install. pytest_plugins = [ - "tests.fixtures.dataset_factories", - "tests.fixtures.files", - "tests.fixtures.hub", "tests.fixtures.optimizers", ] +if is_package_available("datasets"): + pytest_plugins += [ + "tests.fixtures.dataset_factories", + "tests.fixtures.files", + "tests.fixtures.hub", + ] + def pytest_collection_finish(): print(f"\nTesting with {DEVICE=}") +def _is_serial_exception(exc: Exception) -> bool: + """Check if an exception is a SerialException without requiring pyserial.""" + if not is_package_available("pyserial", import_name="serial"): + return False + from serial import SerialException + + return isinstance(exc, SerialException) + + def _check_component_availability(component_type, available_components, make_component): """Generic helper to check if a hardware component is available""" if component_type not in available_components: @@ -53,7 +68,7 @@ def _check_component_availability(component_type, available_components, make_com if isinstance(e, ModuleNotFoundError): print(f"\nInstall module '{e.name}'") - elif isinstance(e, SerialException): + elif _is_serial_exception(e): print("\nNo physical device detected.") elif isinstance(e, ValueError) and "camera_index" in str(e): print("\nNo physical camera detected.") diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 4ac7e001a..b74299311 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -16,7 +16,11 @@ from unittest.mock import patch -import datasets +import pytest + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import datasets # noqa: E402 import torch from lerobot.datasets.aggregate import aggregate_datasets diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 973c80bd8..70ba42378 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -18,6 +18,8 @@ from unittest.mock import patch import numpy as np import pytest +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.compute_stats import ( RunningQuantileStats, _assert_type_and_shape, diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py index 3f3971e15..6db41d05c 100644 --- a/tests/datasets/test_dataset_metadata.py +++ b/tests/datasets/test_dataset_metadata.py @@ -20,6 +20,8 @@ import json import numpy as np import pytest +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata from lerobot.datasets.utils import INFO_PATH from tests.fixtures.constants import DEFAULT_FPS, DUMMY_ROBOT_TYPE diff --git a/tests/datasets/test_dataset_reader.py b/tests/datasets/test_dataset_reader.py index 4c8a8b23f..bbe858b5d 100644 --- a/tests/datasets/test_dataset_reader.py +++ b/tests/datasets/test_dataset_reader.py @@ -15,8 +15,12 @@ # limitations under the License. """Contract tests for DatasetReader.""" +import pytest + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.dataset_reader import DatasetReader -from lerobot.datasets.video_utils import get_safe_default_codec +from lerobot.utils.import_utils import get_safe_default_codec # ── Loading ────────────────────────────────────────────────────────── diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 5ed7aa1a3..0b0862f00 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -21,6 +21,8 @@ import numpy as np import pytest import torch +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.dataset_tools import ( add_features, delete_episodes, diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index 874099e2b..bf705ba81 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -16,13 +16,16 @@ import pytest import torch -from datasets import Dataset + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +from datasets import Dataset # noqa: E402 from huggingface_hub import DatasetCard -from lerobot.datasets.feature_utils import combine_feature_dicts from lerobot.datasets.io_utils import hf_transform_to_torch from lerobot.datasets.utils import create_lerobot_dataset_card from lerobot.utils.constants import ACTION, OBS_IMAGES +from lerobot.utils.feature_utils import combine_feature_dicts def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]: diff --git a/tests/datasets/test_dataset_writer.py b/tests/datasets/test_dataset_writer.py index 8c6ee68bd..8d2bc0373 100644 --- a/tests/datasets/test_dataset_writer.py +++ b/tests/datasets/test_dataset_writer.py @@ -23,6 +23,8 @@ import pytest import torch from PIL import Image +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.dataset_writer import _encode_video_worker from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_IMAGE_PATH diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index d4e9e88b8..6d4c41aaa 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -21,21 +21,22 @@ from pathlib import Path import numpy as np import pytest import torch + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from huggingface_hub import HfApi from PIL import Image from safetensors.torch import load_file from torchvision.transforms import v2 -import lerobot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets.factory import make_dataset -from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features +from lerobot.datasets import make_dataset +from lerobot.datasets.feature_utils import get_hf_features_from_features from lerobot.datasets.image_writer import image_array_to_pil_image from lerobot.datasets.io_utils import hf_transform_to_torch from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.multi_dataset import MultiLeRobotDataset -from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, @@ -46,7 +47,9 @@ from lerobot.datasets.video_utils import VALID_VIDEO_CODECS from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config +from lerobot.transforms import ImageTransforms, ImageTransformsConfig from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD +from lerobot.utils.feature_utils import hw_to_dataset_features from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -493,13 +496,28 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): # - [ ] remove old tests +ENV_DATASET_POLICY_TRIPLETS = [ + ("aloha", dataset, "act") + for dataset in [ + "lerobot/aloha_sim_insertion_human", + "lerobot/aloha_sim_insertion_scripted", + "lerobot/aloha_sim_transfer_cube_human", + "lerobot/aloha_sim_transfer_cube_scripted", + "lerobot/aloha_sim_insertion_human_image", + "lerobot/aloha_sim_insertion_scripted_image", + "lerobot/aloha_sim_transfer_cube_human_image", + "lerobot/aloha_sim_transfer_cube_scripted_image", + ] +] + [ + ("pusht", dataset, policy) + for dataset in ["lerobot/pusht", "lerobot/pusht_image"] + for policy in ["diffusion", "vqbet"] +] + + @pytest.mark.parametrize( "env_name, repo_id, policy_name", - # Single dataset - lerobot.env_dataset_policy_triplets, - # Multi-dataset - # TODO after fix multidataset - # + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")], + ENV_DATASET_POLICY_TRIPLETS, ) def test_factory(env_name, repo_id, policy_name): """ diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py index 8d9529f68..e4e5cf4f3 100644 --- a/tests/datasets/test_delta_timestamps.py +++ b/tests/datasets/test_delta_timestamps.py @@ -13,6 +13,8 @@ # limitations under the License. import pytest +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.feature_utils import ( check_delta_timestamps, get_delta_indices, diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index ef7e8c395..4310274e4 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -21,7 +21,13 @@ from safetensors.torch import load_file from torchvision.transforms import v2 from torchvision.transforms.v2 import functional as F # noqa: N812 -from lerobot.datasets.transforms import ( +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +from lerobot.scripts.lerobot_imgtransform_viz import ( + save_all_transforms, + save_each_transform, +) +from lerobot.transforms import ( ImageTransformConfig, ImageTransforms, ImageTransformsConfig, @@ -29,10 +35,6 @@ from lerobot.datasets.transforms import ( SharpnessJitter, make_transform_from_config, ) -from lerobot.scripts.lerobot_imgtransform_viz import ( - save_all_transforms, - save_each_transform, -) from lerobot.utils.random_utils import seeded_context from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR from tests.utils import require_x86_64_kernel diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index 55419473f..916b8f017 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -20,6 +20,8 @@ import numpy as np import pytest from PIL import Image +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.image_writer import ( AsyncImageWriter, image_array_to_pil_image, diff --git a/tests/datasets/test_lerobot_dataset.py b/tests/datasets/test_lerobot_dataset.py index 5c3c24f99..49efa84d9 100644 --- a/tests/datasets/test_lerobot_dataset.py +++ b/tests/datasets/test_lerobot_dataset.py @@ -25,6 +25,8 @@ from unittest.mock import Mock import pytest import torch +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + import lerobot.datasets.dataset_metadata as dataset_metadata_module import lerobot.datasets.lerobot_dataset as lerobot_dataset_module from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata diff --git a/tests/datasets/test_quantiles_dataset_integration.py b/tests/datasets/test_quantiles_dataset_integration.py index 4df7fab06..b0e8a0e3c 100644 --- a/tests/datasets/test_quantiles_dataset_integration.py +++ b/tests/datasets/test_quantiles_dataset_integration.py @@ -19,6 +19,8 @@ import numpy as np import pytest +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.lerobot_dataset import LeRobotDataset diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 18fb1c8ac..8bb3be8e9 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -17,7 +17,10 @@ import logging import pytest import torch -from datasets import Dataset + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +from datasets import Dataset # noqa: E402 from lerobot.datasets.io_utils import ( hf_transform_to_torch, diff --git a/tests/datasets/test_streaming.py b/tests/datasets/test_streaming.py index 1bd4c1787..db167f657 100644 --- a/tests/datasets/test_streaming.py +++ b/tests/datasets/test_streaming.py @@ -17,6 +17,8 @@ import numpy as np import pytest import torch +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.utils import safe_shard from lerobot.utils.constants import ACTION diff --git a/tests/datasets/test_streaming_video_encoder.py b/tests/datasets/test_streaming_video_encoder.py index f7e63b06f..8b7a1540f 100644 --- a/tests/datasets/test_streaming_video_encoder.py +++ b/tests/datasets/test_streaming_video_encoder.py @@ -20,10 +20,13 @@ import queue import threading from unittest.mock import patch -import av import numpy as np import pytest +pytest.importorskip("av", reason="av is required (install lerobot[dataset])") + +import av # noqa: E402 + from lerobot.datasets.video_utils import ( VALID_VIDEO_CODECS, StreamingVideoEncoder, diff --git a/tests/datasets/test_subtask_dataset.py b/tests/datasets/test_subtask_dataset.py index f80a6c72d..bb77b77d1 100644 --- a/tests/datasets/test_subtask_dataset.py +++ b/tests/datasets/test_subtask_dataset.py @@ -23,8 +23,11 @@ These tests verify that: - Subtask handling gracefully handles missing data """ -import pandas as pd import pytest + +pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])") + +import pandas as pd # noqa: E402 import torch from lerobot.datasets.lerobot_dataset import LeRobotDataset diff --git a/tests/datasets/test_visualize_dataset.py b/tests/datasets/test_visualize_dataset.py index 8e92ec82e..3bf94e6cb 100644 --- a/tests/datasets/test_visualize_dataset.py +++ b/tests/datasets/test_visualize_dataset.py @@ -15,6 +15,8 @@ # limitations under the License. import pytest +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.scripts.lerobot_dataset_viz import visualize_dataset diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 910c275eb..c6a0b077d 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -23,7 +23,6 @@ import torch from gymnasium.envs.registration import register, registry as gym_registry from gymnasium.utils.env_checker import check_env -import lerobot from lerobot.configs.types import PolicyFeature from lerobot.envs.configs import EnvConfig from lerobot.envs.factory import make_env, make_env_config @@ -36,9 +35,16 @@ from tests.utils import require_env OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] +ENV_TASK_PAIRS = [ + ("aloha", "AlohaInsertion-v0"), + ("aloha", "AlohaTransferCube-v0"), + ("pusht", "PushT-v0"), +] +AVAILABLE_ENVS = ["aloha", "pusht"] + @pytest.mark.parametrize("obs_type", OBS_TYPES) -@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs) +@pytest.mark.parametrize("env_name, env_task", ENV_TASK_PAIRS) @require_env def test_env(env_name, env_task, obs_type): if env_name == "aloha" and obs_type == "state": @@ -51,7 +57,7 @@ def test_env(env_name, env_task, obs_type): env.close() -@pytest.mark.parametrize("env_name", lerobot.available_envs) +@pytest.mark.parametrize("env_name", AVAILABLE_ENVS) @require_env def test_factory(env_name): cfg = make_env_config(env_name) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 5ecb52145..e068484b0 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -34,12 +34,12 @@ from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, - DEFAULT_FEATURES, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, - flatten_dict, ) from lerobot.datasets.video_utils import encode_video_frames +from lerobot.utils.constants import DEFAULT_FEATURES +from lerobot.utils.utils import flatten_dict from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 84026fc34..4a424a97c 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -21,10 +21,25 @@ import dynamixel_sdk as dxl import serial from mock_serial.mock_serial import MockSerial -from lerobot.motors.dynamixel.dynamixel import _split_into_byte_chunks - from .mock_serial_patch import WaitableStub + +def _split_into_byte_chunks(value: int, length: int) -> list[int]: + """Split an integer into a list of byte-sized integers (little-endian).""" + if length == 1: + data = [value] + elif length == 2: + data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] + elif length == 4: + data = [ + dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), + ] + return data + + # https://emanual.robotis.com/docs/en/dxl/crc/ DXL_CRC_TABLE = [ 0x0000, 0x8005, 0x800F, 0x000A, 0x801B, 0x001E, 0x0014, 0x8011, diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 33cbc41d6..6e303b56b 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -21,11 +21,27 @@ import scservo_sdk as scs import serial from mock_serial import MockSerial -from lerobot.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout +from lerobot.motors.feetech.feetech import patch_setPacketTimeout from .mock_serial_patch import WaitableStub +def _split_into_byte_chunks(value: int, length: int) -> list[int]: + """Split an integer into a list of byte-sized integers (little-endian).""" + if length == 1: + data = [value] + elif length == 2: + data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] + elif length == 4: + data = [ + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), + scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), + scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), + ] + return data + + class MockFeetechPacket(abc.ABC): @classmethod def build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> bytes: diff --git a/tests/mocks/mock_motors_bus.py b/tests/mocks/mock_motors_bus.py index a499dbfee..9cb27224f 100644 --- a/tests/mocks/mock_motors_bus.py +++ b/tests/mocks/mock_motors_bus.py @@ -17,6 +17,7 @@ from lerobot.motors.motors_bus import ( Motor, MotorsBus, + MotorsBusBase, ) DUMMY_CTRL_TABLE_1 = { @@ -122,6 +123,12 @@ class MockPortHandler: class MockMotorsBus(MotorsBus): + """Mock motor bus that bypasses hardware dependency checks. + + Inherits from MotorsBus (alias for SerialMotorsBus) for type compatibility, + but calls MotorsBusBase.__init__ directly to skip the pyserial/deepdiff guards. + """ + available_baudrates = [500_000, 1_000_000] default_timeout = 1000 model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE @@ -132,8 +139,13 @@ class MockMotorsBus(MotorsBus): normalized_data = ["Present_Position", "Goal_Position"] def __init__(self, port: str, motors: dict[str, Motor]): - super().__init__(port, motors) + # Skip SerialMotorsBus.__init__ (which guards pyserial/deepdiff) + # and call the base class directly — this mock never touches real serial. + MotorsBusBase.__init__(self, port, motors) self.port_handler = MockPortHandler(port) + self._id_to_model_dict = {m.id: m.model for m in self.motors.values()} + self._id_to_name_dict = {m.id: name for name, m in self.motors.items()} + self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()} def _assert_protocol_is_compatible(self, instruction_name): ... def _handshake(self): ... diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py index 27650ef1b..60ecaeabb 100644 --- a/tests/motors/test_motors_bus.py +++ b/tests/motors/test_motors_bus.py @@ -19,6 +19,8 @@ from unittest.mock import patch import pytest +pytest.importorskip("serial", reason="pyserial is required (install lerobot[hardware])") + from lerobot.motors.motors_bus import ( Motor, MotorNormMode, diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index 224613416..5d6687102 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from packaging.version import Version from torch.optim.lr_scheduler import LambdaLR @@ -23,8 +24,10 @@ from lerobot.optim.schedulers import ( save_scheduler_state, ) from lerobot.utils.constants import SCHEDULER_STATE +from lerobot.utils.import_utils import is_package_available +@pytest.mark.skipif(not is_package_available("diffusers"), reason="diffusers not installed") def test_diffuser_scheduler(optimizer): config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5) scheduler = config.build(optimizer, num_training_steps=100) diff --git a/tests/policies/groot/test_groot_lerobot.py b/tests/policies/groot/test_groot_lerobot.py index e299a34e2..788935d4f 100644 --- a/tests/policies/groot/test_groot_lerobot.py +++ b/tests/policies/groot/test_groot_lerobot.py @@ -31,7 +31,7 @@ from lerobot.policies.groot.processor_groot import make_groot_pre_post_processor from lerobot.processor import PolicyProcessorPipeline from lerobot.types import PolicyAction from lerobot.utils.device_utils import auto_select_torch_device -from tests.utils import require_cuda # noqa: E402 +from tests.utils import require_cuda pytest.importorskip("transformers") diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index a62ef3ebb..6d262c01b 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput from lerobot.utils.constants import OBS_IMAGE, REWARD -from tests.utils import require_package +from tests.utils import skip_if_package_missing def test_classifier_output(): @@ -37,7 +37,7 @@ def test_classifier_output(): ) -@require_package("transformers") +@skip_if_package_missing("transformers") @pytest.mark.skip( reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" ) @@ -81,7 +81,7 @@ def test_binary_classifier_with_default_params(): assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" -@require_package("transformers") +@skip_if_package_missing("transformers") @pytest.mark.skip( reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" ) @@ -123,7 +123,7 @@ def test_multiclass_classifier(): assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" -@require_package("transformers") +@skip_if_package_missing("transformers") @pytest.mark.skip( reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" ) @@ -138,7 +138,7 @@ def test_default_device(): assert p.device == torch.device("cpu") -@require_package("transformers") +@skip_if_package_missing("transformers") @pytest.mark.skip( reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" ) diff --git a/tests/policies/smolvla/test_smolvla_rtc.py b/tests/policies/smolvla/test_smolvla_rtc.py index 53e74d940..8c64c8a6c 100644 --- a/tests/policies/smolvla/test_smolvla_rtc.py +++ b/tests/policies/smolvla/test_smolvla_rtc.py @@ -19,15 +19,15 @@ import pytest import torch -from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402 -from lerobot.policies.factory import make_pre_post_processors # noqa: E402 -from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402 +from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule +from lerobot.policies.factory import make_pre_post_processors +from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401 -from lerobot.utils.random_utils import set_seed # noqa: E402 -from tests.utils import require_cuda, require_package # noqa: E402 +from lerobot.utils.random_utils import set_seed +from tests.utils import require_cuda, skip_if_package_missing -@require_package("transformers") +@skip_if_package_missing("transformers") @require_cuda def test_smolvla_rtc_initialization(): from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401 @@ -65,7 +65,7 @@ def test_smolvla_rtc_initialization(): print("✓ SmolVLA RTC initialization: Test passed") -@require_package("transformers") +@skip_if_package_missing("transformers") @require_cuda def test_smolvla_rtc_initialization_without_rtc_config(): from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401 @@ -87,7 +87,7 @@ def test_smolvla_rtc_initialization_without_rtc_config(): print("✓ SmolVLA RTC initialization without RTC config: Test passed") -@require_package("transformers") +@skip_if_package_missing("transformers") @require_cuda @pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") def test_smolvla_rtc_inference_with_prev_chunk(): @@ -170,7 +170,7 @@ def test_smolvla_rtc_inference_with_prev_chunk(): print("✓ SmolVLA RTC inference with prev_chunk: Test passed") -@require_package("transformers") +@skip_if_package_missing("transformers") @require_cuda @pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") def test_smolvla_rtc_inference_without_prev_chunk(): @@ -244,7 +244,7 @@ def test_smolvla_rtc_inference_without_prev_chunk(): print("✓ SmolVLA RTC inference without prev_chunk: Test passed") -@require_package("transformers") +@skip_if_package_missing("transformers") @require_cuda @pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") def test_smolvla_rtc_validation_rules(): diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 4a8d3ab72..2d50446fe 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -20,16 +20,16 @@ from pathlib import Path import einops import pytest import torch + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from packaging import version from safetensors.torch import load_file -from lerobot import available_policies from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.datasets.factory import make_dataset -from lerobot.datasets.feature_utils import dataset_to_policy_features -from lerobot.datasets.utils import cycle +from lerobot.datasets import make_dataset from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.utils import close_envs, preprocess_observation from lerobot.optim.factory import make_optimizer_and_scheduler @@ -45,10 +45,23 @@ from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.feature_utils import dataset_to_policy_features +from lerobot.utils.import_utils import is_package_available from lerobot.utils.random_utils import seeded_context +from lerobot.utils.utils import cycle from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel +# Policies that require optional heavy dependencies to instantiate +_POLICY_REQUIRED_PACKAGES: dict[str, tuple[str, ...]] = { + "diffusion": ("diffusers",), +} + +_ALL_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"] +AVAILABLE_POLICIES = [ + p for p in _ALL_POLICIES if all(is_package_available(pkg) for pkg in _POLICY_REQUIRED_PACKAGES.get(p, ())) +] + @pytest.fixture def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path): @@ -84,7 +97,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p return ds_meta -@pytest.mark.parametrize("policy_name", available_policies) +@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES) def test_get_policy_and_config_classes(policy_name: str): """Check that the correct policy and config classes are returned.""" policy_cls = get_policy_class(policy_name) @@ -255,7 +268,7 @@ def test_act_backbone_lr(): assert len(optimizer.param_groups[1]["params"]) == 20 -@pytest.mark.parametrize("policy_name", available_policies) +@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES) def test_policy_defaults(dummy_dataset_metadata, policy_name: str): """Check that the policy can be instantiated with defaults.""" policy_cls = get_policy_class(policy_name) @@ -268,7 +281,7 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str): policy_cls(policy_cfg) -@pytest.mark.parametrize("policy_name", available_policies) +@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES) def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str): policy_cls = get_policy_class(policy_name) policy_cfg = make_policy_config(policy_name) @@ -343,7 +356,7 @@ def test_multikey_construction(multikey: bool): # to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference # that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass. # Thus, we deactivate this test for now. - ( + pytest.param( "lerobot/pusht", "diffusion", { @@ -352,6 +365,7 @@ def test_multikey_construction(multikey: bool): "down_dims": [128, 256, 512], }, "", + marks=pytest.mark.skipif(not is_package_available("diffusers"), reason="diffusers not installed"), ), ("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""), ( diff --git a/tests/policies/test_relative_actions.py b/tests/policies/test_relative_actions.py index 64c2ee9c4..15ef0a31b 100644 --- a/tests/policies/test_relative_actions.py +++ b/tests/policies/test_relative_actions.py @@ -10,6 +10,8 @@ import numpy as np import pytest import torch +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.datasets.compute_stats import get_feature_stats from lerobot.datasets.lerobot_dataset import LeRobotDataset diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index a335c2b4b..2c41de22c 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -25,6 +25,8 @@ import pytest import torch import torch.nn as nn +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features from lerobot.processor import ( diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py index 227b1dc35..2aa7d4bdf 100644 --- a/tests/processor/test_smolvla_processor.py +++ b/tests/processor/test_smolvla_processor.py @@ -22,14 +22,12 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig -from lerobot.policies.smolvla.processor_smolvla import ( - SmolVLANewLineProcessor, - make_smolvla_pre_post_processors, -) +from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors from lerobot.processor import ( AddBatchDimensionProcessorStep, DeviceProcessorStep, EnvTransition, + NewLineTaskProcessorStep, NormalizerProcessorStep, ProcessorStep, RenameObservationsProcessorStep, @@ -108,7 +106,7 @@ def test_make_smolvla_processor_basic(): assert len(preprocessor.steps) == 6 assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[2], SmolVLANewLineProcessor) + assert isinstance(preprocessor.steps[2], NewLineTaskProcessorStep) # Step 3 would be TokenizerProcessorStep but it's mocked assert isinstance(preprocessor.steps[4], DeviceProcessorStep) assert isinstance(preprocessor.steps[5], NormalizerProcessorStep) @@ -120,8 +118,8 @@ def test_make_smolvla_processor_basic(): def test_smolvla_newline_processor_single_task(): - """Test SmolVLANewLineProcessor with single task string.""" - processor = SmolVLANewLineProcessor() + """Test NewLineTaskProcessorStep with single task string.""" + processor = NewLineTaskProcessorStep() # Test with task that doesn't have newline transition = create_transition(complementary_data={"task": "test task"}) @@ -135,8 +133,8 @@ def test_smolvla_newline_processor_single_task(): def test_smolvla_newline_processor_list_of_tasks(): - """Test SmolVLANewLineProcessor with list of task strings.""" - processor = SmolVLANewLineProcessor() + """Test NewLineTaskProcessorStep with list of task strings.""" + processor = NewLineTaskProcessorStep() # Test with list of tasks tasks = ["task1", "task2\n", "task3"] @@ -147,8 +145,8 @@ def test_smolvla_newline_processor_list_of_tasks(): def test_smolvla_newline_processor_empty_transition(): - """Test SmolVLANewLineProcessor with empty transition.""" - processor = SmolVLANewLineProcessor() + """Test NewLineTaskProcessorStep with empty transition.""" + processor = NewLineTaskProcessorStep() # Test with no complementary_data transition = create_transition() @@ -361,8 +359,8 @@ def test_smolvla_processor_without_stats(): def test_smolvla_newline_processor_state_dict(): - """Test SmolVLANewLineProcessor state dict methods.""" - processor = SmolVLANewLineProcessor() + """Test NewLineTaskProcessorStep state dict methods.""" + processor = NewLineTaskProcessorStep() # Test state_dict (should be empty) state = processor.state_dict() @@ -380,8 +378,8 @@ def test_smolvla_newline_processor_state_dict(): def test_smolvla_newline_processor_transform_features(): - """Test SmolVLANewLineProcessor transform_features method.""" - processor = SmolVLANewLineProcessor() + """Test NewLineTaskProcessorStep transform_features method.""" + processor = NewLineTaskProcessorStep() # Test transform_features features = { diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 76dce2537..5708e6e81 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -36,7 +36,7 @@ from lerobot.utils.constants import ( OBS_LANGUAGE_SUBTASK_TOKENS, OBS_STATE, ) -from tests.utils import require_package +from tests.utils import skip_if_package_missing class MockTokenizer: @@ -94,7 +94,7 @@ def mock_tokenizer(): return MockTokenizer(vocab_size=100) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_basic_tokenization(mock_auto_tokenizer): """Test basic string tokenization functionality.""" @@ -129,7 +129,7 @@ def test_basic_tokenization(mock_auto_tokenizer): assert attention_mask.shape == (10,) -@require_package("transformers") +@skip_if_package_missing("transformers") def test_basic_tokenization_with_tokenizer_object(): """Test basic string tokenization functionality using tokenizer object directly.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -161,7 +161,7 @@ def test_basic_tokenization_with_tokenizer_object(): assert attention_mask.shape == (10,) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_list_of_strings_tokenization(mock_auto_tokenizer): """Test tokenization of a list of strings.""" @@ -189,7 +189,7 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer): assert attention_mask.shape == (2, 8) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_tuple_of_strings_tokenization(mock_auto_tokenizer): """Test tokenization of a tuple of strings (returned by VectorEnv.call()).""" @@ -213,7 +213,7 @@ def test_tuple_of_strings_tokenization(mock_auto_tokenizer): assert attention_mask.shape == (2, 8) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_custom_keys(mock_auto_tokenizer): """Test using custom task_key.""" @@ -239,7 +239,7 @@ def test_custom_keys(mock_auto_tokenizer): assert tokens.shape == (5,) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_none_complementary_data(mock_auto_tokenizer): """Test handling of None complementary_data.""" @@ -255,7 +255,7 @@ def test_none_complementary_data(mock_auto_tokenizer): processor(transition) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_missing_task_key(mock_auto_tokenizer): """Test handling when task key is missing.""" @@ -270,7 +270,7 @@ def test_missing_task_key(mock_auto_tokenizer): processor(transition) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_none_task_value(mock_auto_tokenizer): """Test handling when task value is None.""" @@ -285,7 +285,7 @@ def test_none_task_value(mock_auto_tokenizer): processor(transition) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_unsupported_task_type(mock_auto_tokenizer): """Test handling of unsupported task types.""" @@ -307,14 +307,14 @@ def test_unsupported_task_type(mock_auto_tokenizer): processor(transition) -@require_package("transformers") +@skip_if_package_missing("transformers") def test_no_tokenizer_error(): """Test that ValueError is raised when neither tokenizer nor tokenizer_name is provided.""" with pytest.raises(ValueError, match="Either 'tokenizer' or 'tokenizer_name' must be provided"): TokenizerProcessorStep() -@require_package("transformers") +@skip_if_package_missing("transformers") def test_invalid_tokenizer_name_error(): """Test that error is raised when invalid tokenizer_name is provided.""" with patch("lerobot.processor.tokenizer_processor.AutoTokenizer") as mock_auto_tokenizer: @@ -325,7 +325,7 @@ def test_invalid_tokenizer_name_error(): TokenizerProcessorStep(tokenizer_name="invalid-tokenizer") -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_get_config_with_tokenizer_name(mock_auto_tokenizer): """Test configuration serialization when using tokenizer_name.""" @@ -354,7 +354,7 @@ def test_get_config_with_tokenizer_name(mock_auto_tokenizer): assert config == expected -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_config_with_tokenizer_object(): """Test configuration serialization when using tokenizer object.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -382,7 +382,7 @@ def test_get_config_with_tokenizer_object(): assert "tokenizer_name" not in config -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_state_dict_methods(mock_auto_tokenizer): """Test state_dict and load_state_dict methods.""" @@ -399,7 +399,7 @@ def test_state_dict_methods(mock_auto_tokenizer): processor.load_state_dict({}) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_reset_method(mock_auto_tokenizer): """Test reset method.""" @@ -412,7 +412,7 @@ def test_reset_method(mock_auto_tokenizer): processor.reset() -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_integration_with_robot_processor(mock_auto_tokenizer): """Test integration with RobotProcessor.""" @@ -449,7 +449,7 @@ def test_integration_with_robot_processor(mock_auto_tokenizer): assert torch.equal(result[TransitionKey.ACTION], transition[TransitionKey.ACTION]) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer): """Test saving and loading processor with tokenizer_name.""" @@ -489,7 +489,7 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer): assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION] -@require_package("transformers") +@skip_if_package_missing("transformers") def test_save_and_load_pretrained_with_tokenizer_object(): """Test saving and loading processor with tokenizer object using overrides.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -528,7 +528,7 @@ def test_save_and_load_pretrained_with_tokenizer_object(): assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION] -@require_package("transformers") +@skip_if_package_missing("transformers") def test_registry_functionality(): """Test that the processor is properly registered.""" from lerobot.processor import ProcessorStepRegistry @@ -541,7 +541,7 @@ def test_registry_functionality(): assert retrieved_class is TokenizerProcessorStep -@require_package("transformers") +@skip_if_package_missing("transformers") def test_features_basic(): """Test basic feature contract functionality.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -574,7 +574,7 @@ def test_features_basic(): assert attention_mask_feature.shape == (128,) -@require_package("transformers") +@skip_if_package_missing("transformers") def test_features_with_custom_max_length(): """Test feature contract with custom max_length.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -596,7 +596,7 @@ def test_features_with_custom_max_length(): assert attention_mask_feature.shape == (64,) -@require_package("transformers") +@skip_if_package_missing("transformers") def test_features_existing_features(): """Test feature contract when tokenized features already exist.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -618,7 +618,7 @@ def test_features_existing_features(): assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"].shape == (100,) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_tokenization_parameters(mock_auto_tokenizer): """Test that tokenization parameters are correctly passed to tokenizer.""" @@ -666,7 +666,7 @@ def test_tokenization_parameters(mock_auto_tokenizer): assert tracking_tokenizer.last_call_kwargs["return_tensors"] == "pt" -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_preserves_other_complementary_data(mock_auto_tokenizer): """Test that other complementary data fields are preserved.""" @@ -701,7 +701,7 @@ def test_preserves_other_complementary_data(mock_auto_tokenizer): assert f"{OBS_LANGUAGE}.attention_mask" in observation -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_deterministic_tokenization(mock_auto_tokenizer): """Test that tokenization is deterministic for the same input.""" @@ -729,7 +729,7 @@ def test_deterministic_tokenization(mock_auto_tokenizer): assert torch.equal(attention_mask1, attention_mask2) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_empty_string_task(mock_auto_tokenizer): """Test handling of empty string task.""" @@ -753,7 +753,7 @@ def test_empty_string_task(mock_auto_tokenizer): assert tokens.shape == (8,) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_very_long_task(mock_auto_tokenizer): """Test handling of very long task strings.""" @@ -779,7 +779,7 @@ def test_very_long_task(mock_auto_tokenizer): assert attention_mask.shape == (5,) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_custom_padding_side(mock_auto_tokenizer): """Test using custom padding_side parameter.""" @@ -833,7 +833,7 @@ def test_custom_padding_side(mock_auto_tokenizer): assert tracking_tokenizer.padding_side_calls[-1] == "right" -@require_package("transformers") +@skip_if_package_missing("transformers") def test_device_detection_cpu(): """Test that tokenized tensors stay on CPU when other tensors are on CPU.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -857,7 +857,7 @@ def test_device_detection_cpu(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@require_package("transformers") +@skip_if_package_missing("transformers") def test_device_detection_cuda(): """Test that tokenized tensors are moved to CUDA when other tensors are on CUDA.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -882,7 +882,7 @@ def test_device_detection_cuda(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") -@require_package("transformers") +@skip_if_package_missing("transformers") def test_device_detection_multi_gpu(): """Test that tokenized tensors match device in multi-GPU setup.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -906,7 +906,7 @@ def test_device_detection_multi_gpu(): assert attention_mask.device == device -@require_package("transformers") +@skip_if_package_missing("transformers") def test_device_detection_no_tensors(): """Test that tokenized tensors stay on CPU when no other tensors exist.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -928,7 +928,7 @@ def test_device_detection_no_tensors(): assert attention_mask.device.type == "cpu" -@require_package("transformers") +@skip_if_package_missing("transformers") def test_device_detection_mixed_devices(): """Test device detection when tensors are on different devices (uses first found).""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -956,7 +956,7 @@ def test_device_detection_mixed_devices(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@require_package("transformers") +@skip_if_package_missing("transformers") def test_device_detection_from_action(): """Test that device is detected from action tensor when no observation tensors exist.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -979,7 +979,7 @@ def test_device_detection_from_action(): assert attention_mask.device.type == "cuda" -@require_package("transformers") +@skip_if_package_missing("transformers") def test_device_detection_preserves_dtype(): """Test that device detection doesn't affect dtype of tokenized tensors.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1000,7 +1000,7 @@ def test_device_detection_preserves_dtype(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_integration_with_device_processor(mock_auto_tokenizer): """Test that TokenizerProcessorStep works correctly with DeviceProcessorStep in pipeline.""" @@ -1039,7 +1039,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@require_package("transformers") +@skip_if_package_missing("transformers") def test_simulated_accelerate_scenario(): """Test scenario simulating Accelerate with data already on GPU.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1077,7 +1077,7 @@ def test_simulated_accelerate_scenario(): # ============================================================================= -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_missing_key(): """Test get_subtask returns None when subtask key is missing from complementary_data.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1093,7 +1093,7 @@ def test_get_subtask_missing_key(): assert result is None -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_none_value(): """Test get_subtask returns None when subtask value is None.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1109,7 +1109,7 @@ def test_get_subtask_none_value(): assert result is None -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_none_complementary_data(): """Test get_subtask returns None when complementary_data is None.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1125,7 +1125,7 @@ def test_get_subtask_none_complementary_data(): assert result is None -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_string(): """Test get_subtask returns list with single string when subtask is a string.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1143,7 +1143,7 @@ def test_get_subtask_string(): assert len(result) == 1 -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_list_of_strings(): """Test get_subtask returns the list when subtask is already a list of strings.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1162,7 +1162,7 @@ def test_get_subtask_list_of_strings(): assert len(result) == 3 -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_unsupported_type_integer(): """Test get_subtask returns None when subtask is an unsupported type (integer).""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1178,7 +1178,7 @@ def test_get_subtask_unsupported_type_integer(): assert result is None -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_unsupported_type_mixed_list(): """Test get_subtask returns None when subtask is a list with mixed types.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1194,7 +1194,7 @@ def test_get_subtask_unsupported_type_mixed_list(): assert result is None -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_unsupported_type_dict(): """Test get_subtask returns None when subtask is a dictionary.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1210,7 +1210,7 @@ def test_get_subtask_unsupported_type_dict(): assert result is None -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_empty_string(): """Test get_subtask with empty string returns list with empty string.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1226,7 +1226,7 @@ def test_get_subtask_empty_string(): assert result == [""] -@require_package("transformers") +@skip_if_package_missing("transformers") def test_get_subtask_empty_list(): """Test get_subtask with empty list returns empty list.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1247,7 +1247,7 @@ def test_get_subtask_empty_list(): # ============================================================================= -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_tokenization_when_present(): """Test that subtask is tokenized and added to observation when present.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1276,7 +1276,7 @@ def test_subtask_tokenization_when_present(): assert subtask_attention_mask.dtype == torch.bool -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_tokenization_not_added_when_none(): """Test that subtask tokens are NOT added to observation when subtask is None.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1300,7 +1300,7 @@ def test_subtask_tokenization_not_added_when_none(): assert f"{OBS_LANGUAGE}.attention_mask" in observation -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_tokenization_not_added_when_subtask_value_is_none(): """Test that subtask tokens are NOT added when subtask value is explicitly None.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1320,7 +1320,7 @@ def test_subtask_tokenization_not_added_when_subtask_value_is_none(): assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_tokenization_list_of_strings(): """Test subtask tokenization with list of strings.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1346,7 +1346,7 @@ def test_subtask_tokenization_list_of_strings(): assert subtask_attention_mask.shape == (2, 8) -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_tokenization_device_cpu(): """Test that subtask tokens are on CPU when other tensors are on CPU.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1372,7 +1372,7 @@ def test_subtask_tokenization_device_cpu(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_tokenization_device_cuda(): """Test that subtask tokens are moved to CUDA when other tensors are on CUDA.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1397,7 +1397,7 @@ def test_subtask_tokenization_device_cuda(): assert subtask_attention_mask.device.type == "cuda" -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_tokenization_preserves_other_observation_data(): """Test that subtask tokenization preserves other observation data.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1423,7 +1423,7 @@ def test_subtask_tokenization_preserves_other_observation_data(): assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_attention_mask_dtype(): """Test that subtask attention mask has correct dtype (bool).""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1442,7 +1442,7 @@ def test_subtask_attention_mask_dtype(): assert subtask_attention_mask.dtype == torch.bool -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_tokenization_deterministic(): """Test that subtask tokenization is deterministic for the same input.""" mock_tokenizer = MockTokenizer(vocab_size=100) @@ -1467,7 +1467,7 @@ def test_subtask_tokenization_deterministic(): assert torch.equal(subtask_mask1, subtask_mask2) -@require_package("transformers") +@skip_if_package_missing("transformers") @patch("lerobot.processor.tokenizer_processor.AutoTokenizer") def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer): """Test subtask tokenization works correctly with DataProcessorPipeline.""" @@ -1504,7 +1504,7 @@ def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer): assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,) -@require_package("transformers") +@skip_if_package_missing("transformers") def test_subtask_not_added_for_unsupported_types(): """Test that subtask tokens are not added when subtask has unsupported type.""" mock_tokenizer = MockTokenizer(vocab_size=100) diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py index 54e4d2870..08746ec91 100644 --- a/tests/rl/test_actor.py +++ b/tests/rl/test_actor.py @@ -19,11 +19,14 @@ from unittest.mock import patch import pytest import torch + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from torch.multiprocessing import Event, Queue from lerobot.utils.constants import OBS_STR from lerobot.utils.transition import Transition -from tests.utils import require_package +from tests.utils import skip_if_package_missing def create_learner_service_stub(): @@ -64,7 +67,7 @@ def close_service_stub(channel, server): server.stop(None) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_establish_learner_connection_success(): from lerobot.rl.actor import establish_learner_connection @@ -81,7 +84,7 @@ def test_establish_learner_connection_success(): close_service_stub(channel, server) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_establish_learner_connection_failure(): from lerobot.rl.actor import establish_learner_connection @@ -100,7 +103,7 @@ def test_establish_learner_connection_failure(): close_service_stub(channel, server) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_push_transitions_to_transport_queue(): from lerobot.rl.actor import push_transitions_to_transport_queue from lerobot.transport.utils import bytes_to_transitions @@ -135,7 +138,7 @@ def test_push_transitions_to_transport_queue(): assert_transitions_equal(deserialized_transition, transitions[i]) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_transitions_stream(): from lerobot.rl.actor import transitions_stream @@ -167,7 +170,7 @@ def test_transitions_stream(): assert streamed_data[2].data == b"transition_data_3" -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_interactions_stream(): from lerobot.rl.actor import interactions_stream diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index e13862d82..3978dfffd 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -20,13 +20,16 @@ import time import pytest import torch + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from torch.multiprocessing import Event, Queue from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.utils.constants import OBS_STR from lerobot.utils.transition import Transition -from tests.utils import require_package +from tests.utils import skip_if_package_missing def create_test_transitions(count: int = 3) -> list[Transition]: @@ -88,7 +91,7 @@ def cfg(): return cfg -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_end_to_end_transitions_flow(cfg): from lerobot.rl.actor import ( @@ -150,7 +153,7 @@ def test_end_to_end_transitions_flow(cfg): assert_transitions_equal(transition, input_transitions[i]) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(10) def test_end_to_end_interactions_flow(cfg): from lerobot.rl.actor import ( @@ -223,7 +226,7 @@ def test_end_to_end_interactions_flow(cfg): assert received == expected -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.parametrize("data_size", ["small", "large"]) @pytest.mark.timeout(10) def test_end_to_end_parameters_flow(cfg, data_size): diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py index d967388f0..f1023f0f3 100644 --- a/tests/rl/test_learner_service.py +++ b/tests/rl/test_learner_service.py @@ -20,7 +20,7 @@ from multiprocessing import Event, Queue import pytest -from tests.utils import require_package # our gRPC servicer class +from tests.utils import skip_if_package_missing # our gRPC servicer class @pytest.fixture(scope="function") @@ -39,7 +39,7 @@ def learner_service_stub(): close_learner_service_stub(channel, server) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def create_learner_service_stub( shutdown_event: Event, parameters_queue: Queue, @@ -75,7 +75,7 @@ def create_learner_service_stub( return services_pb2_grpc.LearnerServiceStub(channel), channel, server -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def close_learner_service_stub(channel, server): channel.close() server.stop(None) @@ -91,7 +91,7 @@ def test_ready_method(learner_service_stub): assert response == services_pb2.Empty() -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_interactions(): from lerobot.transport import services_pb2 @@ -135,7 +135,7 @@ def test_send_interactions(): assert interactions == [b"123", b"4", b"5", b"678"] -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_transitions(): from lerobot.transport import services_pb2 @@ -181,7 +181,7 @@ def test_send_transitions(): assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"] -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_transitions_empty_stream(): from lerobot.transport import services_pb2 @@ -209,7 +209,7 @@ def test_send_transitions_empty_stream(): assert transitions_queue.empty() -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_stream_parameters(): import time @@ -267,7 +267,7 @@ def test_stream_parameters(): assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_stream_parameters_with_shutdown(): from lerobot.transport import services_pb2 @@ -319,7 +319,7 @@ def test_stream_parameters_with_shutdown(): assert received_params == [b"param_batch_1", b"stop"] -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_stream_parameters_waits_and_retries_on_empty_queue(): import threading diff --git a/tests/rl/test_queue.py b/tests/rl/test_queue.py index b6716fbd6..cf3d6cdca 100644 --- a/tests/rl/test_queue.py +++ b/tests/rl/test_queue.py @@ -18,9 +18,13 @@ import threading import time from queue import Queue -from torch.multiprocessing import Queue as TorchMPQueue +import pytest -from lerobot.rl.queue import get_last_item_from_queue +pytest.importorskip("grpc") + +from torch.multiprocessing import Queue as TorchMPQueue # noqa: E402 + +from lerobot.rl.queue import get_last_item_from_queue # noqa: E402 def test_get_last_item_single_item(): diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py index 4d758ae35..83ed5a78b 100644 --- a/tests/scripts/test_edit_dataset_parsing.py +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -17,6 +17,8 @@ import draccus import pytest +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.scripts.lerobot_edit_dataset import ( ConvertImageToVideoConfig, DeleteEpisodesConfig, diff --git a/tests/test_available.py b/tests/test_available.py index 19e39b2b6..7dd1cdacb 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -13,48 +13,50 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib -import gymnasium as gym +from unittest.mock import patch + import pytest import lerobot -from lerobot.policies.act.modeling_act import ACTPolicy -from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy -from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy -from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy -from tests.utils import require_env +from lerobot.utils.import_utils import _require_package_cache, require_package -@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs) -@require_env -def test_available_env_task(env_name: str, task_name: list): - """ - This test verifies that all environments listed in `lerobot/__init__.py` can - be successfully imported — if they're installed — and that their - `available_tasks_per_env` are valid. - """ - package_name = f"gym_{env_name}" - importlib.import_module(package_name) - gym_handle = f"{package_name}/{task_name}" - assert gym_handle in gym.envs.registry, gym_handle +def test_version(): + """Verify the package exposes a version string.""" + assert isinstance(lerobot.__version__, str) + assert len(lerobot.__version__) > 0 -def test_available_policies(): - """ - This test verifies that the class attribute `name` for all policies is - consistent with those listed in `lerobot/__init__.py`. - """ - policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy] - policies = [pol_cls.name for pol_cls in policy_classes] - assert set(policies) == set(lerobot.available_policies), policies +def test_require_package_raises_when_missing(): + """require_package raises ImportError with install instructions when a package is missing.""" + with patch("lerobot.utils.import_utils.is_package_available", return_value=False): + # Clear the cache so the mock takes effect + _require_package_cache.clear() + try: + with pytest.raises(ImportError, match=r"pip install 'lerobot\[dataset\]'"): + require_package("datasets", extra="dataset") + finally: + _require_package_cache.clear() -def test_print(): - print(lerobot.available_envs) - print(lerobot.available_tasks_per_env) - print(lerobot.available_datasets) - print(lerobot.available_datasets_per_env) - print(lerobot.available_real_world_datasets) - print(lerobot.available_policies) - print(lerobot.available_policies_per_env) +def test_require_package_passes_when_available(): + """require_package does not raise when the package is installed.""" + with patch("lerobot.utils.import_utils.is_package_available", return_value=True): + _require_package_cache.clear() + try: + # Should not raise + require_package("datasets", extra="dataset") + finally: + _require_package_cache.clear() + + +def test_require_package_error_message_includes_uv(): + """Error message includes both pip and uv install commands.""" + with patch("lerobot.utils.import_utils.is_package_available", return_value=False): + _require_package_cache.clear() + try: + with pytest.raises(ImportError, match=r"uv pip install"): + require_package("grpcio", extra="async", import_name="grpc") + finally: + _require_package_cache.clear() diff --git a/tests/test_cli_peft.py b/tests/test_cli_peft.py index 42fef4741..5d653ee6b 100644 --- a/tests/test_cli_peft.py +++ b/tests/test_cli_peft.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from safetensors.torch import load_file -from .utils import require_package +from .utils import skip_if_package_missing # Skip this entire module in CI pytestmark = pytest.mark.skipif( @@ -37,7 +37,7 @@ def resolve_model_id_for_peft_training(policy_type): @pytest.mark.parametrize("policy_type", ["smolvla"]) -@require_package("peft") +@skip_if_package_missing("peft") def test_peft_training_push_to_hub_works(policy_type, tmp_path): """Ensure that push to hub stores PEFT only the adapter, not the full model weights.""" output_dir = tmp_path / f"output_{policy_type}" @@ -76,7 +76,7 @@ def test_peft_training_push_to_hub_works(policy_type, tmp_path): @pytest.mark.parametrize("policy_type", ["smolvla"]) -@require_package("peft") +@skip_if_package_missing("peft") def test_peft_training_works(policy_type, tmp_path): """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" output_dir = tmp_path / f"output_{policy_type}" @@ -125,7 +125,7 @@ def test_peft_training_works(policy_type, tmp_path): @pytest.mark.parametrize("policy_type", ["smolvla"]) -@require_package("peft") +@skip_if_package_missing("peft") def test_peft_training_params_are_fewer(policy_type, tmp_path): """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" output_dir = tmp_path / f"output_{policy_type}" @@ -176,7 +176,7 @@ def dummy_make_robot_from_config(*args, **kwargs): @pytest.mark.parametrize("policy_type", ["smolvla"]) -@require_package("peft") +@skip_if_package_missing("peft") def test_peft_record_loads_policy(policy_type, tmp_path): """Train a policy with PEFT and attempt to load it with `lerobot-record`.""" from peft import PeftModel diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 772588467..28e91a149 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -16,6 +16,11 @@ from unittest.mock import patch +import pytest + +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") +pytest.importorskip("deepdiff", reason="deepdiff is required (install lerobot[hardware])") + from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay diff --git a/tests/training/test_multi_gpu.py b/tests/training/test_multi_gpu.py index bb234e2e7..638dc3131 100644 --- a/tests/training/test_multi_gpu.py +++ b/tests/training/test_multi_gpu.py @@ -33,6 +33,8 @@ from pathlib import Path import pytest import torch +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.datasets.lerobot_dataset import LeRobotDataset diff --git a/tests/training/test_visual_validation.py b/tests/training/test_visual_validation.py index 89351e3c2..1df8006b2 100644 --- a/tests/training/test_visual_validation.py +++ b/tests/training/test_visual_validation.py @@ -31,6 +31,8 @@ from pathlib import Path import numpy as np import pytest +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + from lerobot.configs.default import DatasetConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py index 63632a8f4..d0df3d941 100644 --- a/tests/transport/test_transport_utils.py +++ b/tests/transport/test_transport_utils.py @@ -23,10 +23,10 @@ import torch from lerobot.utils.constants import ACTION from lerobot.utils.transition import Transition -from tests.utils import require_cuda, require_package +from tests.utils import require_cuda, skip_if_package_missing -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_bytes_buffer_size_empty_buffer(): from lerobot.transport.utils import bytes_buffer_size @@ -37,7 +37,7 @@ def test_bytes_buffer_size_empty_buffer(): assert buffer.tell() == 0 -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_bytes_buffer_size_small_buffer(): from lerobot.transport.utils import bytes_buffer_size @@ -47,7 +47,7 @@ def test_bytes_buffer_size_small_buffer(): assert buffer.tell() == 0 -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_bytes_buffer_size_large_buffer(): from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size @@ -58,7 +58,7 @@ def test_bytes_buffer_size_large_buffer(): assert buffer.tell() == 0 -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_send_bytes_in_chunks_empty_data(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -68,7 +68,7 @@ def test_send_bytes_in_chunks_empty_data(): assert len(chunks) == 0 -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_single_chunk_small_data(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -82,7 +82,7 @@ def test_single_chunk_small_data(): assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_not_silent_mode(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -94,7 +94,7 @@ def test_not_silent_mode(): assert chunks[0].data == b"Some data" -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_send_bytes_in_chunks_large_data(): from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 @@ -111,7 +111,7 @@ def test_send_bytes_in_chunks_large_data(): assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 @@ -124,7 +124,7 @@ def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_receive_bytes_in_chunks_empty_data(): from lerobot.transport.utils import receive_bytes_in_chunks @@ -138,7 +138,7 @@ def test_receive_bytes_in_chunks_empty_data(): assert queue.empty() -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_receive_bytes_in_chunks_single_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -157,7 +157,7 @@ def test_receive_bytes_in_chunks_single_chunk(): assert queue.empty() -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_receive_bytes_in_chunks_single_not_end_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -175,7 +175,7 @@ def test_receive_bytes_in_chunks_single_not_end_chunk(): assert queue.empty() -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_receive_bytes_in_chunks_multiple_chunks(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -199,7 +199,7 @@ def test_receive_bytes_in_chunks_multiple_chunks(): assert queue.empty() -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_receive_bytes_in_chunks_multiple_messages(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -235,7 +235,7 @@ def test_receive_bytes_in_chunks_multiple_messages(): assert queue.empty() -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_receive_bytes_in_chunks_shutdown_during_receive(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -259,7 +259,7 @@ def test_receive_bytes_in_chunks_shutdown_during_receive(): assert queue.empty() -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_receive_bytes_in_chunks_only_begin_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -279,7 +279,7 @@ def test_receive_bytes_in_chunks_only_begin_chunk(): assert queue.empty() -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_receive_bytes_in_chunks_missing_begin(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -303,7 +303,7 @@ def test_receive_bytes_in_chunks_missing_begin(): # Tests for state_to_bytes and bytes_to_state_dict -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_state_to_bytes_empty_dict(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -314,7 +314,7 @@ def test_state_to_bytes_empty_dict(): assert reconstructed == state_dict -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_bytes_to_state_dict_empty_data(): from lerobot.transport.utils import bytes_to_state_dict @@ -323,7 +323,7 @@ def test_bytes_to_state_dict_empty_data(): bytes_to_state_dict(b"") -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_state_to_bytes_simple_dict(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -347,7 +347,7 @@ def test_state_to_bytes_simple_dict(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_state_to_bytes_various_dtypes(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -372,7 +372,7 @@ def test_state_to_bytes_various_dtypes(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_bytes_to_state_dict_invalid_data(): from lerobot.transport.utils import bytes_to_state_dict @@ -382,7 +382,7 @@ def test_bytes_to_state_dict_invalid_data(): @require_cuda -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_state_to_bytes_various_dtypes_cuda(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -407,7 +407,7 @@ def test_state_to_bytes_various_dtypes_cuda(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_python_object_to_bytes_none(): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -439,7 +439,7 @@ def test_python_object_to_bytes_none(): (1, 2, 3), ], ) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_python_object_to_bytes_simple_types(obj): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -450,7 +450,7 @@ def test_python_object_to_bytes_simple_types(obj): assert type(reconstructed) is type(obj) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_python_object_to_bytes_with_tensors(): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -475,7 +475,7 @@ def test_python_object_to_bytes_with_tensors(): assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"]) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_transitions_to_bytes_empty_list(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -487,7 +487,7 @@ def test_transitions_to_bytes_empty_list(): assert isinstance(reconstructed, list) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_transitions_to_bytes_single_transition(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -509,7 +509,7 @@ def test_transitions_to_bytes_single_transition(): assert_transitions_equal(transitions[0], reconstructed[0]) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def assert_transitions_equal(t1: Transition, t2: Transition): """Helper to assert two transitions are equal.""" assert_observation_equal(t1["state"], t2["state"]) @@ -519,7 +519,7 @@ def assert_transitions_equal(t1: Transition, t2: Transition): assert_observation_equal(t1["next_state"], t2["next_state"]) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def assert_observation_equal(o1: dict, o2: dict): """Helper to assert two observations are equal.""" assert set(o1.keys()) == set(o2.keys()) @@ -527,7 +527,7 @@ def assert_observation_equal(o1: dict, o2: dict): assert torch.allclose(o1[key], o2[key]) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_transitions_to_bytes_multiple_transitions(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -551,7 +551,7 @@ def test_transitions_to_bytes_multiple_transitions(): assert_transitions_equal(original, reconstructed_item) -@require_package("grpcio", "grpc") +@skip_if_package_missing("grpcio", "grpc") def test_receive_bytes_in_chunks_unknown_state(): from lerobot.transport.utils import receive_bytes_in_chunks diff --git a/tests/utils.py b/tests/utils.py index 33c554804..f8f4b135b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,23 +20,11 @@ from functools import wraps import pytest import torch -from lerobot import available_cameras, available_motors, available_robots from lerobot.utils.device_utils import auto_select_torch_device from lerobot.utils.import_utils import is_package_available DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device())) -TEST_ROBOT_TYPES = [] -for robot_type in available_robots: - TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)] - -TEST_CAMERA_TYPES = [] -for camera_type in available_cameras: - TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)] - -TEST_MOTOR_TYPES = [] -for motor_type in available_motors: - TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)] # Camera indices used for connecting physical cameras OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0)) @@ -152,7 +140,7 @@ def require_env(func): return wrapper -def require_package_arg(func): +def skip_if_package_arg_missing(func): """ Decorator that skips the test if the required package is not installed. This is similar to `require_env` but more general in that it can check any package (not just environments). @@ -184,7 +172,7 @@ def require_package_arg(func): return wrapper -def require_package(package_name, import_name=None): +def skip_if_package_missing(package_name, import_name=None): """ Decorator that skips the test if the specified package is not installed. """ diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py index e2b00cae9..ce56db173 100644 --- a/tests/utils/test_process.py +++ b/tests/utils/test_process.py @@ -22,7 +22,9 @@ from unittest.mock import patch import pytest -from lerobot.rl.process import ProcessSignalHandler +pytest.importorskip("grpc") + +from lerobot.rl.process import ProcessSignalHandler # noqa: E402 # Fixture to reset shutdown_event_counter and original signal handlers before and after each test diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index b9d3a1ac0..1b2af39f1 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -18,12 +18,16 @@ import sys from collections.abc import Callable import pytest -import torch -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized -from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD -from tests.fixtures.constants import DUMMY_REPO_ID +pytest.importorskip("grpc") +pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") + +import torch # noqa: E402 + +from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: E402 +from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized # noqa: E402 +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD # noqa: E402 +from tests.fixtures.constants import DUMMY_REPO_ID # noqa: E402 def state_dims() -> list[str]: diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 4791caf58..8e5b3f167 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -17,6 +17,16 @@ from pathlib import Path from unittest.mock import Mock, patch +from lerobot.common.train_utils import ( + get_step_checkpoint_dir, + get_step_identifier, + load_training_state, + load_training_step, + save_checkpoint, + save_training_state, + save_training_step, + update_last_checkpoint, +) from lerobot.utils.constants import ( CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, @@ -27,16 +37,6 @@ from lerobot.utils.constants import ( TRAINING_STATE_DIR, TRAINING_STEP, ) -from lerobot.utils.train_utils import ( - get_step_checkpoint_dir, - get_step_identifier, - load_training_state, - load_training_step, - save_checkpoint, - save_training_state, - save_training_step, - update_last_checkpoint, -) def test_get_step_identifier(): @@ -72,7 +72,7 @@ def test_update_last_checkpoint(tmp_path): assert last_checkpoint.resolve() == checkpoint -@patch("lerobot.utils.train_utils.save_training_state") +@patch("lerobot.common.train_utils.save_training_state") def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer): policy = Mock() cfg = Mock() @@ -82,7 +82,7 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer): mock_save_training_state.assert_called_once() -@patch("lerobot.utils.train_utils.save_training_state") +@patch("lerobot.common.train_utils.save_training_state") def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer): policy = Mock() policy.config = Mock() diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index c8e5a92a8..63ff76c77 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -21,6 +21,8 @@ from types import SimpleNamespace import numpy as np import pytest +pytest.importorskip("rerun", reason="rerun-sdk is required (install lerobot[viz])") + from lerobot.types import TransitionKey from lerobot.utils.constants import OBS_STATE @@ -48,6 +50,9 @@ def mock_rerun(monkeypatch): calls.append((key, obj, kwargs)) dummy_rr = SimpleNamespace( + __name__="rerun", + __package__="rerun", + __spec__=SimpleNamespace(name="rerun", submodule_search_locations=None), Scalars=DummyScalar, Image=DummyImage, log=dummy_log, diff --git a/uv.lock b/uv.lock index d549938aa..a66f044ff 100644 --- a/uv.lock +++ b/uv.lock @@ -2,24 +2,33 @@ version = 1 revision = 2 requires-python = ">=3.12" resolution-markers = [ - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", + "(python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'armv7l' and sys_platform == 'linux')", + "(python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and platform_machine == 'armv7l' and sys_platform == 'linux')", + "(python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'armv7l' and sys_platform == 'linux')", + "(python_full_version >= '3.14' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "(python_full_version == '3.13.*' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "(python_full_version < '3.13' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "(python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32')", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32'", + "(python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32')", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "(python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32')", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", ] [[package]] @@ -820,7 +829,7 @@ name = "cuda-bindings" version = "12.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cuda-pathfinder", marker = "sys_platform == 'linux'" }, + { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, @@ -907,7 +916,7 @@ name = "decord" version = "0.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l') or sys_platform != 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/11/79/936af42edf90a7bd4e41a6cac89c913d4b47fa48a26b042d5129a9242ee3/decord-0.6.0-py3-none-manylinux2010_x86_64.whl", hash = "sha256:51997f20be8958e23b7c4061ba45d0efcd86bffd5fe81c695d0befee0d442976", size = 13602299, upload-time = "2021-06-14T21:30:55.486Z" }, @@ -1010,12 +1019,15 @@ name = "dm-tree" version = "0.1.9" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", - "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", + "(python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version >= '3.14' and platform_machine == 'armv7l' and sys_platform == 'linux')", + "(python_full_version >= '3.14' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "(python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and platform_machine != 's390x' and sys_platform == 'win32')", + "python_full_version >= '3.14' and platform_machine == 's390x' and sys_platform == 'win32'", ] dependencies = [ { name = "absl-py", marker = "python_full_version >= '3.14'" }, @@ -1043,18 +1055,24 @@ name = "dm-tree" version = "0.1.10" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and platform_machine != 's390x' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", - "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", - "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", - "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux'", + "(python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and platform_machine == 'armv7l' and sys_platform == 'linux')", + "(python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'arm64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'armv7l' and sys_platform == 'linux')", + "(python_full_version == '3.13.*' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", + "(python_full_version < '3.13' and platform_machine != 's390x' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine != 's390x' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'emscripten'", "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'emscripten'", + "(python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and platform_machine != 's390x' and sys_platform == 'win32')", + "python_full_version == '3.13.*' and platform_machine == 's390x' and sys_platform == 'win32'", + "(python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine != 's390x' and sys_platform == 'win32')", + "python_full_version < '3.13' and platform_machine == 's390x' and sys_platform == 'win32'", ] dependencies = [ { name = "absl-py", marker = "python_full_version < '3.14'" }, @@ -2187,37 +2205,33 @@ name = "lerobot" version = "0.5.2" source = { editable = "." } dependencies = [ - { name = "accelerate" }, - { name = "av" }, { name = "cmake" }, - { name = "datasets" }, - { name = "deepdiff" }, - { name = "diffusers" }, { name = "draccus" }, { name = "einops" }, { name = "gymnasium" }, { name = "huggingface-hub" }, - { name = "imageio", extra = ["ffmpeg"] }, - { name = "jsonlines" }, { name = "numpy" }, { name = "opencv-python-headless" }, { name = "packaging" }, - { name = "pynput" }, - { name = "pyserial" }, - { name = "rerun-sdk" }, + { name = "pillow" }, + { name = "requests" }, + { name = "safetensors" }, { name = "setuptools" }, { name = "termcolor" }, { name = "torch" }, - { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, { name = "torchvision" }, - { name = "wandb" }, + { name = "tqdm" }, ] [package.optional-dependencies] all = [ { name = "accelerate" }, + { name = "av" }, { name = "contourpy" }, + { name = "datasets" }, { name = "debugpy" }, + { name = "deepdiff" }, + { name = "diffusers" }, { name = "dynamixel-sdk" }, { name = "faker" }, { name = "fastapi" }, @@ -2230,6 +2244,7 @@ all = [ { name = "hebi-py" }, { name = "hf-libero", marker = "sys_platform == 'linux'" }, { name = "hidapi" }, + { name = "jsonlines" }, { name = "matplotlib" }, { name = "metaworld" }, { name = "mock-serial", marker = "sys_platform != 'win32'" }, @@ -2240,26 +2255,40 @@ all = [ { name = "placo" }, { name = "pre-commit" }, { name = "protobuf" }, + { name = "pyarrow" }, + { name = "pydantic" }, { name = "pygame" }, { name = "pymunk" }, + { name = "pynput" }, { name = "pyrealsense2", marker = "sys_platform != 'darwin'" }, { name = "pyrealsense2-macosx", marker = "sys_platform == 'darwin'" }, + { name = "pyserial" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-timeout" }, + { name = "python-can" }, { name = "pyzmq" }, { name = "qwen-vl-utils" }, { name = "reachy2-sdk" }, - { name = "safetensors" }, + { name = "rerun-sdk" }, + { name = "ruff" }, { name = "scikit-image" }, { name = "scipy" }, { name = "teleop" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, { name = "torchdiffeq" }, { name = "transformers" }, + { name = "wandb" }, ] aloha = [ + { name = "av" }, + { name = "datasets" }, { name = "gym-aloha" }, + { name = "jsonlines" }, + { name = "pandas" }, + { name = "pyarrow" }, { name = "scipy" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] async = [ { name = "contourpy" }, @@ -2267,12 +2296,44 @@ async = [ { name = "matplotlib" }, { name = "protobuf" }, ] +av-dep = [ + { name = "av" }, +] can-dep = [ { name = "python-can" }, ] +core-scripts = [ + { name = "av" }, + { name = "datasets" }, + { name = "deepdiff" }, + { name = "jsonlines" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "pynput" }, + { name = "pyserial" }, + { name = "rerun-sdk" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, +] damiao = [ { name = "python-can" }, ] +dataset = [ + { name = "av" }, + { name = "datasets" }, + { name = "jsonlines" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, +] +dataset-viz = [ + { name = "av" }, + { name = "datasets" }, + { name = "jsonlines" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "rerun-sdk" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, +] dev = [ { name = "debugpy" }, { name = "grpcio" }, @@ -2280,10 +2341,20 @@ dev = [ { name = "mypy" }, { name = "pre-commit" }, { name = "protobuf" }, + { name = "ruff" }, +] +diffusers-dep = [ + { name = "diffusers" }, +] +diffusion = [ + { name = "diffusers" }, ] dynamixel = [ { name = "dynamixel-sdk" }, ] +evaluation = [ + { name = "av" }, +] feetech = [ { name = "feetech-servo-sdk" }, ] @@ -2293,13 +2364,12 @@ gamepad = [ ] groot = [ { name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" }, + { name = "diffusers" }, { name = "dm-tree", version = "0.1.9", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, { name = "dm-tree", version = "0.1.10", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, { name = "flash-attn", marker = "sys_platform != 'darwin'" }, { name = "ninja" }, { name = "peft" }, - { name = "pillow" }, - { name = "safetensors" }, { name = "timm" }, { name = "transformers" }, ] @@ -2307,6 +2377,11 @@ grpcio-dep = [ { name = "grpcio" }, { name = "protobuf" }, ] +hardware = [ + { name = "deepdiff" }, + { name = "pynput" }, + { name = "pyserial" }, +] hilserl = [ { name = "grpcio" }, { name = "gym-hil" }, @@ -2330,8 +2405,14 @@ lekiwi = [ { name = "pyzmq" }, ] libero = [ + { name = "av" }, + { name = "datasets" }, { name = "hf-libero", marker = "sys_platform == 'linux'" }, + { name = "jsonlines" }, + { name = "pandas" }, + { name = "pyarrow" }, { name = "scipy" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, { name = "transformers" }, ] matplotlib-dep = [ @@ -2339,10 +2420,17 @@ matplotlib-dep = [ { name = "matplotlib" }, ] metaworld = [ + { name = "av" }, + { name = "datasets" }, + { name = "jsonlines" }, { name = "metaworld" }, + { name = "pandas" }, + { name = "pyarrow" }, { name = "scipy" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] multi-task-dit = [ + { name = "diffusers" }, { name = "transformers" }, ] openarms = [ @@ -2369,8 +2457,14 @@ placo-dep = [ { name = "placo" }, ] pusht = [ + { name = "av" }, + { name = "datasets" }, { name = "gym-pusht" }, + { name = "jsonlines" }, + { name = "pandas" }, + { name = "pyarrow" }, { name = "pymunk" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, ] pygame-dep = [ { name = "pygame" }, @@ -2388,6 +2482,7 @@ sarm = [ { name = "contourpy" }, { name = "faker" }, { name = "matplotlib" }, + { name = "pydantic" }, { name = "qwen-vl-utils" }, { name = "transformers" }, ] @@ -2397,7 +2492,6 @@ scipy-dep = [ smolvla = [ { name = "accelerate" }, { name = "num2words" }, - { name = "safetensors" }, { name = "transformers" }, ] test = [ @@ -2406,6 +2500,16 @@ test = [ { name = "pytest-cov" }, { name = "pytest-timeout" }, ] +training = [ + { name = "accelerate" }, + { name = "av" }, + { name = "datasets" }, + { name = "jsonlines" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')" }, + { name = "wandb" }, +] transformers-dep = [ { name = "transformers" }, ] @@ -2422,6 +2526,9 @@ video-benchmark = [ { name = "pandas" }, { name = "scikit-image" }, ] +viz = [ + { name = "rerun-sdk" }, +] wallx = [ { name = "peft" }, { name = "qwen-vl-utils" }, @@ -2435,16 +2542,16 @@ xvla = [ [package.metadata] requires-dist = [ - { name = "accelerate", specifier = ">=1.10.0,<2.0.0" }, { name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" }, - { name = "av", specifier = ">=15.0.0,<16.0.0" }, + { name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" }, + { name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" }, { name = "cmake", specifier = ">=3.29.0.1,<4.2.0" }, { name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" }, - { name = "datasets", specifier = ">=4.0.0,<5.0.0" }, + { name = "datasets", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" }, { name = "debugpy", marker = "extra == 'dev'", specifier = ">=1.8.1,<1.9.0" }, { name = "decord", marker = "(platform_machine == 'AMD64' and extra == 'groot') or (platform_machine == 'x86_64' and extra == 'groot')", specifier = ">=0.6.0,<1.0.0" }, - { name = "deepdiff", specifier = ">=7.0.1,<9.0.0" }, - { name = "diffusers", specifier = ">=0.27.2,<0.36.0" }, + { name = "deepdiff", marker = "extra == 'hardware'", specifier = ">=7.0.1,<9.0.0" }, + { name = "diffusers", marker = "extra == 'diffusers-dep'", specifier = ">=0.27.2,<0.36.0" }, { name = "dm-tree", marker = "extra == 'groot'", specifier = ">=0.1.8,<1.0.0" }, { name = "draccus", specifier = "==0.10.0" }, { name = "dynamixel-sdk", marker = "extra == 'dynamixel'", specifier = ">=3.7.31,<3.9.0" }, @@ -2463,21 +2570,38 @@ requires-dist = [ { name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" }, { name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" }, { name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" }, - { name = "imageio", extras = ["ffmpeg"], specifier = ">=2.34.0,<3.0.0" }, - { name = "jsonlines", specifier = ">=4.0.0,<5.0.0" }, + { name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" }, { name = "lerobot", extras = ["aloha"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["async"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["av-dep"], marker = "extra == 'dataset'" }, + { name = "lerobot", extras = ["av-dep"], marker = "extra == 'evaluation'" }, { name = "lerobot", extras = ["can-dep"], marker = "extra == 'damiao'" }, { name = "lerobot", extras = ["can-dep"], marker = "extra == 'robstride'" }, + { name = "lerobot", extras = ["damiao"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["damiao"], marker = "extra == 'openarms'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'aloha'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'core-scripts'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'dataset-viz'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'libero'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'metaworld'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'pusht'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'training'" }, { name = "lerobot", extras = ["dev"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'diffusion'" }, + { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'groot'" }, + { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" }, + { name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" }, { name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" }, { name = "lerobot", extras = ["gamepad"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'async'" }, { name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'dev'" }, { name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'hilserl'" }, + { name = "lerobot", extras = ["hardware"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["hardware"], marker = "extra == 'core-scripts'" }, { name = "lerobot", extras = ["hilserl"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["hopejr"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["intelrealsense"], marker = "extra == 'all'" }, @@ -2488,10 +2612,12 @@ requires-dist = [ { name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'sarm'" }, { name = "lerobot", extras = ["matplotlib-dep"], marker = "extra == 'unitree-g1'" }, { name = "lerobot", extras = ["metaworld"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["multi-task-dit"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["openarms"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["peft"], marker = "extra == 'all'" }, - { name = "lerobot", extras = ["peft"], marker = "extra == 'groot'" }, - { name = "lerobot", extras = ["peft"], marker = "extra == 'wallx'" }, + { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'groot'" }, { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'peft'" }, + { name = "lerobot", extras = ["peft-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["phone"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["pi"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["placo-dep"], marker = "extra == 'hilserl'" }, @@ -2503,6 +2629,7 @@ requires-dist = [ { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'sarm'" }, { name = "lerobot", extras = ["qwen-vl-utils-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["reachy2"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["robstride"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["sarm"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'aloha'" }, { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'libero'" }, @@ -2512,6 +2639,7 @@ requires-dist = [ { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["test"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["training"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'hilserl'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'libero'" }, @@ -2523,6 +2651,9 @@ requires-dist = [ { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'xvla'" }, { name = "lerobot", extras = ["video-benchmark"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["viz"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["viz"], marker = "extra == 'core-scripts'" }, + { name = "lerobot", extras = ["viz"], marker = "extra == 'dataset-viz'" }, { name = "lerobot", extras = ["wallx"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["xvla"], marker = "extra == 'all'" }, { name = "matplotlib", marker = "extra == 'matplotlib-dep'", specifier = ">=3.10.3,<4.0.0" }, @@ -2537,18 +2668,21 @@ requires-dist = [ { name = "onnxruntime", marker = "extra == 'unitree-g1'", specifier = ">=1.16.0,<2.0.0" }, { name = "opencv-python-headless", specifier = ">=4.9.0,<4.14.0" }, { name = "packaging", specifier = ">=24.2,<26.0" }, + { name = "pandas", marker = "extra == 'dataset'", specifier = ">=2.0.0,<3.0.0" }, { name = "pandas", marker = "extra == 'video-benchmark'", specifier = ">=2.2.2,<2.4.0" }, { name = "peft", marker = "extra == 'peft-dep'", specifier = ">=0.18.0,<1.0.0" }, - { name = "pillow", marker = "extra == 'groot'", specifier = ">=10.0.0,<13.0.0" }, + { name = "pillow", specifier = ">=10.0.0,<13.0.0" }, { name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.17" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" }, { name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" }, + { name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" }, + { name = "pydantic", marker = "extra == 'sarm'", specifier = ">=2.0.0,<3.0.0" }, { name = "pygame", marker = "extra == 'pygame-dep'", specifier = ">=2.5.1,<2.7.0" }, { name = "pymunk", marker = "extra == 'pusht'", specifier = ">=6.6.0,<7.0.0" }, - { name = "pynput", specifier = ">=1.7.8,<1.9.0" }, + { name = "pynput", marker = "extra == 'hardware'", specifier = ">=1.7.8,<1.9.0" }, { name = "pyrealsense2", marker = "sys_platform != 'darwin' and extra == 'intelrealsense'", specifier = ">=2.55.1.6486,<2.57.0" }, { name = "pyrealsense2-macosx", marker = "sys_platform == 'darwin' and extra == 'intelrealsense'", specifier = ">=2.54,<2.57.0" }, - { name = "pyserial", specifier = ">=3.5,<4.0" }, + { name = "pyserial", marker = "extra == 'hardware'", specifier = ">=3.5,<4.0" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.1.0,<9.0.0" }, { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=5.0.0,<8.0.0" }, { name = "pytest-timeout", marker = "extra == 'test'", specifier = ">=2.4.0,<3.0.0" }, @@ -2557,9 +2691,10 @@ requires-dist = [ { name = "pyzmq", marker = "extra == 'unitree-g1'", specifier = ">=26.2.1,<28.0.0" }, { name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" }, { name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" }, - { name = "rerun-sdk", specifier = ">=0.24.0,<0.27.0" }, - { name = "safetensors", marker = "extra == 'groot'", specifier = ">=0.4.3,<1.0.0" }, - { name = "safetensors", marker = "extra == 'smolvla'", specifier = ">=0.4.3,<1.0.0" }, + { name = "requests", specifier = ">=2.32.0,<3.0.0" }, + { name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" }, + { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.1" }, + { name = "safetensors", specifier = ">=0.4.3,<1.0.0" }, { name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" }, { name = "scipy", marker = "extra == 'all'", specifier = ">=1.14.0,<2.0.0" }, { name = "scipy", marker = "extra == 'scipy-dep'", specifier = ">=1.14.0,<2.0.0" }, @@ -2568,13 +2703,14 @@ requires-dist = [ { name = "termcolor", specifier = ">=2.4.0,<4.0.0" }, { name = "timm", marker = "extra == 'groot'", specifier = ">=1.0.0,<1.1.0" }, { name = "torch", specifier = ">=2.7,<2.11.0" }, - { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux') or (platform_machine != 'x86_64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32')", specifier = ">=0.3.0,<0.11.0" }, + { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine != 'x86_64' and sys_platform == 'darwin' and extra == 'dataset') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'dataset')", specifier = ">=0.3.0,<0.11.0" }, { name = "torchdiffeq", marker = "extra == 'wallx'", specifier = ">=0.2.4,<0.3.0" }, { name = "torchvision", specifier = ">=0.22.0,<0.26.0" }, + { name = "tqdm", specifier = ">=4.66.0,<5.0.0" }, { name = "transformers", marker = "extra == 'transformers-dep'", specifier = "==5.3.0" }, - { name = "wandb", specifier = ">=0.24.0,<0.25.0" }, + { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ] -provides-extras = ["pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "qwen-vl-utils-dep", "matplotlib-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "hilserl", "async", "peft", "dev", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] +provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "smolvla", "multi-task-dit", "groot", "sarm", "xvla", "hilserl", "async", "peft", "dev", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] [[package]] name = "librt" @@ -3359,7 +3495,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -3370,7 +3506,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -3397,9 +3533,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -3410,7 +3546,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -3677,7 +3813,7 @@ name = "pexpect" version = "4.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "ptyprocess", marker = "sys_platform != 'emscripten'" }, + { name = "ptyprocess", marker = "(platform_machine != 's390x' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" }, ] sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } wheels = [ @@ -4231,10 +4367,10 @@ name = "pyobjc-framework-applicationservices" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core", marker = "sys_platform != 'linux'" }, - { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'linux'" }, - { name = "pyobjc-framework-coretext", marker = "sys_platform != 'linux'" }, - { name = "pyobjc-framework-quartz", marker = "sys_platform != 'linux'" }, + { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-framework-coretext", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-framework-quartz", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/be/6a/d4e613c8e926a5744fc47a9e9fea08384a510dc4f27d844f7ad7a2d793bd/pyobjc_framework_applicationservices-12.1.tar.gz", hash = "sha256:c06abb74f119bc27aeb41bf1aef8102c0ae1288aec1ac8665ea186a067a8945b", size = 103247, upload-time = "2025-11-14T10:08:52.18Z" } wheels = [ @@ -4250,7 +4386,7 @@ name = "pyobjc-framework-cocoa" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core", marker = "sys_platform != 'linux'" }, + { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/02/a3/16ca9a15e77c061a9250afbae2eae26f2e1579eb8ca9462ae2d2c71e1169/pyobjc_framework_cocoa-12.1.tar.gz", hash = "sha256:5556c87db95711b985d5efdaaf01c917ddd41d148b1e52a0c66b1a2e2c5c1640", size = 2772191, upload-time = "2025-11-14T10:13:02.069Z" } wheels = [ @@ -4266,9 +4402,9 @@ name = "pyobjc-framework-coretext" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core", marker = "sys_platform != 'linux'" }, - { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'linux'" }, - { name = "pyobjc-framework-quartz", marker = "sys_platform != 'linux'" }, + { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-framework-quartz", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/29/da/682c9c92a39f713bd3c56e7375fa8f1b10ad558ecb075258ab6f1cdd4a6d/pyobjc_framework_coretext-12.1.tar.gz", hash = "sha256:e0adb717738fae395dc645c9e8a10bb5f6a4277e73cba8fa2a57f3b518e71da5", size = 90124, upload-time = "2025-11-14T10:14:38.596Z" } wheels = [ @@ -4284,8 +4420,8 @@ name = "pyobjc-framework-quartz" version = "12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pyobjc-core", marker = "sys_platform != 'linux'" }, - { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'linux'" }, + { name = "pyobjc-core", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, + { name = "pyobjc-framework-cocoa", marker = "sys_platform != 'emscripten' and sys_platform != 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/94/18/cc59f3d4355c9456fc945eae7fe8797003c4da99212dd531ad1b0de8a0c6/pyobjc_framework_quartz-12.1.tar.gz", hash = "sha256:27f782f3513ac88ec9b6c82d9767eef95a5cf4175ce88a1e5a65875fee799608", size = 3159099, upload-time = "2025-11-14T10:21:24.31Z" } wheels = [ @@ -4888,6 +5024,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/85/70/92482ccffb96f5441aab93e26c4d66489eb599efdcf96fad90c14bbfb976/rpds_py-0.30.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:dbd936cde57abfee19ab3213cf9c26be06d60750e60a8e4dd85d1ab12c8b1f40", size = 556030, upload-time = "2025-11-30T20:24:10.956Z" }, ] +[[package]] +name = "ruff" +version = "0.15.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/d9/aa3f7d59a10ef6b14fe3431706f854dbf03c5976be614a9796d36326810c/ruff-0.15.10.tar.gz", hash = "sha256:d1f86e67ebfdef88e00faefa1552b5e510e1d35f3be7d423dc7e84e63788c94e", size = 4631728, upload-time = "2026-04-09T14:06:09.884Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/00/a1c2fdc9939b2c03691edbda290afcd297f1f389196172826b03d6b6a595/ruff-0.15.10-py3-none-linux_armv6l.whl", hash = "sha256:0744e31482f8f7d0d10a11fcbf897af272fefdfcb10f5af907b18c2813ff4d5f", size = 10563362, upload-time = "2026-04-09T14:06:21.189Z" }, + { url = "https://files.pythonhosted.org/packages/5c/15/006990029aea0bebe9d33c73c3e28c80c391ebdba408d1b08496f00d422d/ruff-0.15.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b1e7c16ea0ff5a53b7c2df52d947e685973049be1cdfe2b59a9c43601897b22e", size = 10951122, upload-time = "2026-04-09T14:06:02.236Z" }, + { url = "https://files.pythonhosted.org/packages/f2/c0/4ac978fe874d0618c7da647862afe697b281c2806f13ce904ad652fa87e4/ruff-0.15.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:93cc06a19e5155b4441dd72808fdf84290d84ad8a39ca3b0f994363ade4cebb1", size = 10314005, upload-time = "2026-04-09T14:06:00.026Z" }, + { url = "https://files.pythonhosted.org/packages/da/73/c209138a5c98c0d321266372fc4e33ad43d506d7e5dd817dd89b60a8548f/ruff-0.15.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83e1dd04312997c99ea6965df66a14fb4f03ba978564574ffc68b0d61fd3989e", size = 10643450, upload-time = "2026-04-09T14:05:42.137Z" }, + { url = "https://files.pythonhosted.org/packages/ec/76/0deec355d8ec10709653635b1f90856735302cb8e149acfdf6f82a5feb70/ruff-0.15.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8154d43684e4333360fedd11aaa40b1b08a4e37d8ffa9d95fee6fa5b37b6fab1", size = 10379597, upload-time = "2026-04-09T14:05:49.984Z" }, + { url = "https://files.pythonhosted.org/packages/dc/be/86bba8fc8798c081e28a4b3bb6d143ccad3fd5f6f024f02002b8f08a9fa3/ruff-0.15.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ab88715f3a6deb6bde6c227f3a123410bec7b855c3ae331b4c006189e895cef", size = 11146645, upload-time = "2026-04-09T14:06:12.246Z" }, + { url = "https://files.pythonhosted.org/packages/a8/89/140025e65911b281c57be1d385ba1d932c2366ca88ae6663685aed8d4881/ruff-0.15.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a768ff5969b4f44c349d48edf4ab4f91eddb27fd9d77799598e130fb628aa158", size = 12030289, upload-time = "2026-04-09T14:06:04.776Z" }, + { url = "https://files.pythonhosted.org/packages/88/de/ddacca9545a5e01332567db01d44bd8cf725f2db3b3d61a80550b48308ea/ruff-0.15.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ee3ef42dab7078bda5ff6a1bcba8539e9857deb447132ad5566a038674540d0", size = 11496266, upload-time = "2026-04-09T14:05:55.485Z" }, + { url = "https://files.pythonhosted.org/packages/bc/bb/7ddb00a83760ff4a83c4e2fc231fd63937cc7317c10c82f583302e0f6586/ruff-0.15.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51cb8cc943e891ba99989dd92d61e29b1d231e14811db9be6440ecf25d5c1609", size = 11256418, upload-time = "2026-04-09T14:05:57.69Z" }, + { url = "https://files.pythonhosted.org/packages/dc/8d/55de0d35aacf6cd50b6ee91ee0f291672080021896543776f4170fc5c454/ruff-0.15.10-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:e59c9bdc056a320fb9ea1700a8d591718b8faf78af065484e801258d3a76bc3f", size = 11288416, upload-time = "2026-04-09T14:05:44.695Z" }, + { url = "https://files.pythonhosted.org/packages/68/cf/9438b1a27426ec46a80e0a718093c7f958ef72f43eb3111862949ead3cc1/ruff-0.15.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:136c00ca2f47b0018b073f28cb5c1506642a830ea941a60354b0e8bc8076b151", size = 10621053, upload-time = "2026-04-09T14:05:52.782Z" }, + { url = "https://files.pythonhosted.org/packages/4c/50/e29be6e2c135e9cd4cb15fbade49d6a2717e009dff3766dd080fcb82e251/ruff-0.15.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8b80a2f3c9c8a950d6237f2ca12b206bccff626139be9fa005f14feb881a1ae8", size = 10378302, upload-time = "2026-04-09T14:06:14.361Z" }, + { url = "https://files.pythonhosted.org/packages/18/2f/e0b36a6f99c51bb89f3a30239bc7bf97e87a37ae80aa2d6542d6e5150364/ruff-0.15.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:e3e53c588164dc025b671c9df2462429d60357ea91af7e92e9d56c565a9f1b07", size = 10850074, upload-time = "2026-04-09T14:06:16.581Z" }, + { url = "https://files.pythonhosted.org/packages/11/08/874da392558ce087a0f9b709dc6ec0d60cbc694c1c772dab8d5f31efe8cb/ruff-0.15.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b0c52744cf9f143a393e284125d2576140b68264a93c6716464e129a3e9adb48", size = 11358051, upload-time = "2026-04-09T14:06:18.948Z" }, + { url = "https://files.pythonhosted.org/packages/e4/46/602938f030adfa043e67112b73821024dc79f3ab4df5474c25fa4c1d2d14/ruff-0.15.10-py3-none-win32.whl", hash = "sha256:d4272e87e801e9a27a2e8df7b21011c909d9ddd82f4f3281d269b6ba19789ca5", size = 10588964, upload-time = "2026-04-09T14:06:07.14Z" }, + { url = "https://files.pythonhosted.org/packages/25/b6/261225b875d7a13b33a6d02508c39c28450b2041bb01d0f7f1a83d569512/ruff-0.15.10-py3-none-win_amd64.whl", hash = "sha256:28cb32d53203242d403d819fd6983152489b12e4a3ae44993543d6fe62ab42ed", size = 11745044, upload-time = "2026-04-09T14:05:39.473Z" }, + { url = "https://files.pythonhosted.org/packages/58/ed/dea90a65b7d9e69888890fb14c90d7f51bf0c1e82ad800aeb0160e4bacfd/ruff-0.15.10-py3-none-win_arm64.whl", hash = "sha256:601d1610a9e1f1c2165a4f561eeaa2e2ea1e97f3287c5aa258d3dab8b57c6188", size = 11035607, upload-time = "2026-04-09T14:05:47.593Z" }, +] + [[package]] name = "safetensors" version = "0.7.0"