From a07b1d76f17b359f55ffcfceed1a3b532448ab25 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 15 Mar 2026 20:26:06 -0700 Subject: [PATCH] chore(dependecies): untangle dependecies across internal modules (#3149) --- examples/phone_to_so100/evaluate.py | 3 +- examples/phone_to_so100/record.py | 3 +- examples/phone_to_so100/replay.py | 3 +- examples/phone_to_so100/teleoperate.py | 3 +- examples/so100_to_so100_EE/evaluate.py | 3 +- examples/so100_to_so100_EE/record.py | 3 +- examples/so100_to_so100_EE/replay.py | 3 +- examples/so100_to_so100_EE/teleoperate.py | 3 +- src/lerobot/async_inference/policy_server.py | 6 +- src/lerobot/configs/policies.py | 2 +- src/lerobot/datasets/pipeline_features.py | 3 +- src/lerobot/envs/libero.py | 2 +- src/lerobot/envs/metaworld.py | 2 +- src/lerobot/envs/utils.py | 2 +- src/lerobot/policies/factory.py | 3 +- src/lerobot/policies/groot/processor_groot.py | 2 +- src/lerobot/policies/pi05/processor_pi05.py | 2 +- .../policies/pi0_fast/processor_pi0_fast.py | 2 +- src/lerobot/policies/sarm/processor_sarm.py | 2 +- .../policies/smolvla/modeling_smolvla.py | 2 +- src/lerobot/policies/utils.py | 2 +- src/lerobot/policies/xvla/processor_xvla.py | 2 +- src/lerobot/processor/__init__.py | 15 +-- src/lerobot/processor/batch_processor.py | 2 +- src/lerobot/processor/converters.py | 3 +- .../processor/delta_action_processor.py | 2 +- src/lerobot/processor/device_processor.py | 4 +- src/lerobot/processor/factory.py | 3 +- src/lerobot/processor/gym_action_processor.py | 4 +- src/lerobot/processor/hil_processor.py | 3 +- src/lerobot/processor/normalize_processor.py | 2 +- src/lerobot/processor/pipeline.py | 2 +- src/lerobot/processor/tokenizer_processor.py | 2 +- src/lerobot/rl/actor.py | 4 +- src/lerobot/rl/learner.py | 2 +- .../bi_openarm_follower.py | 2 +- .../robots/bi_so_follower/bi_so_follower.py | 2 +- .../robot_earthrover_mini_plus.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_arm.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_hand.py | 2 +- .../robots/koch_follower/koch_follower.py | 2 +- src/lerobot/robots/lekiwi/lekiwi.py | 2 +- src/lerobot/robots/lekiwi/lekiwi_client.py | 2 +- .../robots/omx_follower/omx_follower.py | 2 +- .../openarm_follower/openarm_follower.py | 2 +- src/lerobot/robots/reachy2/robot_reachy2.py | 2 +- src/lerobot/robots/robot.py | 2 +- src/lerobot/robots/so_follower/so_follower.py | 2 +- src/lerobot/robots/unitree_g1/unitree_g1.py | 5 +- src/lerobot/scripts/lerobot_eval.py | 5 +- src/lerobot/scripts/lerobot_record.py | 2 +- .../bi_openarm_leader/bi_openarm_leader.py | 2 +- .../teleoperators/gamepad/teleop_gamepad.py | 2 +- .../teleoperators/keyboard/teleop_keyboard.py | 2 +- .../openarm_leader/openarm_leader.py | 2 +- .../openarm_mini/openarm_mini.py | 2 +- .../teleoperators/phone/phone_processor.py | 3 +- src/lerobot/teleoperators/teleoperator.py | 2 +- src/lerobot/{processor/core.py => types.py} | 0 src/lerobot/utils/control_utils.py | 3 +- src/lerobot/utils/device_utils.py | 109 ++++++++++++++++++ src/lerobot/utils/utils.py | 103 ++--------------- src/lerobot/utils/visualization_utils.py | 2 +- tests/mocks/mock_robot.py | 2 +- tests/mocks/mock_teleop.py | 2 +- tests/policies/groot/test_groot_lerobot.py | 5 +- .../policies/groot/test_groot_vs_original.py | 3 +- .../test_pi0_fast_original_vs_lerobot.py | 3 +- .../pi0_pi05/test_pi05_original_vs_lerobot.py | 3 +- .../pi0_pi05/test_pi0_original_vs_lerobot.py | 3 +- tests/policies/test_sarm_processor.py | 2 +- .../xvla/test_xvla_original_vs_lerobot.py | 3 +- tests/processor/test_batch_conversion.py | 3 +- tests/processor/test_converters.py | 2 +- tests/processor/test_device_processor.py | 3 +- tests/processor/test_normalize_processor.py | 2 +- tests/processor/test_observation_processor.py | 3 +- tests/processor/test_tokenizer_processor.py | 3 +- tests/training/test_visual_validation.py | 2 +- tests/utils.py | 2 +- tests/utils/test_visualization_utils.py | 2 +- 81 files changed, 235 insertions(+), 189 deletions(-) rename src/lerobot/{processor/core.py => types.py} (100%) create mode 100644 src/lerobot/utils/device_utils.py diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 837217eda..c1291d101 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -23,8 +23,6 @@ from lerobot.model.kinematics import RobotKinematics from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import ( - RobotAction, - RobotObservation, RobotProcessorPipeline, make_default_teleop_action_processor, ) @@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.scripts.lerobot_record import record_loop +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 1f5005db9..756c6f42d 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -19,7 +19,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( observation_to_transition, robot_action_observation_to_transition, @@ -38,6 +38,7 @@ from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 9d7806cf4..7b955cdb7 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -18,7 +18,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, @@ -27,6 +27,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py index 6eaaec806..7242c39ce 100644 --- a/examples/phone_to_so100/teleoperate.py +++ b/examples/phone_to_so100/teleoperate.py @@ -16,7 +16,7 @@ import time from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, @@ -31,6 +31,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index b614b89f2..45a87ebad 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -23,8 +23,6 @@ from lerobot.model.kinematics import RobotKinematics from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import ( - RobotAction, - RobotObservation, RobotProcessorPipeline, make_default_teleop_action_processor, ) @@ -40,6 +38,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.scripts.lerobot_record import record_loop +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index d85a1c5cc..8fa862d6e 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -20,7 +20,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.utils import combine_feature_dicts from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( observation_to_transition, robot_action_observation_to_transition, @@ -35,6 +35,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( ) from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 47a2f6635..b042e02dd 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -19,7 +19,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, @@ -28,6 +28,7 @@ from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py index 71d2899de..af21f079b 100644 --- a/examples/so100_to_so100_EE/teleoperate.py +++ b/examples/so100_to_so100_EE/teleoperate.py @@ -17,7 +17,7 @@ import time from lerobot.model.kinematics import RobotKinematics -from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor import RobotProcessorPipeline from lerobot.processor.converters import ( robot_action_observation_to_transition, robot_action_to_transition, @@ -30,6 +30,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index aedce2a74..3f63929df 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -39,15 +39,13 @@ import grpc import torch from lerobot.policies.factory import get_policy_class, make_pre_post_processors -from lerobot.processor import ( - PolicyAction, - PolicyProcessorPipeline, -) +from lerobot.processor import PolicyProcessorPipeline from lerobot.transport import ( services_pb2, # type: ignore services_pb2_grpc, # type: ignore ) from lerobot.transport.utils import receive_bytes_in_chunks +from lerobot.types import PolicyAction from .configs import PolicyServerConfig from .constants import SUPPORTED_POLICIES diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 44b013c29..ce567b8f5 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -30,8 +30,8 @@ from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.device_utils import auto_select_torch_device, is_amp_available, is_torch_device_available from lerobot.utils.hub import HubMixin -from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available T = TypeVar("T", bound="PreTrainedConfig") logger = getLogger(__name__) diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index 161633f26..f824eb9bc 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -18,7 +18,8 @@ from typing import Any from lerobot.configs.types import PipelineFeatureType from lerobot.datasets.utils import hw_to_dataset_features -from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation +from lerobot.processor import DataProcessorPipeline +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index d20dae8ea..6d3589fed 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -29,7 +29,7 @@ from gymnasium import spaces from libero.libero import benchmark, get_libero_path from libero.libero.envs import OffScreenRenderEnv -from lerobot.processor import RobotObservation +from lerobot.types import RobotObservation def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py index 4d91e002d..e9e29f304 100644 --- a/src/lerobot/envs/metaworld.py +++ b/src/lerobot/envs/metaworld.py @@ -25,7 +25,7 @@ import metaworld.policies as policies import numpy as np from gymnasium import spaces -from lerobot.processor import RobotObservation +from lerobot.types import RobotObservation # ---- Load configuration data from the external JSON file ---- CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 09431a18d..fd17a6762 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -29,7 +29,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.envs.configs import EnvConfig -from lerobot.processor import RobotObservation +from lerobot.types import RobotObservation from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import get_channel_first_image_shape diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index d50d8652a..9515d5b82 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -43,13 +43,14 @@ from lerobot.policies.utils import validate_visual_features_consistency from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.policies.wall_x.configuration_wall_x import WallXConfig from lerobot.policies.xvla.configuration_xvla import XVLAConfig -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline from lerobot.processor.converters import ( batch_to_transition, policy_action_to_transition, transition_to_batch, transition_to_policy_action, ) +from lerobot.types import PolicyAction from lerobot.utils.constants import ( ACTION, POLICY_POSTPROCESSOR_DEFAULT_NAME, diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 14149cf2f..8bf9dabca 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -49,7 +49,7 @@ from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, ) -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( ACTION, HF_LEROBOT_HOME, diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py index 6e01a4e16..425a85577 100644 --- a/src/lerobot/policies/pi05/processor_pi05.py +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -36,7 +36,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py index fde7d5c80..46e54432a 100644 --- a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py @@ -37,7 +37,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, diff --git a/src/lerobot/policies/sarm/processor_sarm.py b/src/lerobot/policies/sarm/processor_sarm.py index 8f2bc23db..f377a7ffa 100644 --- a/src/lerobot/policies/sarm/processor_sarm.py +++ b/src/lerobot/policies/sarm/processor_sarm.py @@ -48,8 +48,8 @@ from lerobot.processor.converters import ( policy_action_to_transition, transition_to_policy_action, ) -from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.processor.pipeline import PipelineFeatureType +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 32165eba8..7110ba7d2 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -68,7 +68,7 @@ from lerobot.policies.utils import ( populate_queues, ) from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE -from lerobot.utils.utils import get_safe_dtype +from lerobot.utils.device_utils import get_safe_dtype class ActionSelectKwargs(TypedDict, total=False): diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 1a14b2925..9ad5dac4a 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -24,7 +24,7 @@ from torch import nn from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.datasets.utils import build_dataset_frame -from lerobot.processor import PolicyAction, RobotAction, RobotObservation +from lerobot.types import PolicyAction, RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STR diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index c4e3f2d6f..0fa9ffe3f 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -38,7 +38,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action -from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.types import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_IMAGES, OBS_PREFIX, diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 0b63e1606..12dcf0c6d 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -14,13 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .batch_processor import AddBatchDimensionProcessorStep -from .converters import ( - batch_to_transition, - create_transition, - transition_to_batch, -) -from .core import ( +from lerobot.types import ( EnvAction, EnvTransition, PolicyAction, @@ -28,6 +22,13 @@ from .core import ( RobotObservation, TransitionKey, ) + +from .batch_processor import AddBatchDimensionProcessorStep +from .converters import ( + batch_to_transition, + create_transition, + transition_to_batch, +) from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep from .device_processor import DeviceProcessorStep from .factory import ( diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index e1a90421f..c904acf84 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -25,9 +25,9 @@ from dataclasses import dataclass, field from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvTransition, PolicyAction from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from .core import EnvTransition, PolicyAction from .pipeline import ( ComplementaryDataProcessorStep, ObservationProcessorStep, diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 18c7b0220..ffdf0098c 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,10 +23,9 @@ from typing import Any import numpy as np import torch +from lerobot.types import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED -from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey - @singledispatch def to_tensor( diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index a8395637c..f7f5676ac 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -17,8 +17,8 @@ from dataclasses import dataclass from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.types import PolicyAction, RobotAction -from .core import PolicyAction, RobotAction from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 2d0dd0880..36c80e58e 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -25,9 +25,9 @@ from typing import Any import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.utils import get_safe_torch_device +from lerobot.types import EnvTransition, PolicyAction, TransitionKey +from lerobot.utils.device_utils import get_safe_torch_device -from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import ProcessorStep, ProcessorStepRegistry diff --git a/src/lerobot/processor/factory.py b/src/lerobot/processor/factory.py index 5a0c41072..5028122f1 100644 --- a/src/lerobot/processor/factory.py +++ b/src/lerobot/processor/factory.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lerobot.types import RobotAction, RobotObservation + from .converters import ( observation_to_transition, robot_action_observation_to_transition, transition_to_observation, transition_to_robot_action, ) -from .core import RobotAction, RobotObservation from .pipeline import IdentityProcessorStep, RobotProcessorPipeline diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index 4f225af92..e756ded7f 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -17,9 +17,9 @@ from dataclasses import dataclass from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvAction, EnvTransition, PolicyAction from .converters import to_tensor -from .core import EnvAction, EnvTransition, PolicyAction from .hil_processor import TELEOP_ACTION_KEY from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry @@ -75,7 +75,7 @@ class Numpy2TorchActionProcessorStep(ProcessorStep): def __call__(self, transition: EnvTransition) -> EnvTransition: """Converts numpy action to torch tensor if action exists, otherwise passes through.""" - from .core import TransitionKey + from lerobot.types import TransitionKey self._current_transition = transition.copy() new_transition = self._current_transition diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 34eaeed51..0b8521c2b 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -30,7 +30,8 @@ from lerobot.teleoperators.utils import TeleopEvents if TYPE_CHECKING: from lerobot.teleoperators.teleoperator import Teleoperator -from .core import EnvTransition, PolicyAction, TransitionKey +from lerobot.types import EnvTransition, PolicyAction, TransitionKey + from .pipeline import ( ComplementaryDataProcessorStep, InfoProcessorStep, diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 4769b91ac..8a7a1176a 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -26,10 +26,10 @@ from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.types import EnvTransition, PolicyAction, TransitionKey from lerobot.utils.constants import ACTION from .converters import from_tensor_to_numpy, to_tensor -from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index db1c3015c..abfb31421 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -46,10 +46,10 @@ from huggingface_hub import hf_hub_download from safetensors.torch import load_file, save_file from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.types import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey from lerobot.utils.hub import HubMixin from .converters import batch_to_transition, create_transition, transition_to_batch -from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey # Generic type variables for pipeline input and output. TInput = TypeVar("TInput") diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index da6e600af..2a972ecc8 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.types import EnvTransition, RobotObservation, TransitionKey from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, @@ -40,7 +41,6 @@ from lerobot.utils.constants import ( ) from lerobot.utils.import_utils import _transformers_available -from .core import EnvTransition, RobotObservation, TransitionKey from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry # Conditional import for type checking and lazy loading diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 7427633d2..18c0ca1ea 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -62,7 +62,6 @@ from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy -from lerobot.processor import TransitionKey from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.queue import get_last_item_from_queue from lerobot.robots import so_follower # noqa: F401 @@ -77,6 +76,8 @@ from lerobot.transport.utils import ( send_bytes_in_chunks, transitions_to_bytes, ) +from lerobot.types import TransitionKey +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.transition import ( @@ -86,7 +87,6 @@ from lerobot.utils.transition import ( ) from lerobot.utils.utils import ( TimerManager, - get_safe_torch_device, init_logging, ) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index ee09ac9ac..2853fbcb3 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -86,6 +86,7 @@ from lerobot.utils.constants import ( PRETRAINED_MODEL_DIR, TRAINING_STATE_DIR, ) +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( get_step_checkpoint_dir, @@ -96,7 +97,6 @@ from lerobot.utils.train_utils import ( from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.utils.utils import ( format_big_number, - get_safe_torch_device, init_logging, ) diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py index 2e3885e67..7f5e92271 100644 --- a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -17,8 +17,8 @@ import logging from functools import cached_property -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index 28c58b898..ba1826e29 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -17,8 +17,8 @@ import logging from functools import cached_property -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py index cdf6efde1..299206a1e 100644 --- a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py +++ b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py @@ -23,7 +23,7 @@ import cv2 import numpy as np import requests -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.errors import DeviceNotConnectedError diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index e8269ae46..7f6492ef0 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index a05c4bbcb..784804836 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -24,7 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index 53a32beed..44e83f6a3 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -24,7 +24,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 9d11a000f..60fac89e5 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -28,7 +28,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 1d5ea64a6..fd43e84fe 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -22,7 +22,7 @@ from functools import cached_property import cv2 import numpy as np -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.errors import DeviceNotConnectedError diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index e0b612c60..5d161daa2 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index c865f1ec1..99e8b920b 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -22,7 +22,7 @@ from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index fb466f85b..5227a096a 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -19,7 +19,7 @@ import time from typing import TYPE_CHECKING, Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.import_utils import _reachy2_sdk_available from ..robot import Robot diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index d165886b9..1b556f963 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -19,7 +19,7 @@ from pathlib import Path import draccus from lerobot.motors import MotorCalibration -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS from .config import RobotConfig diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index c898e9137..ca132d102 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -24,7 +24,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 41146ebe6..9e373c05f 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -26,8 +26,6 @@ from typing import TYPE_CHECKING, Protocol, runtime_checkable import numpy as np from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.envs.factory import make_env -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.unitree_g1.g1_kinematics import G1_29_ArmIK from lerobot.robots.unitree_g1.g1_utils import ( REMOTE_AXES, @@ -37,6 +35,7 @@ from lerobot.robots.unitree_g1.g1_utils import ( default_remote_input, make_locomotion_controller, ) +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.import_utils import _unitree_sdk_available from ..robot import Robot @@ -291,6 +290,8 @@ class UnitreeG1(Robot): def connect(self, calibrate: bool = True) -> None: # connect to DDS # Initialize DDS channel and simulation environment if self.config.is_simulation: + from lerobot.envs.factory import make_env + self._ChannelFactoryInitialize(0, "lo") self._env_wrapper = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) # Extract the actual gym env from the dict structure diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index e32b80404..6d814f498 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -80,13 +80,14 @@ from lerobot.envs.utils import ( ) from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import PolicyAction from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( - get_safe_torch_device, init_logging, inside_slurm, ) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index dc682fe6f..345d18f23 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -139,10 +139,10 @@ from lerobot.utils.control_utils import ( sanity_check_dataset_name, sanity_check_dataset_robot_compatibility, ) +from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import ( - get_safe_torch_device, init_logging, log_say, ) diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py index 74b0c9b83..b44f1fbea 100644 --- a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -17,8 +17,8 @@ import logging from functools import cached_property -from lerobot.processor import RobotAction from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..openarm_leader import OpenArmLeader diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index 69cb0f971..8c1796e45 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -20,7 +20,7 @@ from typing import Any import numpy as np -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 919f463d3..6c1ef7492 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -21,7 +21,7 @@ import time from queue import Queue from typing import Any -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py index d9eaabe0f..65da7416a 100644 --- a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -20,7 +20,7 @@ from typing import Any from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/openarm_mini/openarm_mini.py b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py index 3fbcecf24..23594caa9 100644 --- a/src/lerobot/teleoperators/openarm_mini/openarm_mini.py +++ b/src/lerobot/teleoperators/openarm_mini/openarm_mini.py @@ -23,7 +23,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py index 67e64c7d5..c498bed7d 100644 --- a/src/lerobot/teleoperators/phone/phone_processor.py +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -17,8 +17,9 @@ from dataclasses import dataclass, field from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import ProcessorStepRegistry, RobotAction, RobotActionProcessorStep +from lerobot.processor import ProcessorStepRegistry, RobotActionProcessorStep from lerobot.teleoperators.phone.config_phone import PhoneOS +from lerobot.types import RobotAction @ProcessorStepRegistry.register("map_phone_action_to_robot_action") diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index 847b88b7f..f47904423 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -20,7 +20,7 @@ from typing import Any import draccus from lerobot.motors.motors_bus import MotorCalibration -from lerobot.processor import RobotAction +from lerobot.types import RobotAction from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS from .config import TeleoperatorConfig diff --git a/src/lerobot/processor/core.py b/src/lerobot/types.py similarity index 100% rename from src/lerobot/processor/core.py rename to src/lerobot/types.py diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 7c605af17..94cd82fa1 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -32,8 +32,9 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import prepare_observation_for_inference -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline from lerobot.robots import Robot +from lerobot.types import PolicyAction @cache diff --git a/src/lerobot/utils/device_utils.py b/src/lerobot/utils/device_utils.py new file mode 100644 index 000000000..37981f07f --- /dev/null +++ b/src/lerobot/utils/device_utils.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch + + +def auto_select_torch_device() -> torch.device: + """Tries to select automatically a torch device.""" + if torch.cuda.is_available(): + logging.info("Cuda backend detected, using cuda.") + return torch.device("cuda") + elif torch.backends.mps.is_available(): + logging.info("Metal backend detected, using mps.") + return torch.device("mps") + elif torch.xpu.is_available(): + logging.info("Intel XPU backend detected, using xpu.") + return torch.device("xpu") + else: + logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") + return torch.device("cpu") + + +# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level +def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: + """Given a string, return a torch.device with checks on whether the device is available.""" + try_device = str(try_device) + if try_device.startswith("cuda"): + assert torch.cuda.is_available() + device = torch.device(try_device) + elif try_device == "mps": + assert torch.backends.mps.is_available() + device = torch.device("mps") + elif try_device == "xpu": + assert torch.xpu.is_available() + device = torch.device("xpu") + elif try_device == "cpu": + device = torch.device("cpu") + if log: + logging.warning("Using CPU, this will be slow.") + else: + device = torch.device(try_device) + if log: + logging.warning(f"Using custom {try_device} device.") + return device + + +def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): + """ + mps is currently not compatible with float64 + """ + if isinstance(device, torch.device): + device = device.type + if device == "mps" and dtype == torch.float64: + return torch.float32 + if device == "xpu" and dtype == torch.float64: + if hasattr(torch.xpu, "get_device_capability"): + device_capability = torch.xpu.get_device_capability() + # NOTE: Some Intel XPU devices do not support double precision (FP64). + # The `has_fp64` flag is returned by `torch.xpu.get_device_capability()` + # when available; if False, we fall back to float32 for compatibility. + if not device_capability.get("has_fp64", False): + logging.warning(f"Device {device} does not support float64, using float32 instead.") + return torch.float32 + else: + logging.warning( + f"Device {device} capability check failed. Assuming no support for float64, using float32 instead." + ) + return torch.float32 + return dtype + else: + return dtype + + +def is_torch_device_available(try_device: str) -> bool: + try_device = str(try_device) # Ensure try_device is a string + if try_device.startswith("cuda"): + return torch.cuda.is_available() + elif try_device == "mps": + return torch.backends.mps.is_available() + elif try_device == "xpu": + return torch.xpu.is_available() + elif try_device == "cpu": + return True + else: + raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.") + + +def is_amp_available(device: str): + if device in ["cuda", "xpu", "cpu"]: + return True + elif device == "mps": + return False + else: + raise ValueError(f"Unknown device '{device}.") diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index c7ad2bbdb..b9f8441d6 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging import os import platform @@ -24,11 +26,12 @@ from copy import copy, deepcopy from datetime import datetime from pathlib import Path from statistics import mean +from typing import TYPE_CHECKING import numpy as np -import torch -from accelerate import Accelerator -from datasets.utils.logging import disable_progress_bar, enable_progress_bar + +if TYPE_CHECKING: + from accelerate import Accelerator def inside_slurm(): @@ -37,96 +40,6 @@ def inside_slurm(): return "SLURM_JOB_ID" in os.environ -def auto_select_torch_device() -> torch.device: - """Tries to select automatically a torch device.""" - if torch.cuda.is_available(): - logging.info("Cuda backend detected, using cuda.") - return torch.device("cuda") - elif torch.backends.mps.is_available(): - logging.info("Metal backend detected, using mps.") - return torch.device("mps") - elif torch.xpu.is_available(): - logging.info("Intel XPU backend detected, using xpu.") - return torch.device("xpu") - else: - logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") - return torch.device("cpu") - - -# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level -def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: - """Given a string, return a torch.device with checks on whether the device is available.""" - try_device = str(try_device) - if try_device.startswith("cuda"): - assert torch.cuda.is_available() - device = torch.device(try_device) - elif try_device == "mps": - assert torch.backends.mps.is_available() - device = torch.device("mps") - elif try_device == "xpu": - assert torch.xpu.is_available() - device = torch.device("xpu") - elif try_device == "cpu": - device = torch.device("cpu") - if log: - logging.warning("Using CPU, this will be slow.") - else: - device = torch.device(try_device) - if log: - logging.warning(f"Using custom {try_device} device.") - return device - - -def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): - """ - mps is currently not compatible with float64 - """ - if isinstance(device, torch.device): - device = device.type - if device == "mps" and dtype == torch.float64: - return torch.float32 - if device == "xpu" and dtype == torch.float64: - if hasattr(torch.xpu, "get_device_capability"): - device_capability = torch.xpu.get_device_capability() - # NOTE: Some Intel XPU devices do not support double precision (FP64). - # The `has_fp64` flag is returned by `torch.xpu.get_device_capability()` - # when available; if False, we fall back to float32 for compatibility. - if not device_capability.get("has_fp64", False): - logging.warning(f"Device {device} does not support float64, using float32 instead.") - return torch.float32 - else: - logging.warning( - f"Device {device} capability check failed. Assuming no support for float64, using float32 instead." - ) - return torch.float32 - return dtype - else: - return dtype - - -def is_torch_device_available(try_device: str) -> bool: - try_device = str(try_device) # Ensure try_device is a string - if try_device.startswith("cuda"): - return torch.cuda.is_available() - elif try_device == "mps": - return torch.backends.mps.is_available() - elif try_device == "xpu": - return torch.xpu.is_available() - elif try_device == "cpu": - return True - else: - raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.") - - -def is_amp_available(device: str): - if device in ["cuda", "xpu", "cpu"]: - return True - elif device == "mps": - return False - else: - raise ValueError(f"Unknown device '{device}.") - - def init_logging( log_file: Path | None = None, display_pid: bool = False, @@ -297,9 +210,13 @@ class SuppressProgressBars: """ def __enter__(self): + from datasets.utils.logging import disable_progress_bar + disable_progress_bar() def __exit__(self, exc_type, exc_val, exc_tb): + from datasets.utils.logging import enable_progress_bar + enable_progress_bar() diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 31ca8d247..782358c9e 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -18,7 +18,7 @@ import os import numpy as np import rerun as rr -from lerobot.processor import RobotAction, RobotObservation +from lerobot.types import RobotAction, RobotObservation from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index f69a2c02a..5504b30bf 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -20,8 +20,8 @@ from functools import cached_property from lerobot.cameras import CameraConfig, make_cameras_from_configs from lerobot.motors.motors_bus import Motor, MotorNormMode -from lerobot.processor import RobotAction, RobotObservation from lerobot.robots import Robot, RobotConfig +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from tests.mocks.mock_motors_bus import MockMotorsBus diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index 89174dadf..b84b2b891 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -19,8 +19,8 @@ from dataclasses import dataclass from functools import cached_property from typing import Any -from lerobot.processor import RobotAction from lerobot.teleoperators import Teleoperator, TeleoperatorConfig +from lerobot.types import RobotAction from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected diff --git a/tests/policies/groot/test_groot_lerobot.py b/tests/policies/groot/test_groot_lerobot.py index 760f13a5f..e299a34e2 100644 --- a/tests/policies/groot/test_groot_lerobot.py +++ b/tests/policies/groot/test_groot_lerobot.py @@ -28,8 +28,9 @@ import torch from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.modeling_groot import GrootPolicy from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline -from lerobot.utils.utils import auto_select_torch_device +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import PolicyAction +from lerobot.utils.device_utils import auto_select_torch_device from tests.utils import require_cuda # noqa: E402 pytest.importorskip("transformers") diff --git a/tests/policies/groot/test_groot_vs_original.py b/tests/policies/groot/test_groot_vs_original.py index e9dd1df00..0adad96ca 100644 --- a/tests/policies/groot/test_groot_vs_original.py +++ b/tests/policies/groot/test_groot_vs_original.py @@ -28,7 +28,8 @@ import torch from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.modeling_groot import GrootPolicy from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor import PolicyProcessorPipeline +from lerobot.types import PolicyAction pytest.importorskip("gr00t") pytest.importorskip("transformers") diff --git a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py index d24bb11d7..b757d5a94 100644 --- a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py +++ b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py @@ -31,7 +31,8 @@ pytest.importorskip("scipy") from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy from lerobot.policies.pi0_fast.processor_pi0_fast import make_pi0_fast_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, diff --git a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py index f70707262..a965132b0 100644 --- a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py @@ -42,7 +42,8 @@ from transformers import AutoTokenizer # noqa: E402 from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402 from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402 -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 # TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG DUMMY_ACTION_DIM = 32 diff --git a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py index d3d1c1908..62e34b70d 100644 --- a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py @@ -41,7 +41,8 @@ from transformers import AutoTokenizer # noqa: E402 from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402 from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402 -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 # TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG DUMMY_ACTION_DIM = 32 diff --git a/tests/policies/test_sarm_processor.py b/tests/policies/test_sarm_processor.py index 66404f663..5b90784a6 100644 --- a/tests/policies/test_sarm_processor.py +++ b/tests/policies/test_sarm_processor.py @@ -25,7 +25,7 @@ import pandas as pd import pytest import torch -from lerobot.processor.core import TransitionKey +from lerobot.types import TransitionKey class MockDatasetMeta: diff --git a/tests/policies/xvla/test_xvla_original_vs_lerobot.py b/tests/policies/xvla/test_xvla_original_vs_lerobot.py index e36d14d01..3cea11329 100644 --- a/tests/policies/xvla/test_xvla_original_vs_lerobot.py +++ b/tests/policies/xvla/test_xvla_original_vs_lerobot.py @@ -30,7 +30,8 @@ pytest.importorskip("transformers") from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.policies.xvla.modeling_xvla import XVLAPolicy from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors -from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.processor import PolicyProcessorPipeline # noqa: E402 +from lerobot.types import PolicyAction # noqa: E402 from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402 from tests.utils import require_cuda # noqa: E402 diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 477381618..d589b6c5e 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -16,8 +16,9 @@ import torch -from lerobot.processor import DataProcessorPipeline, TransitionKey +from lerobot.processor import DataProcessorPipeline from lerobot.processor.converters import batch_to_transition, transition_to_batch +from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, REWARD, TRUNCATED diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index 47a6eea18..91afdd0e5 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -18,13 +18,13 @@ import numpy as np import pytest import torch -from lerobot.processor import TransitionKey from lerobot.processor.converters import ( batch_to_transition, create_transition, to_tensor, transition_to_batch, ) +from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, DONE, OBS_STATE, OBS_STR, REWARD diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index bb7d467bf..57b923076 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -19,8 +19,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey +from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep from lerobot.processor.converters import create_transition, identity_transition +from lerobot.types import TransitionKey from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 208a6b5c5..cd5c75005 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -30,7 +30,7 @@ from lerobot.processor import ( ) from lerobot.processor.converters import create_transition, identity_transition, to_tensor from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR -from lerobot.utils.utils import auto_select_torch_device +from lerobot.utils.device_utils import auto_select_torch_device def test_numpy_conversion(): diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 11b58a66c..923059210 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -19,8 +19,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType -from lerobot.processor import TransitionKey, VanillaObservationProcessorStep +from lerobot.processor import VanillaObservationProcessorStep from lerobot.processor.converters import create_transition +from lerobot.types import TransitionKey from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 64cc8aac8..2f1c4cc9c 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -25,8 +25,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey +from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep from lerobot.processor.converters import create_transition, identity_transition +from lerobot.types import TransitionKey from lerobot.utils.constants import ( ACTION, OBS_IMAGE, diff --git a/tests/training/test_visual_validation.py b/tests/training/test_visual_validation.py index af693fe5e..89351e3c2 100644 --- a/tests/training/test_visual_validation.py +++ b/tests/training/test_visual_validation.py @@ -37,7 +37,7 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import make_policy_config from lerobot.scripts.lerobot_train import train -from lerobot.utils.utils import auto_select_torch_device +from lerobot.utils.device_utils import auto_select_torch_device pytest.importorskip("transformers") diff --git a/tests/utils.py b/tests/utils.py index a77082ea9..33c554804 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,8 +21,8 @@ import pytest import torch from lerobot import available_cameras, available_motors, available_robots +from lerobot.utils.device_utils import auto_select_torch_device from lerobot.utils.import_utils import is_package_available -from lerobot.utils.utils import auto_select_torch_device DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device())) diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 408f636cb..c8e5a92a8 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -21,7 +21,7 @@ from types import SimpleNamespace import numpy as np import pytest -from lerobot.processor import TransitionKey +from lerobot.types import TransitionKey from lerobot.utils.constants import OBS_STATE