mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
* refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring * chore: clarify torch.compile disabled note in SACAlgorithm * fix(teleop): keyboard EE teleop not registering special keys and losing intervention state Fixes #2345 Co-authored-by: jpizarrom <jpizarrom@gmail.com> * fix: remove leftover normalization calls from reward classifier predict_reward Fixes #2355 * fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample() * refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference * perf: remove redundant CPU→GPU→CPU transition move in learner * Fix: add kwargs in reward classifier __init__() * fix: include IS_INTERVENTION in complementary_info sent to learner for offline replay buffer * fix: add try/finally to control_loop to ensure image writer cleanup on exit * fix: use string key for IS_INTERVENTION in complementary_info to avoid torch.load serialization error * fix: skip tests that require grpc if not available * fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests * fix(tests): skip tests that require grpc if not available * refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages * fix(config): update vision encoder model name to lerobot/resnet10 * fix(sac): clarify torch.compile status * refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity * refactor(sac): simplify optimizer return structure * perf(rl): use async iterators in OnlineOfflineMixer.get_iterator * refactor(sac): decouple algorithm hyperparameters from policy config * update losses names in tests * fix docstring * remove unused type alias * fix test for flat dict structure * refactor(policies): rename policies/sac → policies/gaussian_actor * refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic * perf(observation_processor): add CUDA support for image processing * fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline (cherry picked from commit9c2af818ff) * fix(rl): add time limit processor to environment pipeline (cherry picked from commitcd105f65cb) * fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100 (cherry picked from commit494f469a2b) * fix(rl): update neutral gripper action (cherry picked from commit9c9064e5be) * fix(rl): merge environment and action-processor info in transition processing (cherry picked from commit30e1886b64) * fix(rl): mirror gym_manipulator in actor (cherry picked from commitd2a046dfc5) * fix(rl): postprocess action in actor (cherry picked from commitc2556439e5) * fix(rl): improve action processing for discrete and continuous actions (cherry picked from commitf887ab3f6a) * fix(rl): enhance intervention handling in actor and learner (cherry picked from commitef8bfffbd7) * Revert "perf(observation_processor): add CUDA support for image processing" This reverts commit38b88c414c. * refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable * refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation * refactor(rl): add type property to RLAlgorithmConfig for better clarity * refactor(rl): make RLAlgorithmConfig an abstract base class for better extensibility * refactor(tests): remove grpc import checks from test files for cleaner code * fix(tests): gate RL tests on the `datasets` extra * refactor: simplify docstrings for clarity and conciseness across multiple files * fix(rl): update gripper position key and handle action absence during reset * fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset * refactor: clean up import statements * chore: address reviewer comments * chore: improve visual stats reshaping logic and update docstring for clarity * refactor: enforce mandatory config_class and name attributes in RLAlgorithm * refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer * refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests * refactor: add require_package calls for grpcio and gym-hil in relevant modules * refactor(rl): move grpcio guards to runtime entry points * feat(rl): consolidate HIL-SERL checkpoint into HF-style components Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract `state_dict()` / `load_state_dict()` for critic ensemble, target nets and `log_alpha`, and persist them as a sibling `algorithm/` component next to `pretrained_model/`. Replace the pickled `training_state.pt` with an enriched `training_step.json` carrying `step` and `interaction_step`, so resume restores actor + critics + target nets + temperature + optimizers + RNG + counters from HF-standard files. * refactor(rl): move actor weight-sync wire format from policy to algorithm * refactor(rl): update type hints for learner and actor functions * refactor(rl): hoist grpcio guard to module top in actor/learner * chore(rl): manage import pattern in actor (#3564) * chore(rl): manage import pattern in actor * chore(rl): optional grpc imports in learner; quote grpc ServicerContext types --------- Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co> * update uv.lock * chore(doc): update doc --------- Co-authored-by: jpizarrom <jpizarrom@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
175 lines
6.8 KiB
Python
175 lines
6.8 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from pathlib import Path
|
|
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
|
|
from lerobot.configs.train import TrainPipelineConfig
|
|
from lerobot.optim import (
|
|
load_optimizer_state,
|
|
load_scheduler_state,
|
|
save_optimizer_state,
|
|
save_scheduler_state,
|
|
)
|
|
from lerobot.policies import PreTrainedPolicy
|
|
from lerobot.processor import PolicyProcessorPipeline
|
|
from lerobot.utils.constants import (
|
|
CHECKPOINTS_DIR,
|
|
LAST_CHECKPOINT_LINK,
|
|
PRETRAINED_MODEL_DIR,
|
|
TRAINING_STATE_DIR,
|
|
TRAINING_STEP,
|
|
)
|
|
from lerobot.utils.io_utils import load_json, write_json
|
|
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
|
|
|
|
|
def get_step_identifier(step: int, total_steps: int) -> str:
|
|
num_digits = max(6, len(str(total_steps)))
|
|
return f"{step:0{num_digits}d}"
|
|
|
|
|
|
def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
|
|
"""Returns the checkpoint sub-directory corresponding to the step number."""
|
|
step_identifier = get_step_identifier(step, total_steps)
|
|
return output_dir / CHECKPOINTS_DIR / step_identifier
|
|
|
|
|
|
def save_training_step(step: int, save_dir: Path) -> None:
|
|
write_json({"step": step}, save_dir / TRAINING_STEP)
|
|
|
|
|
|
def load_training_step(save_dir: Path) -> int:
|
|
training_step = load_json(save_dir / TRAINING_STEP)
|
|
return training_step["step"]
|
|
|
|
|
|
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
|
|
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
|
|
if last_checkpoint_dir.is_symlink():
|
|
last_checkpoint_dir.unlink()
|
|
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
|
|
last_checkpoint_dir.symlink_to(relative_target)
|
|
|
|
|
|
def save_checkpoint(
|
|
checkpoint_dir: Path,
|
|
step: int,
|
|
cfg: TrainPipelineConfig,
|
|
policy: PreTrainedPolicy,
|
|
optimizer: Optimizer,
|
|
scheduler: LRScheduler | None = None,
|
|
preprocessor: PolicyProcessorPipeline | None = None,
|
|
postprocessor: PolicyProcessorPipeline | None = None,
|
|
) -> None:
|
|
"""This function creates the following directory structure:
|
|
|
|
005000/ # training step at checkpoint
|
|
├── pretrained_model/
|
|
│ ├── config.json # policy config
|
|
│ ├── model.safetensors # policy weights
|
|
│ ├── train_config.json # train config
|
|
│ ├── processor.json # processor config (if preprocessor provided)
|
|
│ └── step_*.safetensors # processor state files (if any)
|
|
└── training_state/
|
|
├── optimizer_param_groups.json # optimizer param groups
|
|
├── optimizer_state.safetensors # optimizer state
|
|
├── rng_state.safetensors # rng states
|
|
├── scheduler_state.json # scheduler state
|
|
└── training_step.json # training step
|
|
|
|
Args:
|
|
cfg (TrainPipelineConfig): The training config used for this run.
|
|
step (int): The training step at that checkpoint.
|
|
policy (PreTrainedPolicy): The policy to save.
|
|
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
|
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
|
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
|
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
|
"""
|
|
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
|
policy.save_pretrained(pretrained_dir)
|
|
cfg.save_pretrained(pretrained_dir)
|
|
if cfg.peft is not None:
|
|
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
|
# policy config which we need for loading the model. In this case we'll write it ourselves.
|
|
policy.config.save_pretrained(pretrained_dir)
|
|
if preprocessor is not None:
|
|
preprocessor.save_pretrained(pretrained_dir)
|
|
if postprocessor is not None:
|
|
postprocessor.save_pretrained(pretrained_dir)
|
|
save_training_state(checkpoint_dir, step, optimizer, scheduler)
|
|
|
|
|
|
def save_training_state(
|
|
checkpoint_dir: Path,
|
|
train_step: int,
|
|
optimizer: Optimizer | None = None,
|
|
scheduler: LRScheduler | None = None,
|
|
) -> None:
|
|
"""
|
|
Saves the training step, optimizer state, scheduler state, and rng state.
|
|
|
|
Args:
|
|
save_dir (Path): The directory to save artifacts to.
|
|
train_step (int): Current training step.
|
|
optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict.
|
|
Defaults to None.
|
|
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
|
|
Defaults to None.
|
|
"""
|
|
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
save_training_step(train_step, save_dir)
|
|
save_rng_state(save_dir)
|
|
if optimizer is not None:
|
|
save_optimizer_state(optimizer, save_dir)
|
|
if scheduler is not None:
|
|
save_scheduler_state(scheduler, save_dir)
|
|
|
|
|
|
def load_training_state(
|
|
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
|
|
) -> tuple[int, Optimizer, LRScheduler | None]:
|
|
"""
|
|
Loads the training step, optimizer state, scheduler state, and rng state.
|
|
This is used to resume a training run.
|
|
|
|
Args:
|
|
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
|
|
optimizer (Optimizer): The optimizer to load the state_dict to.
|
|
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
|
|
|
|
Raises:
|
|
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
|
|
|
|
Returns:
|
|
tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their
|
|
state_dict loaded.
|
|
"""
|
|
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
|
|
if not training_state_dir.is_dir():
|
|
raise NotADirectoryError(training_state_dir)
|
|
|
|
load_rng_state(training_state_dir)
|
|
step = load_training_step(training_state_dir)
|
|
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
|
if scheduler is not None:
|
|
scheduler = load_scheduler_state(scheduler, training_state_dir)
|
|
|
|
return step, optimizer, scheduler
|