# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import abc from collections.abc import Iterator from typing import TYPE_CHECKING, Any import torch from torch.optim import Optimizer from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats if TYPE_CHECKING: from lerobot.rl.data_sources.data_mixer import DataMixer BatchType = dict[str, Any] class RLAlgorithm(abc.ABC): """Base for all RL algorithms.""" config_class: type[RLAlgorithmConfig] | None = None name: str | None = None def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if not getattr(cls, "config_class", None): raise TypeError(f"Class {cls.__name__} must define 'config_class'") if not getattr(cls, "name", None): raise TypeError(f"Class {cls.__name__} must define 'name'") @abc.abstractmethod def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: """One complete training step. The algorithm calls ``next(batch_iterator)`` as many times as it needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches. The iterator is owned by the trainer; the algorithm just consumes from it. """ ... def configure_data_iterator( self, data_mixer: DataMixer, batch_size: int, *, async_prefetch: bool = True, queue_size: int = 2, ) -> Iterator[BatchType]: """Create the data iterator this algorithm needs. The default implementation uses the standard ``data_mixer.get_iterator()``. Algorithms that need specialised sampling should override this method. """ return data_mixer.get_iterator( batch_size=batch_size, async_prefetch=async_prefetch, queue_size=queue_size, ) def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]: """Create, store, and return the optimizers needed for training. Called on the **learner** side after construction. Subclasses must override this with algorithm-specific optimizer setup. """ return {} def get_optimizers(self) -> dict[str, Optimizer]: """Return optimizers for checkpointing / external scheduling.""" return {} @property def optimization_step(self) -> int: """Current learner optimization step. Part of the stable contract for checkpoint/resume. Algorithms can either use this default storage or override for custom behavior. """ return getattr(self, "_optimization_step", 0) @optimization_step.setter def optimization_step(self, value: int) -> None: self._optimization_step = int(value) def get_weights(self) -> dict[str, Any]: """Policy state-dict to push to actors.""" return {} @abc.abstractmethod def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: """Load policy state-dict received from the learner."""