From 012d428f7b12d552983db3169b0675fd2767c1bc Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 2 Jul 2025 17:33:51 +0200 Subject: [PATCH] Reverted back missing files in `src/lerobot/configs/` --- src/lerobot/configs/default.py | 71 ++++++++++++ src/lerobot/configs/eval.py | 65 +++++++++++ src/lerobot/configs/policies.py | 190 ++++++++++++++++++++++++++++++++ src/lerobot/configs/train.py | 184 +++++++++++++++++++++++++++++++ 4 files changed, 510 insertions(+) diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index e69de29bb..53cfe58e7 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot import ( + policies, # noqa: F401 +) +from lerobot.datasets.transforms import ImageTransformsConfig +from lerobot.datasets.video_utils import get_safe_default_codec + + +@dataclass +class DatasetConfig: + # You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data + # keys common between the datasets are kept. Each dataset gets and additional transform that inserts the + # "dataset_index" into the returned item. The index mapping is made according to the order in which the + # datasets are provided. + repo_id: str + # Root directory where the dataset will be stored (e.g. 'dataset/path'). + root: str | None = None + episodes: list[int] | None = None + image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) + revision: str | None = None + use_imagenet_stats: bool = True + video_backend: str = field(default_factory=get_safe_default_codec) + + +@dataclass +class WandBConfig: + enable: bool = False + # Set to true to disable saving an artifact despite training.save_checkpoint=True + disable_artifact: bool = False + project: str = "lerobot" + entity: str | None = None + notes: str | None = None + run_id: str | None = None + mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online' + + +@dataclass +class EvalConfig: + n_episodes: int = 50 + # `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv. + batch_size: int = 50 + # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing). + use_async_envs: bool = False + + def __post_init__(self): + if self.batch_size > self.n_episodes: + raise ValueError( + "The eval batch size is greater than the number of eval episodes " + f"({self.batch_size} > {self.n_episodes}). As a result, {self.batch_size} " + f"eval environments will be instantiated, but only {self.n_episodes} will be used. " + "This might significantly slow down evaluation. To fix this, you should update your command " + f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), " + f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)." + ) diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py index e69de29bb..cfe48cf87 100644 --- a/src/lerobot/configs/eval.py +++ b/src/lerobot/configs/eval.py @@ -0,0 +1,65 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime as dt +import logging +from dataclasses import dataclass, field +from pathlib import Path + +from lerobot import envs, policies # noqa: F401 +from lerobot.configs import parser +from lerobot.configs.default import EvalConfig +from lerobot.configs.policies import PreTrainedConfig + + +@dataclass +class EvalPipelineConfig: + # Either the repo ID of a model hosted on the Hub or a path to a directory containing weights + # saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch + # (useful for debugging). This argument is mutually exclusive with `--config`. + env: envs.EnvConfig + eval: EvalConfig = field(default_factory=EvalConfig) + policy: PreTrainedConfig | None = None + output_dir: Path | None = None + job_name: str | None = None + seed: int | None = 1000 + + def __post_init__(self): + # HACK: We parse again the cli args here to get the pretrained path if there was one. + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + + else: + logging.warning( + "No pretrained path was provided, evaluated policy will be built from scratch (random weights)." + ) + + if not self.job_name: + if self.env is None: + self.job_name = f"{self.policy.type}" + else: + self.job_name = f"{self.env.type}_{self.policy.type}" + + if not self.output_dir: + now = dt.datetime.now() + eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" + self.output_dir = Path("outputs/eval") / eval_dir + + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["policy"] diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index e69de29bb..36e6ea2e5 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -0,0 +1,190 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Type, TypeVar + +import draccus +from huggingface_hub import hf_hub_download +from huggingface_hub.constants import CONFIG_NAME +from huggingface_hub.errors import HfHubHTTPError + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import OptimizerConfig +from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.utils.hub import HubMixin +from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available + +# Generic variable that is either PreTrainedConfig or a subclass thereof +T = TypeVar("T", bound="PreTrainedConfig") + + +@dataclass +class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): + """ + Base configuration class for policy models. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + input_shapes: A dictionary defining the shapes of the input data for the policy. + output_shapes: A dictionary defining the shapes of the output data for the policy. + input_normalization_modes: A dictionary with key representing the modality and the value specifies the + normalization mode to apply. + output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to + the original scale. + """ + + n_obs_steps: int = 1 + normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict) + + input_features: dict[str, PolicyFeature] = field(default_factory=dict) + output_features: dict[str, PolicyFeature] = field(default_factory=dict) + + device: str | None = None # cuda | cpu | mp + # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, + # automatic gradient scaling is used. + use_amp: bool = False + + push_to_hub: bool = True + repo_id: str | None = None + + # Upload on private repository on the Hugging Face hub. + private: bool | None = None + # Add tags to your policy on the hub. + tags: list[str] | None = None + # Add tags to your policy on the hub. + license: str | None = None + + def __post_init__(self): + self.pretrained_path = None + if not self.device or not is_torch_device_available(self.device): + auto_device = auto_select_torch_device() + logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") + self.device = auto_device.type + + # Automatically deactivate AMP if necessary + if self.use_amp and not is_amp_available(self.device): + logging.warning( + f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." + ) + self.use_amp = False + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + @property + @abc.abstractmethod + def observation_delta_indices(self) -> list | None: + raise NotImplementedError + + @property + @abc.abstractmethod + def action_delta_indices(self) -> list | None: + raise NotImplementedError + + @property + @abc.abstractmethod + def reward_delta_indices(self) -> list | None: + raise NotImplementedError + + @abc.abstractmethod + def get_optimizer_preset(self) -> OptimizerConfig: + raise NotImplementedError + + @abc.abstractmethod + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + raise NotImplementedError + + @abc.abstractmethod + def validate_features(self) -> None: + raise NotImplementedError + + @property + def robot_state_feature(self) -> PolicyFeature | None: + for _, ft in self.input_features.items(): + if ft.type is FeatureType.STATE: + return ft + return None + + @property + def env_state_feature(self) -> PolicyFeature | None: + for _, ft in self.input_features.items(): + if ft.type is FeatureType.ENV: + return ft + return None + + @property + def image_features(self) -> dict[str, PolicyFeature]: + return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL} + + @property + def action_feature(self) -> PolicyFeature | None: + for _, ft in self.output_features.items(): + if ft.type is FeatureType.ACTION: + return ft + return None + + def _save_pretrained(self, save_directory: Path) -> None: + with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"): + draccus.dump(self, f, indent=4) + + @classmethod + def from_pretrained( + cls: Type[T], + pretrained_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + **policy_kwargs, + ) -> T: + model_id = str(pretrained_name_or_path) + config_file: str | None = None + if Path(model_id).is_dir(): + if CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, CONFIG_NAME) + else: + print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}") + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" + ) from e + + # HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus + # something like --policy.path (in addition to --policy.type) + cli_overrides = policy_kwargs.pop("cli_overrides", []) + with draccus.config_type("json"): + return draccus.parse(cls, config_file, args=cli_overrides) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index e69de29bb..c088a5fa1 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -0,0 +1,184 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime as dt +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Type + +import draccus +from huggingface_hub import hf_hub_download +from huggingface_hub.errors import HfHubHTTPError + +from lerobot import envs +from lerobot.configs import parser +from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.optim import OptimizerConfig +from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.utils.hub import HubMixin + +TRAIN_CONFIG_NAME = "train_config.json" + + +@dataclass +class TrainPipelineConfig(HubMixin): + dataset: DatasetConfig + env: envs.EnvConfig | None = None + policy: PreTrainedConfig | None = None + # Set `dir` to where you would like to save all of the run outputs. If you run another training session + # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. + output_dir: Path | None = None + job_name: str | None = None + # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure + # `dir` is the directory of an existing run with at least one checkpoint in it. + # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint, + # regardless of what's provided with the training command at the time of resumption. + resume: bool = False + # `seed` is used for training (eg: model initialization, dataset shuffling) + # AND for the evaluation environments. + seed: int | None = 1000 + # Number of workers for the dataloader. + num_workers: int = 4 + batch_size: int = 8 + steps: int = 100_000 + eval_freq: int = 20_000 + log_freq: int = 200 + save_checkpoint: bool = True + # Checkpoint is saved every `save_freq` training iterations and after the last training step. + save_freq: int = 20_000 + use_policy_training_preset: bool = True + optimizer: OptimizerConfig | None = None + scheduler: LRSchedulerConfig | None = None + eval: EvalConfig = field(default_factory=EvalConfig) + wandb: WandBConfig = field(default_factory=WandBConfig) + + def __post_init__(self): + self.checkpoint_path = None + + def validate(self): + # HACK: We parse again the cli args here to get the pretrained paths if there was some. + policy_path = parser.get_path_arg("policy") + if policy_path: + # Only load the policy config + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + elif self.resume: + # The entire train config is already loaded, we just need to get the checkpoint dir + config_path = parser.parse_arg("config_path") + if not config_path: + raise ValueError( + f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}" + ) + if not Path(config_path).resolve().exists(): + raise NotADirectoryError( + f"{config_path=} is expected to be a local path. " + "Resuming from the hub is not supported for now." + ) + policy_path = Path(config_path).parent + self.policy.pretrained_path = policy_path + self.checkpoint_path = policy_path.parent + + if not self.job_name: + if self.env is None: + self.job_name = f"{self.policy.type}" + else: + self.job_name = f"{self.env.type}_{self.policy.type}" + + if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir(): + raise FileExistsError( + f"Output directory {self.output_dir} already exists and resume is {self.resume}. " + f"Please change your output directory so that {self.output_dir} is not overwritten." + ) + elif not self.output_dir: + now = dt.datetime.now() + train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" + self.output_dir = Path("outputs/train") / train_dir + + if isinstance(self.dataset.repo_id, list): + raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") + + if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): + raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") + elif self.use_policy_training_preset and not self.resume: + self.optimizer = self.policy.get_optimizer_preset() + self.scheduler = self.policy.get_scheduler_preset() + + if self.policy.push_to_hub and not self.policy.repo_id: + raise ValueError( + "'policy.repo_id' argument missing. Please specify it to push the model to the hub." + ) + + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["policy"] + + def to_dict(self) -> dict: + return draccus.encode(self) + + def _save_pretrained(self, save_directory: Path) -> None: + with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"): + draccus.dump(self, f, indent=4) + + @classmethod + def from_pretrained( + cls: Type["TrainPipelineConfig"], + pretrained_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + **kwargs, + ) -> "TrainPipelineConfig": + model_id = str(pretrained_name_or_path) + config_file: str | None = None + if Path(model_id).is_dir(): + if TRAIN_CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, TRAIN_CONFIG_NAME) + else: + print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}") + elif Path(model_id).is_file(): + config_file = model_id + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=TRAIN_CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + raise FileNotFoundError( + f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" + ) from e + + cli_args = kwargs.pop("cli_args", []) + with draccus.config_type("json"): + return draccus.parse(cls, config_file, args=cli_args) + + +@dataclass(kw_only=True) +class TrainRLServerPipelineConfig(TrainPipelineConfig): + dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset