Reverted back missing files in src/lerobot/configs/

This commit is contained in:
Michel Aractingi
2025-07-02 17:33:51 +02:00
parent 1c17419224
commit 012d428f7b
4 changed files with 510 additions and 0 deletions

View File

@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot import (
policies, # noqa: F401
)
from lerobot.datasets.transforms import ImageTransformsConfig
from lerobot.datasets.video_utils import get_safe_default_codec
@dataclass
class DatasetConfig:
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datasets are provided.
repo_id: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | None = None
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
@dataclass
class WandBConfig:
enable: bool = False
# Set to true to disable saving an artifact despite training.save_checkpoint=True
disable_artifact: bool = False
project: str = "lerobot"
entity: str | None = None
notes: str | None = None
run_id: str | None = None
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
@dataclass
class EvalConfig:
n_episodes: int = 50
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
batch_size: int = 50
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: bool = False
def __post_init__(self):
if self.batch_size > self.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({self.batch_size} > {self.n_episodes}). As a result, {self.batch_size} "
f"eval environments will be instantiated, but only {self.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)."
)

View File

@@ -0,0 +1,65 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime as dt
import logging
from dataclasses import dataclass, field
from pathlib import Path
from lerobot import envs, policies # noqa: F401
from lerobot.configs import parser
from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig
@dataclass
class EvalPipelineConfig:
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch
# (useful for debugging). This argument is mutually exclusive with `--config`.
env: envs.EnvConfig
eval: EvalConfig = field(default_factory=EvalConfig)
policy: PreTrainedConfig | None = None
output_dir: Path | None = None
job_name: str | None = None
seed: int | None = 1000
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type}"
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
if not self.output_dir:
now = dt.datetime.now()
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/eval") / eval_dir
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]

View File

@@ -0,0 +1,190 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Type, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
# Generic variable that is either PreTrainedConfig or a subclass thereof
T = TypeVar("T", bound="PreTrainedConfig")
@dataclass
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
Base configuration class for policy models.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
"""
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
push_to_hub: bool = True
repo_id: str | None = None
# Upload on private repository on the Hugging Face hub.
private: bool | None = None
# Add tags to your policy on the hub.
tags: list[str] | None = None
# Add tags to your policy on the hub.
license: str | None = None
def __post_init__(self):
self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@property
@abc.abstractmethod
def observation_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def action_delta_indices(self) -> list | None:
raise NotImplementedError
@property
@abc.abstractmethod
def reward_delta_indices(self) -> list | None:
raise NotImplementedError
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError
@abc.abstractmethod
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
raise NotImplementedError
@abc.abstractmethod
def validate_features(self) -> None:
raise NotImplementedError
@property
def robot_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items():
if ft.type is FeatureType.STATE:
return ft
return None
@property
def env_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items():
if ft.type is FeatureType.ENV:
return ft
return None
@property
def image_features(self) -> dict[str, PolicyFeature]:
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
@property
def action_feature(self) -> PolicyFeature | None:
for _, ft in self.output_features.items():
if ft.type is FeatureType.ACTION:
return ft
return None
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**policy_kwargs,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
# something like --policy.path (in addition to --policy.type)
cli_overrides = policy_kwargs.pop("cli_overrides", [])
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_overrides)

View File

@@ -0,0 +1,184 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime as dt
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Type
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HfHubHTTPError
from lerobot import envs
from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.optim import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.hub import HubMixin
TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
env: envs.EnvConfig | None = None
policy: PreTrainedConfig | None = None
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
output_dir: Path | None = None
job_name: str | None = None
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
# `dir` is the directory of an existing run with at least one checkpoint in it.
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption.
resume: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: int | None = 1000
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: int = 20_000
use_policy_training_preset: bool = True
optimizer: OptimizerConfig | None = None
scheduler: LRSchedulerConfig | None = None
eval: EvalConfig = field(default_factory=EvalConfig)
wandb: WandBConfig = field(default_factory=WandBConfig)
def __post_init__(self):
self.checkpoint_path = None
def validate(self):
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy")
if policy_path:
# Only load the policy config
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
elif self.resume:
# The entire train config is already loaded, we just need to get the checkpoint dir
config_path = parser.parse_arg("config_path")
if not config_path:
raise ValueError(
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
)
if not Path(config_path).resolve().exists():
raise NotADirectoryError(
f"{config_path=} is expected to be a local path. "
"Resuming from the hub is not supported for now."
)
policy_path = Path(config_path).parent
self.policy.pretrained_path = policy_path
self.checkpoint_path = policy_path.parent
if not self.job_name:
if self.env is None:
self.job_name = f"{self.policy.type}"
else:
self.job_name = f"{self.env.type}_{self.policy.type}"
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
raise FileExistsError(
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
f"Please change your output directory so that {self.output_dir} is not overwritten."
)
elif not self.output_dir:
now = dt.datetime.now()
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
elif self.use_policy_training_preset and not self.resume:
self.optimizer = self.policy.get_optimizer_preset()
self.scheduler = self.policy.get_scheduler_preset()
if self.policy.push_to_hub and not self.policy.repo_id:
raise ValueError(
"'policy.repo_id' argument missing. Please specify it to push the model to the hub."
)
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def to_dict(self) -> dict:
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: Type["TrainPipelineConfig"],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**kwargs,
) -> "TrainPipelineConfig":
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if TRAIN_CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
else:
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
elif Path(model_id).is_file():
config_file = model_id
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=TRAIN_CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
cli_args = kwargs.pop("cli_args", [])
with draccus.config_type("json"):
return draccus.parse(cls, config_file, args=cli_args)
@dataclass(kw_only=True)
class TrainRLServerPipelineConfig(TrainPipelineConfig):
dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset