mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
* refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring * chore: clarify torch.compile disabled note in SACAlgorithm * fix(teleop): keyboard EE teleop not registering special keys and losing intervention state Fixes #2345 Co-authored-by: jpizarrom <jpizarrom@gmail.com> * fix: remove leftover normalization calls from reward classifier predict_reward Fixes #2355 * fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample() * refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference * perf: remove redundant CPU→GPU→CPU transition move in learner * Fix: add kwargs in reward classifier __init__() * fix: include IS_INTERVENTION in complementary_info sent to learner for offline replay buffer * fix: add try/finally to control_loop to ensure image writer cleanup on exit * fix: use string key for IS_INTERVENTION in complementary_info to avoid torch.load serialization error * fix: skip tests that require grpc if not available * fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests * fix(tests): skip tests that require grpc if not available * refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages * fix(config): update vision encoder model name to lerobot/resnet10 * fix(sac): clarify torch.compile status * refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity * refactor(sac): simplify optimizer return structure * perf(rl): use async iterators in OnlineOfflineMixer.get_iterator * refactor(sac): decouple algorithm hyperparameters from policy config * update losses names in tests * fix docstring * remove unused type alias * fix test for flat dict structure * refactor(policies): rename policies/sac → policies/gaussian_actor * refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic * perf(observation_processor): add CUDA support for image processing * fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline (cherry picked from commit9c2af818ff) * fix(rl): add time limit processor to environment pipeline (cherry picked from commitcd105f65cb) * fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100 (cherry picked from commit494f469a2b) * fix(rl): update neutral gripper action (cherry picked from commit9c9064e5be) * fix(rl): merge environment and action-processor info in transition processing (cherry picked from commit30e1886b64) * fix(rl): mirror gym_manipulator in actor (cherry picked from commitd2a046dfc5) * fix(rl): postprocess action in actor (cherry picked from commitc2556439e5) * fix(rl): improve action processing for discrete and continuous actions (cherry picked from commitf887ab3f6a) * fix(rl): enhance intervention handling in actor and learner (cherry picked from commitef8bfffbd7) * Revert "perf(observation_processor): add CUDA support for image processing" This reverts commit38b88c414c. * refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable * refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation * refactor(rl): add type property to RLAlgorithmConfig for better clarity * refactor(rl): make RLAlgorithmConfig an abstract base class for better extensibility * refactor(tests): remove grpc import checks from test files for cleaner code * fix(tests): gate RL tests on the `datasets` extra * refactor: simplify docstrings for clarity and conciseness across multiple files * fix(rl): update gripper position key and handle action absence during reset * fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset * refactor: clean up import statements * chore: address reviewer comments * chore: improve visual stats reshaping logic and update docstring for clarity * refactor: enforce mandatory config_class and name attributes in RLAlgorithm * refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer * refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests * refactor: add require_package calls for grpcio and gym-hil in relevant modules * refactor(rl): move grpcio guards to runtime entry points * feat(rl): consolidate HIL-SERL checkpoint into HF-style components Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract `state_dict()` / `load_state_dict()` for critic ensemble, target nets and `log_alpha`, and persist them as a sibling `algorithm/` component next to `pretrained_model/`. Replace the pickled `training_state.pt` with an enriched `training_step.json` carrying `step` and `interaction_step`, so resume restores actor + critics + target nets + temperature + optimizers + RNG + counters from HF-standard files. * refactor(rl): move actor weight-sync wire format from policy to algorithm * refactor(rl): update type hints for learner and actor functions * refactor(rl): hoist grpcio guard to module top in actor/learner * chore(rl): manage import pattern in actor (#3564) * chore(rl): manage import pattern in actor * chore(rl): optional grpc imports in learner; quote grpc ServicerContext types --------- Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co> * update uv.lock * chore(doc): update doc --------- Co-authored-by: jpizarrom <jpizarrom@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
102 lines
3.4 KiB
Python
102 lines
3.4 KiB
Python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Iterator
|
|
from typing import Any
|
|
|
|
from lerobot.types import BatchType
|
|
|
|
from .algorithms.base import RLAlgorithm
|
|
from .algorithms.configs import TrainingStats
|
|
from .data_sources.data_mixer import DataMixer
|
|
|
|
|
|
class RLTrainer:
|
|
"""Unified training step orchestrator.
|
|
|
|
Holds the algorithm, a DataMixer, and an optional preprocessor.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
algorithm: RLAlgorithm,
|
|
data_mixer: DataMixer,
|
|
batch_size: int,
|
|
*,
|
|
preprocessor: Any | None = None,
|
|
):
|
|
self.algorithm = algorithm
|
|
self.data_mixer = data_mixer
|
|
self.batch_size = batch_size
|
|
self._preprocessor = preprocessor
|
|
|
|
self._iterator: Iterator[BatchType] | None = None
|
|
|
|
self.algorithm.make_optimizers_and_scheduler()
|
|
|
|
def _build_data_iterator(self) -> Iterator[BatchType]:
|
|
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
|
|
raw = self.algorithm.configure_data_iterator(
|
|
data_mixer=self.data_mixer,
|
|
batch_size=self.batch_size,
|
|
)
|
|
if self._preprocessor is not None:
|
|
return _PreprocessedIterator(raw, self._preprocessor)
|
|
return raw
|
|
|
|
def reset_data_iterator(self) -> None:
|
|
"""Discard the current iterator so it will be rebuilt lazily next step."""
|
|
self._iterator = None
|
|
|
|
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
|
|
"""Swap the active data mixer, optionally resetting the iterator."""
|
|
self.data_mixer = data_mixer
|
|
if reset:
|
|
self.reset_data_iterator()
|
|
|
|
def training_step(self) -> TrainingStats:
|
|
"""Run one training step (algorithm-agnostic)."""
|
|
if self._iterator is None:
|
|
self._iterator = self._build_data_iterator()
|
|
return self.algorithm.update(self._iterator)
|
|
|
|
|
|
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
|
|
"""Apply policy preprocessing to RL observations only."""
|
|
observations = batch["state"]
|
|
next_observations = batch["next_state"]
|
|
batch["state"] = preprocessor.process_observation(observations)
|
|
batch["next_state"] = preprocessor.process_observation(next_observations)
|
|
|
|
return batch
|
|
|
|
|
|
class _PreprocessedIterator:
|
|
"""Iterator wrapper that preprocesses each sampled RL batch."""
|
|
|
|
__slots__ = ("_raw", "_preprocessor")
|
|
|
|
def __init__(self, raw_iterator: Iterator[BatchType], preprocessor: Any) -> None:
|
|
self._raw = raw_iterator
|
|
self._preprocessor = preprocessor
|
|
|
|
def __iter__(self) -> _PreprocessedIterator:
|
|
return self
|
|
|
|
def __next__(self) -> BatchType:
|
|
batch = next(self._raw)
|
|
return preprocess_rl_batch(self._preprocessor, batch)
|