Files
lerobot-clone/src/lerobot/utils/import_utils.py
Khalil Meftah e963e5a0c4 RL stack refactoring (#3075)
* refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring

* chore: clarify torch.compile disabled note in SACAlgorithm

* fix(teleop): keyboard EE teleop not registering special keys and losing intervention state

Fixes #2345

Co-authored-by: jpizarrom <jpizarrom@gmail.com>

* fix: remove leftover normalization calls from reward classifier predict_reward

Fixes #2355

* fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample()

* refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference

* perf: remove redundant CPU→GPU→CPU transition move in learner

* Fix: add kwargs in reward classifier __init__()

* fix: include IS_INTERVENTION in complementary_info sent to learner for offline replay buffer

* fix: add try/finally to control_loop to ensure image writer cleanup on exit

* fix: use string key for IS_INTERVENTION in complementary_info to avoid torch.load serialization error

* fix: skip tests that require grpc if not available

* fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests

* fix(tests): skip tests that require grpc if not available

* refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages

* fix(config): update vision encoder model name to lerobot/resnet10

* fix(sac): clarify torch.compile status

* refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity

* refactor(sac): simplify optimizer return structure

* perf(rl): use async iterators in OnlineOfflineMixer.get_iterator

* refactor(sac): decouple algorithm hyperparameters from policy config

* update losses names in tests

* fix docstring

* remove unused type alias

* fix test for flat dict structure

* refactor(policies): rename policies/sac → policies/gaussian_actor

* refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

* perf(observation_processor): add CUDA support for image processing

* fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline

(cherry picked from commit 9c2af818ff)

* fix(rl): add time limit processor to environment pipeline

(cherry picked from commit cd105f65cb)

* fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100

(cherry picked from commit 494f469a2b)

* fix(rl): update neutral gripper action

(cherry picked from commit 9c9064e5be)

* fix(rl): merge environment and action-processor info in transition processing

(cherry picked from commit 30e1886b64)

* fix(rl): mirror gym_manipulator in actor

(cherry picked from commit d2a046dfc5)

* fix(rl): postprocess action in actor

(cherry picked from commit c2556439e5)

* fix(rl): improve action processing for discrete and continuous actions

(cherry picked from commit f887ab3f6a)

* fix(rl): enhance intervention handling in actor and learner

(cherry picked from commit ef8bfffbd7)

* Revert "perf(observation_processor): add CUDA support for image processing"

This reverts commit 38b88c414c.

* refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable

* refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation

* refactor(rl): add type property to RLAlgorithmConfig for better clarity

* refactor(rl): make RLAlgorithmConfig an abstract base class for better extensibility

* refactor(tests): remove grpc import checks from test files for cleaner code

* fix(tests): gate RL tests on the `datasets` extra

* refactor: simplify docstrings for clarity and conciseness across multiple files

* fix(rl): update gripper position key and handle action absence during reset

* fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset

* refactor: clean up import statements

* chore: address reviewer comments

* chore: improve visual stats reshaping logic and update docstring for clarity

* refactor: enforce mandatory config_class and name attributes in RLAlgorithm

* refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer

* refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests

* refactor: add require_package calls for grpcio and gym-hil in relevant modules

* refactor(rl): move grpcio guards to runtime entry points

* feat(rl): consolidate HIL-SERL checkpoint into HF-style components

Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract
`state_dict()` / `load_state_dict()` for critic ensemble, target nets
and `log_alpha`, and persist them as a sibling `algorithm/` component
next to `pretrained_model/`. Replace the pickled `training_state.pt`
with an enriched `training_step.json` carrying `step` and
`interaction_step`, so resume restores actor + critics + target nets +
temperature + optimizers + RNG + counters from HF-standard files.

* refactor(rl): move actor weight-sync wire format from policy to algorithm

* refactor(rl): update type hints for learner and actor functions

* refactor(rl): hoist grpcio guard to module top in actor/learner

* chore(rl): manage import pattern in actor (#3564)

* chore(rl): manage import pattern in actor

* chore(rl): optional grpc imports in learner; quote grpc ServicerContext types

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>

* update uv.lock

* chore(doc): update doc

---------

Co-authored-by: jpizarrom <jpizarrom@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-12 15:49:54 +02:00

235 lines
9.7 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import importlib.metadata
import logging
from typing import Any
from draccus.choice_types import ChoiceRegistry
def is_package_available(
pkg_name: str, import_name: str | None = None, return_version: bool = False
) -> tuple[bool, str] | bool:
"""
Check if the package spec exists and grab its version to avoid importing a local directory.
Args:
pkg_name: The name of the package as installed via pip (e.g. "python-can").
import_name: The actual name used to import the package (e.g. "can").
Defaults to pkg_name if not provided.
return_version: Whether to return the version string.
"""
if import_name is None:
import_name = pkg_name
# Check if the module spec exists using the import name
package_exists = importlib.util.find_spec(import_name) is not None
package_version = "N/A"
if package_exists:
try:
# Primary method to get the package version
package_version = importlib.metadata.version(pkg_name)
except importlib.metadata.PackageNotFoundError:
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(import_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
package_version = temp_version
package_exists = True
else:
package_exists = False
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
logging.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
return package_exists, package_version
else:
return package_exists
def get_safe_default_codec():
logger = logging.getLogger(__name__)
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logger.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
_require_package_cache: dict[str, bool] = {}
def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None:
"""Raise an informative ImportError if a package required by an optional feature is missing."""
cache_key = import_name or pkg_name
if cache_key not in _require_package_cache:
_require_package_cache[cache_key] = is_package_available(pkg_name, import_name)
if not _require_package_cache[cache_key]:
raise ImportError(
f"'{pkg_name}' is required but not installed. Install it with: "
f"pip install 'lerobot[{extra}]' (or uv pip install 'lerobot[{extra}]')"
)
# ── Centralised availability flags ────────────────────────────────────────
# Every optional-dependency check lives here so that the rest of the codebase
# can simply ``from lerobot.utils.import_utils import _foo_available``.
# Do NOT define ad-hoc ``is_package_available(...)`` calls in other modules.
# ML / training
_transformers_available = is_package_available("transformers")
_peft_available = is_package_available("peft")
_scipy_available = is_package_available("scipy")
_diffusers_available = is_package_available("diffusers")
_torchdiffeq_available = is_package_available("torchdiffeq")
# Hardware SDKs
_serial_available = is_package_available("pyserial", import_name="serial")
_deepdiff_available = is_package_available("deepdiff")
_dynamixel_sdk_available = is_package_available("dynamixel-sdk", import_name="dynamixel_sdk")
_feetech_sdk_available = is_package_available("feetech-servo-sdk", import_name="scservo_sdk")
_reachy2_sdk_available = is_package_available("reachy2_sdk")
_can_available = is_package_available("python-can", "can")
_unitree_sdk_available = is_package_available("unitree-sdk2py", "unitree_sdk2py")
_pyrealsense2_available = is_package_available("pyrealsense2") or is_package_available(
"pyrealsense2-macosx", import_name="pyrealsense2"
)
_zmq_available = is_package_available("pyzmq", import_name="zmq")
_hebi_available = is_package_available("hebi-py", import_name="hebi")
_teleop_available = is_package_available("teleop")
_placo_available = is_package_available("placo")
_hidapi_available = is_package_available("hidapi", import_name="hid")
# Data / serialization
_pandas_available = is_package_available("pandas")
_faker_available = is_package_available("faker")
# Misc
_pynput_available = is_package_available("pynput")
_pygame_available = is_package_available("pygame")
_qwen_vl_utils_available = is_package_available("qwen-vl-utils", import_name="qwen_vl_utils")
_grpc_available = is_package_available("grpcio", import_name="grpc")
_wallx_deps_available = (
_transformers_available and _peft_available and _torchdiffeq_available and _qwen_vl_utils_available
)
def make_device_from_device_class(config: ChoiceRegistry) -> Any:
"""
Dynamically instantiates an object from its `ChoiceRegistry` configuration.
This factory uses the module path and class name from the `config` object's
type to locate and instantiate the corresponding device class (not the config).
It derives the device class name by removing a trailing 'Config' from the config
class name and tries a few candidate modules where the device implementation is
commonly located.
"""
if not isinstance(config, ChoiceRegistry):
raise ValueError(f"Config should be an instance of `ChoiceRegistry`, got {type(config)}")
config_cls = config.__class__
module_path = config_cls.__module__ # typical: lerobot_teleop_mydevice.config_mydevice
config_name = config_cls.__name__ # typical: MyDeviceConfig
# Derive device class name (strip "Config")
if not config_name.endswith("Config"):
raise ValueError(f"Config class name '{config_name}' does not end with 'Config'")
device_class_name = config_name[:-6] # typical: MyDeviceConfig -> MyDevice
# Build candidate modules to search for the device class
parts = module_path.split(".")
parent_module = ".".join(parts[:-1]) if len(parts) > 1 else module_path
candidates = [
parent_module, # typical: lerobot_teleop_mydevice
parent_module + "." + device_class_name.lower(), # typical: lerobot_teleop_mydevice.mydevice
]
# handle modules named like "config_xxx" -> try replacing that piece with "xxx"
last = parts[-1] if parts else ""
if last.startswith("config_"):
candidates.append(".".join(parts[:-1] + [last.replace("config_", "")]))
# de-duplicate while preserving order
seen: set[str] = set()
candidates = [c for c in candidates if not (c in seen or seen.add(c))]
tried: list[str] = []
for candidate in candidates:
tried.append(candidate)
try:
module = importlib.import_module(candidate)
except ImportError:
continue
if hasattr(module, device_class_name):
cls = getattr(module, device_class_name)
if callable(cls):
try:
return cls(config)
except TypeError as e:
raise TypeError(
f"Failed to instantiate '{device_class_name}' from module '{candidate}': {e}"
) from e
raise ImportError(
f"Could not locate device class '{device_class_name}' for config '{config_name}'. "
f"Tried modules: {tried}. Ensure your device class name is the config class name without "
f"'Config' and that it's importable from one of those modules."
)
def register_third_party_plugins() -> None:
"""
Discover and import third-party LeRobot plugins so they can register themselves.
This function uses `importlib.metadata` to find packages installed in the environment
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
"""
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
imported: list[str] = []
failed: list[str] = []
def attempt_import(module_name: str):
try:
importlib.import_module(module_name)
imported.append(module_name)
logging.info("Imported third-party plugin: %s", module_name)
except Exception:
logging.exception("Could not import third-party plugin: %s", module_name)
failed.append(module_name)
for dist in importlib.metadata.distributions():
dist_name = dist.metadata.get("Name")
if not dist_name:
continue
if dist_name.startswith(prefixes):
attempt_import(dist_name)
logging.debug("Third-party plugin import summary: imported=%s failed=%s", imported, failed)