mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
107 lines
3.6 KiB
Python
107 lines
3.6 KiB
Python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import abc
|
|
from collections.abc import Iterator
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
|
|
from lerobot.rl.algorithms.configs import RLAlgorithmConfig, TrainingStats
|
|
|
|
if TYPE_CHECKING:
|
|
from lerobot.rl.data_sources.data_mixer import DataMixer
|
|
|
|
BatchType = dict[str, Any]
|
|
|
|
|
|
class RLAlgorithm(abc.ABC):
|
|
"""Base for all RL algorithms."""
|
|
|
|
config_class: type[RLAlgorithmConfig] | None = None
|
|
name: str | None = None
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
if not getattr(cls, "config_class", None):
|
|
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
|
|
if not getattr(cls, "name", None):
|
|
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
|
|
|
@abc.abstractmethod
|
|
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
|
"""One complete training step.
|
|
|
|
The algorithm calls ``next(batch_iterator)`` as many times as it
|
|
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
|
|
The iterator is owned by the trainer; the algorithm just consumes
|
|
from it.
|
|
"""
|
|
...
|
|
|
|
def configure_data_iterator(
|
|
self,
|
|
data_mixer: DataMixer,
|
|
batch_size: int,
|
|
*,
|
|
async_prefetch: bool = True,
|
|
queue_size: int = 2,
|
|
) -> Iterator[BatchType]:
|
|
"""Create the data iterator this algorithm needs.
|
|
|
|
The default implementation uses the standard ``data_mixer.get_iterator()``.
|
|
Algorithms that need specialised sampling should override this method.
|
|
"""
|
|
return data_mixer.get_iterator(
|
|
batch_size=batch_size,
|
|
async_prefetch=async_prefetch,
|
|
queue_size=queue_size,
|
|
)
|
|
|
|
def make_optimizers_and_scheduler(self) -> dict[str, Optimizer]:
|
|
"""Create, store, and return the optimizers needed for training.
|
|
|
|
Called on the **learner** side after construction. Subclasses must
|
|
override this with algorithm-specific optimizer setup.
|
|
"""
|
|
return {}
|
|
|
|
def get_optimizers(self) -> dict[str, Optimizer]:
|
|
"""Return optimizers for checkpointing / external scheduling."""
|
|
return {}
|
|
|
|
@property
|
|
def optimization_step(self) -> int:
|
|
"""Current learner optimization step.
|
|
|
|
Part of the stable contract for checkpoint/resume. Algorithms can
|
|
either use this default storage or override for custom behavior.
|
|
"""
|
|
return getattr(self, "_optimization_step", 0)
|
|
|
|
@optimization_step.setter
|
|
def optimization_step(self, value: int) -> None:
|
|
self._optimization_step = int(value)
|
|
|
|
def get_weights(self) -> dict[str, Any]:
|
|
"""Policy state-dict to push to actors."""
|
|
return {}
|
|
|
|
@abc.abstractmethod
|
|
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
|
"""Load policy state-dict received from the learner."""
|