mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
359 lines
13 KiB
Python
359 lines
13 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import abc
|
|
from collections.abc import Iterable
|
|
from dataclasses import asdict, dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import draccus
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
|
|
from lerobot.utils.constants import (
|
|
OPTIMIZER_PARAM_GROUPS,
|
|
OPTIMIZER_STATE,
|
|
)
|
|
from lerobot.utils.io_utils import deserialize_json_into_object, write_json
|
|
from lerobot.utils.utils import flatten_dict, unflatten_dict
|
|
|
|
# Type alias for parameters accepted by optimizer build() methods.
|
|
# This matches PyTorch's optimizer signature while also supporting:
|
|
# - dict[str, Parameter]: Named parameters for differential LR by name (e.g., XVLA)
|
|
# - dict[str, Iterable]: Multiple parameter groups for multi-optimizer configs (e.g., SAC)
|
|
OptimizerParams = (
|
|
Iterable[torch.nn.Parameter] # From model.parameters()
|
|
| Iterable[dict[str, Any]] # List of param groups with lr/weight_decay overrides
|
|
| dict[str, torch.nn.Parameter] # From dict(model.named_parameters()) for name-based LR
|
|
| dict[str, Any] # For multi-optimizer configs (SAC) with multiple param groups
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
|
lr: float
|
|
weight_decay: float
|
|
grad_clip_norm: float
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
return self.get_choice_name(self.__class__)
|
|
|
|
@classmethod
|
|
def default_choice_name(cls) -> str | None:
|
|
return "adam"
|
|
|
|
@abc.abstractmethod
|
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
|
"""
|
|
Build the optimizer. It can be a single optimizer or a dictionary of optimizers.
|
|
|
|
NOTE: Multiple optimizers are useful when you have different models to optimize.
|
|
For example, you can have one optimizer for the policy and another one for the value function
|
|
in reinforcement learning settings.
|
|
|
|
Args:
|
|
params: Parameters to optimize. Accepts multiple formats depending on the optimizer:
|
|
- Iterable[Parameter]: From model.parameters() - standard PyTorch usage
|
|
- Iterable[dict]: List of param groups with 'params' key and optional
|
|
'lr', 'weight_decay' overrides (e.g., ACT, VQBeT policies)
|
|
- dict[str, Parameter]: From dict(model.named_parameters()) for optimizers
|
|
that apply differential learning rates by parameter name (e.g., XVLA)
|
|
- dict[str, Iterable]: For multi-optimizer configs where each key maps to
|
|
a separate optimizer's parameters (e.g., SAC with actor/critic/temperature)
|
|
|
|
Returns:
|
|
The optimizer or a dictionary of optimizers.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adam")
|
|
@dataclass
|
|
class AdamConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.Adam(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("adamw")
|
|
@dataclass
|
|
class AdamWConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
betas: tuple[float, float] = (0.9, 0.999)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 1e-2
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.AdamW(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("sgd")
|
|
@dataclass
|
|
class SGDConfig(OptimizerConfig):
|
|
lr: float = 1e-3
|
|
momentum: float = 0.0
|
|
dampening: float = 0.0
|
|
nesterov: bool = False
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
|
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
|
kwargs = asdict(self)
|
|
kwargs.pop("grad_clip_norm")
|
|
return torch.optim.SGD(params, **kwargs)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("xvla-adamw")
|
|
@dataclass
|
|
class XVLAAdamWConfig(OptimizerConfig):
|
|
"""Custom AdamW optimizer for XVLA with differential learning rates.
|
|
|
|
The Vision-Language Model (VLM) is trained with 1/10 of the base learning rate
|
|
for stable optimization, while all other components use the full LR.
|
|
|
|
This LR ratio is crucial for achieving strong and stable finetuning performance.
|
|
|
|
Soft-prompts can optionally use a separate learning rate with warm-up support.
|
|
Set `soft_prompt_lr_scale` to a value < 1.0 (e.g., 0.1) to start soft-prompts
|
|
at a lower LR. Combine with a warmup scheduler for optimal results.
|
|
|
|
Note:
|
|
Completely matching official reported performance may require an additional
|
|
warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
|
When `soft_prompt_warmup_lr_scale` is set, soft-prompts start at
|
|
`lr * soft_prompt_warmup_lr_scale` and should be warmed up via the scheduler.
|
|
|
|
Parameter Groups:
|
|
- Group 0 (vlm): VLM parameters at lr * 0.1, weight_decay * 0.1
|
|
- Group 1 (soft_prompts): Soft-prompt parameters at lr * soft_prompt_lr_scale
|
|
- Group 2 (other): All other parameters at full lr
|
|
"""
|
|
|
|
lr: float = 1e-4
|
|
betas: tuple[float, float] = (0.9, 0.99)
|
|
eps: float = 1e-8
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
# Soft-prompt specific settings
|
|
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
|
|
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
|
|
|
|
def build(self, params: OptimizerParams) -> torch.optim.Optimizer:
|
|
"""
|
|
Build AdamW optimizer with differential learning rates.
|
|
|
|
Args:
|
|
params: Must be a dict[str, Parameter] from dict(model.named_parameters())
|
|
or equivalent.
|
|
|
|
Returns:
|
|
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
|
|
|
|
Raises:
|
|
AssertionError: If params is not a dict (e.g., from model.parameters())
|
|
"""
|
|
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
|
|
|
|
vlm_group, soft_prompt_group, other_group = [], [], []
|
|
for name, p in params.items():
|
|
if not p.requires_grad:
|
|
continue
|
|
if "vlm" in name.lower():
|
|
vlm_group.append(p)
|
|
elif "soft_prompt" in name.lower():
|
|
soft_prompt_group.append(p)
|
|
else:
|
|
other_group.append(p)
|
|
|
|
# Determine soft-prompt LR
|
|
soft_prompt_lr = self.lr * self.soft_prompt_lr_scale
|
|
if self.soft_prompt_warmup_lr_scale is not None:
|
|
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
|
|
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
|
|
|
|
param_groups: list[dict[str, Any]] = [
|
|
{
|
|
"params": vlm_group,
|
|
"lr": self.lr * 0.1,
|
|
"weight_decay": self.weight_decay * 0.1,
|
|
"name": "vlm",
|
|
},
|
|
{
|
|
"params": soft_prompt_group,
|
|
"lr": soft_prompt_lr,
|
|
"weight_decay": self.weight_decay,
|
|
"name": "soft_prompts",
|
|
},
|
|
{
|
|
"params": other_group,
|
|
"lr": self.lr,
|
|
"weight_decay": self.weight_decay,
|
|
"name": "other",
|
|
},
|
|
]
|
|
|
|
# Filter out empty groups
|
|
param_groups = [g for g in param_groups if len(g["params"]) > 0]
|
|
|
|
return torch.optim.AdamW(
|
|
param_groups,
|
|
betas=self.betas,
|
|
eps=self.eps,
|
|
)
|
|
|
|
|
|
@OptimizerConfig.register_subclass("multi_adam")
|
|
@dataclass
|
|
class MultiAdamConfig(OptimizerConfig):
|
|
"""Configuration for multiple Adam optimizers with different parameter groups.
|
|
|
|
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
|
|
|
Args:
|
|
lr: Default learning rate (used if not specified for a group)
|
|
weight_decay: Default weight decay (used if not specified for a group)
|
|
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
|
grad_clip_norm: Gradient clipping norm
|
|
"""
|
|
|
|
lr: float = 1e-3
|
|
weight_decay: float = 0.0
|
|
grad_clip_norm: float = 10.0
|
|
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
|
|
def build(self, params: OptimizerParams) -> dict[str, torch.optim.Optimizer]:
|
|
"""Build multiple Adam optimizers.
|
|
|
|
Args:
|
|
params: Must be a dict[str, Iterable[Parameter]] mapping parameter group names
|
|
to iterables of parameters. The keys should match the keys in optimizer_groups.
|
|
Typically from policies that need separate optimizers (e.g., SAC with
|
|
actor/critic/temperature).
|
|
|
|
Returns:
|
|
Dictionary mapping parameter group names to their optimizers
|
|
|
|
Raises:
|
|
AssertionError: If params is not a dict
|
|
"""
|
|
assert isinstance(params, dict), "MultiAdamConfig requires a dict of parameter groups as inputs."
|
|
optimizers = {}
|
|
|
|
for name, group_params in params.items():
|
|
# Get group-specific hyperparameters or use defaults
|
|
group_config = self.optimizer_groups.get(name, {})
|
|
|
|
# Create optimizer with merged parameters (defaults + group-specific)
|
|
optimizer_kwargs = {
|
|
"lr": group_config.get("lr", self.lr),
|
|
"betas": group_config.get("betas", (0.9, 0.999)),
|
|
"eps": group_config.get("eps", 1e-5),
|
|
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
|
}
|
|
|
|
optimizers[name] = torch.optim.Adam(group_params, **optimizer_kwargs)
|
|
|
|
return optimizers
|
|
|
|
|
|
def save_optimizer_state(
|
|
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
|
) -> None:
|
|
"""Save optimizer state to disk.
|
|
|
|
Args:
|
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
|
save_dir: Directory to save the optimizer state.
|
|
"""
|
|
if isinstance(optimizer, dict):
|
|
# Handle dictionary of optimizers
|
|
for name, opt in optimizer.items():
|
|
optimizer_dir = save_dir / name
|
|
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
|
_save_single_optimizer_state(opt, optimizer_dir)
|
|
else:
|
|
# Handle single optimizer
|
|
_save_single_optimizer_state(optimizer, save_dir)
|
|
|
|
|
|
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
|
"""Save a single optimizer's state to disk."""
|
|
state = optimizer.state_dict()
|
|
param_groups = state.pop("param_groups")
|
|
flat_state = flatten_dict(state)
|
|
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
|
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
|
|
|
|
|
def load_optimizer_state(
|
|
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
|
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
|
"""Load optimizer state from disk.
|
|
|
|
Args:
|
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
|
save_dir: Directory to load the optimizer state from.
|
|
|
|
Returns:
|
|
The updated optimizer(s) with loaded state.
|
|
"""
|
|
if isinstance(optimizer, dict):
|
|
# Handle dictionary of optimizers
|
|
loaded_optimizers = {}
|
|
for name, opt in optimizer.items():
|
|
optimizer_dir = save_dir / name
|
|
if optimizer_dir.exists():
|
|
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
|
|
else:
|
|
loaded_optimizers[name] = opt
|
|
return loaded_optimizers
|
|
else:
|
|
# Handle single optimizer
|
|
return _load_single_optimizer_state(optimizer, save_dir)
|
|
|
|
|
|
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
|
"""Load a single optimizer's state from disk."""
|
|
current_state_dict = optimizer.state_dict()
|
|
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
|
state = unflatten_dict(flat_state)
|
|
|
|
# Handle case where 'state' key might not exist (for newly created optimizers)
|
|
if "state" in state:
|
|
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
|
else:
|
|
loaded_state_dict = {"state": {}}
|
|
|
|
if "param_groups" in current_state_dict:
|
|
param_groups = deserialize_json_into_object(
|
|
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
|
|
)
|
|
loaded_state_dict["param_groups"] = param_groups
|
|
|
|
optimizer.load_state_dict(loaded_state_dict)
|
|
return optimizer
|