Files
lerobot-clone/src/lerobot/policies/factory.py
Steven Palma e489ba24fc feat(dependencies): require Python 3.12+ as minimum version (#3023)
* feat(dependecies): upgrade to python3.12

* fix(test): processor regex message

* fix(test): processor regex message

* fix(dependecies): resolve all tags in python 3.12

* fix(dependecies): add more hints to faster resolve

* chore(dependecies): remove cli tag huggingface-hub dep

* refactor(policy): update eagle for python3.12

* chore(docs): update policy creation for python 3.12

* chore(test): skip failing tests in macos
2026-03-06 10:15:13 +01:00

591 lines
23 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib
import logging
from typing import Any, TypedDict, Unpack
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sarm.configuration_sarm import SARMConfig
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
policy_action_to_transition,
transition_to_batch,
transition_to_policy_action,
)
from lerobot.utils.constants import (
ACTION,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
"""
Retrieves a policy class by its registered name.
This function uses dynamic imports to avoid loading all policy classes into memory
at once, improving startup time and reducing dependencies.
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
Raises:
NotImplementedError: If the policy name is not recognized.
"""
if name == "tdmpc":
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy
elif name == "diffusion":
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy
elif name == "act":
from lerobot.policies.act.modeling_act import ACTPolicy
return ACTPolicy
elif name == "vqbet":
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi0_fast":
from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy
return PI0FastPolicy
elif name == "pi05":
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy
return SACPolicy
elif name == "reward_classifier":
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
return Classifier
elif name == "smolvla":
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy
elif name == "sarm":
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
return SARMRewardModel
elif name == "groot":
from lerobot.policies.groot.modeling_groot import GrootPolicy
return GrootPolicy
elif name == "xvla":
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
return XVLAPolicy
elif name == "wall_x":
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy
return WallXPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
except Exception as e:
raise ValueError(f"Policy type '{name}' is not available.") from e
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
"""
Instantiates a policy configuration object based on the policy type.
This factory function simplifies the creation of policy configuration objects by
mapping a string identifier to the corresponding config class.
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
An instance of a `PreTrainedConfig` subclass.
Raises:
ValueError: If the `policy_type` is not recognized.
"""
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":
return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs)
elif policy_type == "groot":
return GrootConfig(**kwargs)
elif policy_type == "xvla":
return XVLAConfig(**kwargs)
elif policy_type == "wall_x":
return WallXConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
return config_cls(**kwargs)
except Exception as e:
raise ValueError(f"Policy type '{policy_type}' is not available.") from e
class ProcessorConfigKwargs(TypedDict, total=False):
"""
A TypedDict defining the keyword arguments for processor configuration.
This provides type hints for the optional arguments passed to `make_pre_post_processors`,
improving code clarity and enabling static analysis.
Attributes:
preprocessor_config_filename: The filename for the preprocessor configuration.
postprocessor_config_filename: The filename for the postprocessor configuration.
preprocessor_overrides: A dictionary of overrides for the preprocessor configuration.
postprocessor_overrides: A dictionary of overrides for the postprocessor configuration.
dataset_stats: Dataset statistics for normalization.
"""
preprocessor_config_filename: str | None
postprocessor_config_filename: str | None
preprocessor_overrides: dict[str, Any] | None
postprocessor_overrides: dict[str, Any] | None
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
def make_pre_post_processors(
policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Create or load pre- and post-processor pipelines for a given policy.
This function acts as a factory. It can either load existing processor pipelines
from a pretrained path or create new ones from scratch based on the policy
configuration. Each policy type has a dedicated factory function for its
processors (e.g., `make_tdmpc_pre_post_processors`).
Args:
policy_cfg: The configuration of the policy for which to create processors.
pretrained_path: An optional path to load pretrained processor pipelines from.
If provided, pipelines are loaded from this path.
**kwargs: Keyword arguments for processor configuration, as defined in
`ProcessorConfigKwargs`.
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
Raises:
NotImplementedError: If a processor factory is not implemented for the given
policy configuration type.
"""
if pretrained_path:
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
# GROOT handles normalization in groot_pack_inputs_v3 step
# Need to override both stats AND normalize_min_max since saved config might be empty
preprocessor_overrides = {}
postprocessor_overrides = {}
preprocessor_overrides["groot_pack_inputs_v3"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
}
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
"env_action_dim": env_action_dim,
}
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides
return (
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=batch_to_transition,
to_output=transition_to_batch,
),
PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
config_filename=kwargs.get(
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
),
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
# Create a new processor based on policy type
if isinstance(policy_cfg, TDMPCConfig):
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
processors = make_tdmpc_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, DiffusionConfig):
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
processors = make_diffusion_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, ACTConfig):
from lerobot.policies.act.processor_act import make_act_pre_post_processors
processors = make_act_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, VQBeTConfig):
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
processors = make_vqbet_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0Config):
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
processors = make_pi0_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI05Config):
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors
processors = make_pi05_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SACConfig):
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
processors = make_sac_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, RewardClassifierConfig):
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
processors = make_classifier_processor(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SmolVLAConfig):
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
processors = make_smolvla_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, SARMConfig):
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
processors = make_sarm_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
dataset_meta=kwargs.get("dataset_meta"),
)
elif isinstance(policy_cfg, GrootConfig):
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
processors = make_groot_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import (
make_xvla_pre_post_processors,
)
processors = make_xvla_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, WallXConfig):
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors
processors = make_wall_x_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:
processors = _make_processors_from_policy_config(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
except Exception as e:
raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e
return processors
def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
rename_map: dict[str, str] | None = None,
) -> PreTrainedPolicy:
"""
Instantiate a policy model.
This factory function handles the logic of creating a policy, which requires
determining the input and output feature shapes. These shapes can be derived
either from a `LeRobotDatasetMetadata` object or an `EnvConfig` object. The function
can either initialize a new policy from scratch or load a pretrained one.
Args:
cfg: The configuration for the policy to be created. If `cfg.pretrained_path` is
set, the policy will be loaded with weights from that path.
ds_meta: Dataset metadata used to infer feature shapes and types. Also provides
statistics for normalization layers.
env_cfg: Environment configuration used to infer feature shapes and types.
One of `ds_meta` or `env_cfg` must be provided.
rename_map: Optional mapping of dataset or environment feature keys to match
expected policy feature names (e.g., `"left"` → `"camera1"`).
Returns:
An instantiated and device-placed policy model.
Raises:
ValueError: If both or neither of `ds_meta` and `env_cfg` are provided.
NotImplementedError: If attempting to use an unsupported policy-backend
combination (e.g., VQBeT with 'mps').
"""
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
# you want this op to be added in priority during the prototype phase of this feature, please comment on
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
if cfg.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
)
policy_cls = get_policy_class(cfg.type)
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
else:
if not cfg.pretrained_path:
logging.warning(
"You are instantiating a policy from scratch and its features are parsed from an environment "
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
if env_cfg is None:
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
if ds_meta is not None and hasattr(ds_meta, "stats"):
kwargs["dataset_stats"] = ds_meta.stats
if ds_meta is not None:
kwargs["dataset_meta"] = ds_meta
if not cfg.pretrained_path and cfg.use_peft:
raise ValueError(
"Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires "
"the PEFT config parameters to be set. For training with PEFT, see `lerobot_train.py` on how to do that."
)
if cfg.pretrained_path and not cfg.use_peft:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
policy = policy_cls.from_pretrained(**kwargs)
elif cfg.pretrained_path and cfg.use_peft:
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
# of the adapter and the adapter's config contains the path to the base policy. So we need the
# adapter config first, then load the correct policy and then apply PEFT.
from peft import PeftConfig, PeftModel
logging.info("Loading policy's PEFT adapter.")
peft_pretrained_path = cfg.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path
if not kwargs["pretrained_name_or_path"]:
# This means that there's a bug or we trained a policy from scratch using PEFT.
# It is more likely that this is a bug so we'll raise an error.
raise ValueError(
"No pretrained model name found in adapter config. Can't instantiate the pre-trained policy on which "
"the adapter was trained."
)
policy = policy_cls.from_pretrained(**kwargs)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else:
# Make a fresh policy.
policy = policy_cls(**kwargs)
policy.to(cfg.device)
assert isinstance(policy, torch.nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead")
if not rename_map:
validate_visual_features_consistency(cfg, features)
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
return policy
def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]:
"""Get policy class from its registered name using dynamic imports.
This is used as a helper function to import policies from 3rd party lerobot plugins.
Args:
name: The name of the policy.
Returns:
The policy class corresponding to the given name.
"""
if name not in PreTrainedConfig.get_known_choices():
raise ValueError(
f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}"
)
config_cls = PreTrainedConfig.get_choice_class(name)
config_cls_name = config_cls.__name__
model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion
if model_name == config_cls_name:
raise ValueError(
f"The config class name '{config_cls_name}' does not follow the expected naming convention."
f"Make sure it ends with 'Config'!"
)
cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy
module_path = config_cls.__module__.replace(
"configuration_", "modeling_"
) # e.g., configuration_diffusion -> modeling_diffusion
module = importlib.import_module(module_path)
policy_cls = getattr(module, cls_name)
return policy_cls
def _make_processors_from_policy_config(
config: PreTrainedConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[Any, Any]:
"""Create pre- and post-processors from a policy configuration using dynamic imports.
This is used as a helper function to import processor factories from 3rd party lerobot plugins.
Args:
config: The policy configuration object.
dataset_stats: Dataset statistics for normalization.
Returns:
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
"""
policy_type = config.type
function_name = f"make_{policy_type}_pre_post_processors"
module_path = config.__class__.__module__.replace(
"configuration_", "processor_"
) # e.g., configuration_diffusion -> processor_diffusion
logging.debug(
f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'"
)
module = importlib.import_module(module_path)
function = getattr(module, function_name)
return function(config, dataset_stats=dataset_stats)