mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
* Add extensive language support * Address review: split persistent/event schemas, drop event timestamps - recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals - dataset_metadata.py: keep CODEBASE_VERSION at v3.0 - language.py: remove RESERVED_STYLES; split arrow/feature schemas into persistent (with timestamp) and event (without timestamp); add docstrings - language_render.py: events use frame-row timestamp implicitly; no per-event timestamp filtering or sorting - converters.py: drop unused subtask_key passthrough - add docstrings to new public APIs (recipe, render_messages_processor, collate) - update tests for split schemas; revert uv.lock Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Add docstrings to all new helpers; revert uv.lock Covers private helpers in recipe.py, language.py, language_render.py, and render_messages_processor.py. Also reverts uv.lock to main (it was re-generated by `uv run` during local checks). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): add motion (persistent) and trace (event-only) styles Promote the previously-reserved motion/trace styles to first-class core styles. motion routes to language_persistent (it tracks robot state over time); trace routes to language_events (single-moment annotations). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): per-camera tagging on view-dependent styles Adds a nullable `camera` field to the language row struct (both persistent and event variants) so view-dependent styles like `vqa` can carry which `observation.images.*` view they were grounded against. Without this, multi-camera datasets ended up with multiple `(vqa, role)` rows at the same timestamp that the resolver could not disambiguate. - `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS, to both Arrow struct types and the HF datasets feature mappings; introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus `is_view_dependent_style` and `validate_camera_field` helpers (camera required iff style is view-dependent). - `language_render.py`: thread an optional `camera=` kwarg through every resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through `_matching_rows` / `_select_*`, so recipes can disambiguate per-camera VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`. Without a `camera` filter, multi-row matches keep raising the existing ambiguity error — which is the desired behaviour on multi-camera data. - `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with `ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying the matching image block), keeping the original 0.20 budget and documenting the customization point for datasets with different cameras. - Tests: schema test asserts the new field order; new tests cover `is_view_dependent_style`, `validate_camera_field` (both required and forbidden directions), per-camera `emitted_at` filtering, and the ambiguity error when two cameras emit `(vqa, assistant)` at the same timestamp without a `camera=` filter. RenderMessagesStep + dataset passthrough fixtures updated to include the new field. - `docs/source/language_and_recipes.mdx`: document the `camera` field, the per-camera resolver pattern, and the canonical recipe convention. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): drop motion from VIEW_DEPENDENT_STYLES Motion primitives are described in robot-frame (joint / Cartesian) terms, not pixel space, so they are camera-agnostic. Only `vqa` (event) and `trace` (event, pixel-trajectory) are view-dependent. The `camera` field stays on PERSISTENT_ROW_FIELDS for schema symmetry — the validator, resolver, and HF feature mapping behave identically across the two columns regardless of which styles populate `camera` today — but persistent rows now always have `camera=None` in practice. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): task_aug style + automatic ${task} rephrasing rotation Adds task-prompt diversity (Xiao 2022 / CAST) without touching ``meta/tasks.parquet`` or forcing recipes to opt in. The plan reserved ``task_aug`` as a future style; this lands it now. - ``language.py``: add ``task_aug`` to ``CORE_STYLES`` and ``PERSISTENT_STYLES``. ``column_for_style("task_aug")`` returns ``language_persistent`` so PR 2 writers route it correctly. - ``language_render.py``: ``_resolve_task`` now consults the persistent slice for rows of ``style="task_aug", role="user"``. When any exist it picks one deterministically by ``sample_idx`` (blake2b-keyed, not Python's randomized hash) so an epoch sees every rephrasing of every episode while the same sample still resolves identically across reruns. Falls back to the canonical ``meta/tasks.parquet`` task when no rephrasings are present, so existing datasets and unannotated runs keep their behaviour. Explicit ``task=`` overrides still win. - Tests: rephrasing coverage across samples, determinism on repeat ``sample_idx``, fallback when persistent has no ``task_aug`` rows, and explicit override priority. Recipes get this for free: any ``${task}`` placeholder rotates through the available rephrasings. Recipes that want the literal canonical task can override the binding. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(language): tool catalog in meta/info.json + LeRobotDatasetMetadata.tools Stores OpenAI-style function schemas at ``meta/info.json["tools"]`` so datasets can declare which tools are available (today: just ``say``; tomorrow: per-dataset extensions). The ``DEFAULT_TOOLS`` constant fills in for unannotated datasets so chat-template consumers don't have to special-case anything. Three pieces: - ``language.py``: ``SAY_TOOL_SCHEMA`` and ``DEFAULT_TOOLS`` constants. Single source of truth — PR 2's writer and PR 3's runtime tool registry will both import from here instead of duplicating the dict. - ``dataset_metadata.py``: ``LeRobotDatasetMetadata.tools`` property reads ``info.json["tools"]`` and falls back to ``DEFAULT_TOOLS``. Returns deep-copied dicts so callers can mutate the result safely. - ``docs/source/tools.mdx``: spec page covering the catalog, per-row invocations, and the three-step "how to add a new tool" workflow (declare schema, implement, register). Linked from the docs toctree under the Datasets section. This lays the groundwork for PR 2's pipeline writing the catalog out during annotation, and PR 3's ``src/lerobot/tools/`` package shipping runnable implementations (one file per tool — first up: ``say.py`` wrapping Kyutai's pocket-tts). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Apply ruff and prettier formatting after merge Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * refactor(language): unify resolver dispatch and prune redundant test scaffolding * Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`; only `emitted_at` actually consults events. The dispatcher in `_resolve_spec` now passes events conditionally. * Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a single `_row_sort_key` and drop the `sort_key` parameter from `_select_one`. Event rows lack `timestamp` (it is implicit in the frame) and now default to `0.0` for sort purposes — the `(style, role)` tiebreaker is unchanged. * Inline `_select_latest` into `active_at` (its only caller). * Collapse `emitted_at`'s dual-branch into one `_select_one` call. * Tighten `_validate_persistent_resolver` to a single `column_for_style(style) != LANGUAGE_PERSISTENT` check. * Parameterize `test_per_camera_blend_renders_both_views` over the two cameras and factor the sub-recipe builder into `_vqa_subrecipe` so the test no longer hand-rolls two near-identical recipe blocks. Net -98 LOC; behavior, public resolver names, and test expectations unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): always raise on ambiguous resolver matches `_select_one` previously skipped its ambiguity check whenever any of `role`/`tool_name`/`camera` was set, on the assumption that the caller had already pinned down a unique row. That left a real ambiguity hole for VQA: with two cameras emitting `(vqa, assistant)` at the same frame, `emitted_at(..., role="assistant")` silently picked the first sorted row instead of telling the recipe to add `camera=...`. The existing `test_emitted_at_raises_on_ambiguous_per_camera_vqa` test already encoded the desired behavior. Tighten the check: any time `len(rows) > 1` we now raise with the selectors echoed back, so users see exactly which fields they passed and that more is needed to disambiguate. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * chore: fix CI — collapse short ValueError to one line, refresh uv.lock * `ruff format` on CI (newer version) wants the short `camera=None` ValueError on a single line. * `uv.lock` was stale relative to `pyproject.toml`'s `datasets>=4.7.0` pin (and picked up upstream `s390x` marker fixes for cuda packages). CI runs `uv sync --locked` which rejected the divergence. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): keep base install green — drop processor re-export, gate dataset-extra tests `lerobot.processor` re-exported `RenderMessagesStep` at the package level, so importing anything from `lerobot.processor` pulled in `lerobot.datasets.language` → `lerobot.datasets/__init__.py` → `require_package("datasets")`, which fails in the Tier 1 base install that intentionally omits the `[dataset]` extra. The chain bricked collection for unrelated suites (`tests/policies/pi0_pi05/...`, `tests/envs/...`, etc.). * Stop re-exporting `RenderMessagesStep` from `lerobot.processor`. The only consumer (the test) already imports from the submodule. Document the deliberate omission in the module docstring. * Add `pytest.importorskip("datasets", ...)` (and `pandas` where needed) at the top of the four PR-added tests that exercise the language stack: - tests/datasets/test_language.py - tests/datasets/test_language_render.py - tests/processor/test_render_messages_processor.py - tests/utils/test_collate.py Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(language): address review — tools accessor, motion docs, conditional collate * **`meta.tools` actually reads `info.json["tools"]`.** `DatasetInfo` had no `tools` field, so `from_dict` silently dropped the key (it warned about unknown fields then discarded them) and the property always returned `DEFAULT_TOOLS`. Added `tools: list[dict] | None` to the dataclass; `to_dict()` drops it when unset so existing datasets keep a clean `info.json`. Fixed the accessor to read `self.info.tools` (the previous `.get(...)` would have raised AttributeError on the dataclass anyway). Added regression tests: fallback when absent, round-trip from disk, and round-trip through `DatasetInfo.from_dict` / `to_dict`. * **`motion` is not view-dependent — fix the docs.** The mdx claimed rows of style `motion` must carry `camera`, but `VIEW_DEPENDENT_STYLES = {"vqa", "trace"}` and the validator agrees: motion primitives are joint/Cartesian-frame, not pixel-space. Updated both call-out paragraphs in `language_and_recipes.mdx`. * **Conditional `collate_fn` swap.** Added `meta.has_language_columns` and gate the `lerobot_collate_fn` swap in `lerobot_train.py` on it, so non-language datasets keep PyTorch's `default_collate`. Also added a pass-through test in `test_collate.py` that asserts on a plain tensor batch the custom collate matches `default_collate` key-for-key, plus a test for the `None`-sample drop path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review: dedupe regex, centralize column names, harden collate, more tests * **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in `recipe.py` and `language_render.py`. Promote to module-level `PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares template syntax) and import from `language_render.py`. * **#3 — centralize language column names.** `io_utils.py` had hardcoded `{"language_persistent", "language_events"}` literals at two sites. Replace with `LANGUAGE_COLUMNS` import so a future column rename can't silently desync. * **#4 — defensive collate preserved-keys.** `lerobot_collate_fn` silently filtered language fields from samples that didn't have them, which would hand downstream consumers a preserved list shorter than the tensor batch. Now: if any sample carries a key, every sample in the batch must carry it; otherwise raise a `ValueError` so the upstream rendering bug surfaces at the boundary. * **#5 — `_scalar` rejects non-singleton lists.** Previously a zero- or multi-element list fell through and triggered confusing `float([])` errors downstream. Now raises `ValueError` with the actual length. * **#6 — refactor `_extract_complementary_data`.** Replace 11 lines of `key = {... if ... else {}}` plus an 11-line splat dict with a single `_COMPLEMENTARY_KEYS` tuple iterated once. * **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no comment. Add a docstring explaining it's an intentional extension point: downstream modules append project-local styles before `column_for_style` is called. * **#9 — `tools.mdx` notes the runtime layer is future work.** The page referenced `src/lerobot/tools/`, `registry.py`, and `get_tools(meta)` — none exist in this PR. Added a callout at the start of "How to add your own tool" plus a note on the implementations paragraph. * **#10 — tests for YAML round-trip, malformed rows, blend validation.** `test_recipe.py` grew from 1 case to 12 covering: blend-or-messages exclusivity, target-turn requirement, blend emptiness, weight presence/positivity, nested-blend rejection, `from_dict` with nested blends, `from_yaml` / `load_recipe` agreement, top-level non-mapping rejection. Added a malformed-row test for `_normalize_rows` that asserts non-dict entries raise `TypeError`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review: emitted_at uses 0.1s tolerance; MessageTurn requires stream at construction * **Float tolerance in `emitted_at` for persistent styles.** The ``_timestamp(row) == t`` exact-equality check silently missed any caller that derived ``t`` arithmetically (e.g. ``frame_idx / fps``) even though the parquet timestamp would only differ by ULPs. Added ``EMITTED_AT_TOLERANCE_S = 0.1`` and check ``abs(...) <= tolerance`` instead, with a docstring explaining why exact equality wasn't enough and why 0.1 s is safe at typical 30–100 Hz control rates. Test asserts the new behavior at half-window (matches) and double-window (no match) using the constant so it stays in sync. * **`MessageTurn.stream` is required at construction.** It was typed ``MessageStream | None = None`` so YAML could omit ``stream:`` and pass the dataclass invariant — but ``_validate_rendered`` rejected ``None`` streams later, surfacing the error at the first sample instead of at recipe load. Now ``__post_init__`` raises ``ValueError`` if ``stream`` is ``None``, with the list of valid streams in the message. The redundant late-stage check in ``_validate_rendered`` is replaced with a one-line comment that cites the upstream invariant. Test pins the new construction-time rejection. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs(tools): drop follow-up-PR references Reword the two callouts in `tools.mdx` to describe the runtime layer in present tense ("not part of the catalog layer shipped today", "those modules don't yet exist in the tree") instead of pointing at a specific follow-up PR. Keeps the doc honest about what works now without coupling it to a particular release order. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * review: address CarolinePascal feedback - language timestamps: float64 -> float32 to match LeRobotDataset frame timestamps (Arrow struct + HF feature) - dataset_metadata: hoist `.language` imports to module top — language.py has no lerobot imports, so there is no circular-import risk - dataset_metadata: add a `meta.tools` setter that persists the catalog to info.json and reloads `meta.info` - feature_utils: validate the `language` dtype instead of returning "" — warn (non-fatal) when a non-empty value is written at record time - centralize the scalar-unwrap helper as `lerobot.utils.utils.unwrap_scalar`, shared by render_messages_processor and language_render - docs: move `## Layer 2 — recipe anatomy` ahead of the resolver sections, which describe recipe bindings rather than dataset layout - language_render: note in EMITTED_AT_TOLERANCE_S that persistent rows change on a human-action timescale, not the camera frame rate Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
606 lines
24 KiB
Python
606 lines
24 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Train a policy.
|
|
|
|
Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wandb extras)
|
|
"""
|
|
|
|
import dataclasses
|
|
import logging
|
|
import time
|
|
from contextlib import nullcontext
|
|
from pprint import pformat
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
if TYPE_CHECKING:
|
|
from accelerate import Accelerator
|
|
|
|
import torch
|
|
from termcolor import colored
|
|
from torch.optim import Optimizer
|
|
from tqdm import tqdm
|
|
|
|
from lerobot.common.train_utils import (
|
|
get_step_checkpoint_dir,
|
|
get_step_identifier,
|
|
load_training_state,
|
|
save_checkpoint,
|
|
update_last_checkpoint,
|
|
)
|
|
from lerobot.common.wandb_utils import WandBLogger
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.train import TrainPipelineConfig
|
|
from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
|
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
|
from lerobot.rewards import make_reward_pre_post_processors
|
|
from lerobot.utils.collate import lerobot_collate_fn
|
|
from lerobot.utils.import_utils import register_third_party_plugins
|
|
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
|
from lerobot.utils.random_utils import set_seed
|
|
from lerobot.utils.utils import (
|
|
cycle,
|
|
format_big_number,
|
|
has_method,
|
|
init_logging,
|
|
inside_slurm,
|
|
)
|
|
|
|
from .lerobot_eval import eval_policy_all
|
|
|
|
|
|
def update_policy(
|
|
train_metrics: MetricsTracker,
|
|
policy: PreTrainedPolicy,
|
|
batch: Any,
|
|
optimizer: Optimizer,
|
|
grad_clip_norm: float,
|
|
accelerator: "Accelerator",
|
|
lr_scheduler=None,
|
|
lock=None,
|
|
sample_weighter=None,
|
|
) -> tuple[MetricsTracker, dict | None]:
|
|
"""
|
|
Performs a single training step to update the policy's weights.
|
|
|
|
This function executes the forward and backward passes, clips gradients, and steps the optimizer and
|
|
learning rate scheduler. Accelerator handles mixed-precision training automatically.
|
|
|
|
Args:
|
|
train_metrics: A MetricsTracker instance to record training statistics.
|
|
policy: The policy model to be trained.
|
|
batch: A batch of training data.
|
|
optimizer: The optimizer used to update the policy's parameters.
|
|
grad_clip_norm: The maximum norm for gradient clipping.
|
|
accelerator: The Accelerator instance for distributed training and mixed precision.
|
|
lr_scheduler: An optional learning rate scheduler.
|
|
lock: An optional lock for thread-safe optimizer updates.
|
|
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- The updated MetricsTracker with new statistics for this step.
|
|
- A dictionary of outputs from the policy's forward pass, for logging purposes.
|
|
"""
|
|
start_time = time.perf_counter()
|
|
policy.train()
|
|
|
|
# Compute sample weights if a weighter is provided
|
|
sample_weights = None
|
|
weight_stats = None
|
|
if sample_weighter is not None:
|
|
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
|
|
|
|
# Let accelerator handle mixed precision
|
|
with accelerator.autocast():
|
|
if sample_weights is not None:
|
|
# Use per-sample loss for weighted training
|
|
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
|
|
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
|
|
|
# Weighted loss: each sample's contribution is scaled by its weight.
|
|
# We divide by weight sum (not batch size) so that if some weights are zero,
|
|
# the remaining samples contribute proportionally more, preserving gradient scale.
|
|
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
|
|
epsilon = 1e-6
|
|
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
|
|
|
|
# Log weighting statistics
|
|
if output_dict is None:
|
|
output_dict = {}
|
|
for key, value in weight_stats.items():
|
|
output_dict[f"sample_weight_{key}"] = value
|
|
else:
|
|
loss, output_dict = policy.forward(batch)
|
|
|
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
|
|
|
# Use accelerator's backward method
|
|
accelerator.backward(loss)
|
|
|
|
# Clip gradients if specified
|
|
if grad_clip_norm > 0:
|
|
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
|
|
else:
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
policy.parameters(), float("inf"), error_if_nonfinite=False
|
|
)
|
|
|
|
# Optimizer step
|
|
with lock if lock is not None else nullcontext():
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Step through pytorch scheduler at every batch instead of epoch
|
|
if lr_scheduler is not None:
|
|
lr_scheduler.step()
|
|
|
|
# Update internal buffers if policy has update method
|
|
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
|
|
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
|
|
|
train_metrics.loss = loss.item()
|
|
train_metrics.grad_norm = grad_norm.item()
|
|
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
|
train_metrics.update_s = time.perf_counter() - start_time
|
|
return train_metrics, output_dict
|
|
|
|
|
|
@parser.wrap()
|
|
def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|
"""
|
|
Main function to train a policy.
|
|
|
|
This function orchestrates the entire training pipeline, including:
|
|
- Setting up logging, seeding, and device configuration.
|
|
- Creating the dataset, evaluation environment (if applicable), policy, and optimizer.
|
|
- Handling resumption from a checkpoint.
|
|
- Running the main training loop, which involves fetching data batches and calling `update_policy`.
|
|
- Periodically logging metrics, saving model checkpoints, and evaluating the policy.
|
|
- Pushing the final trained model to the Hugging Face Hub if configured.
|
|
|
|
Args:
|
|
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
|
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
|
"""
|
|
from lerobot.utils.import_utils import require_package
|
|
|
|
require_package("accelerate", extra="training")
|
|
from accelerate import Accelerator
|
|
|
|
cfg.validate()
|
|
|
|
# Create Accelerator if not provided
|
|
# It will automatically detect if running in distributed mode or single-process mode
|
|
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
|
# We set find_unused_parameters=True to handle models with conditional computation
|
|
if accelerator is None:
|
|
from accelerate.utils import DistributedDataParallelKwargs
|
|
|
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
|
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
|
force_cpu = cfg.trainable_config.device == "cpu"
|
|
accelerator = Accelerator(
|
|
step_scheduler_with_optimizer=False,
|
|
kwargs_handlers=[ddp_kwargs],
|
|
cpu=force_cpu,
|
|
)
|
|
|
|
init_logging(accelerator=accelerator)
|
|
|
|
# Determine if this is the main process (for logging and checkpointing)
|
|
# When using accelerate, only the main process should log to avoid duplicate outputs
|
|
is_main_process = accelerator.is_main_process
|
|
|
|
# Only log on main process
|
|
if is_main_process:
|
|
logging.info(pformat(cfg.to_dict()))
|
|
|
|
# Initialize wandb only on main process
|
|
if cfg.wandb.enable and cfg.wandb.project and is_main_process:
|
|
wandb_logger = WandBLogger(cfg)
|
|
else:
|
|
wandb_logger = None
|
|
if is_main_process:
|
|
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
|
|
|
if cfg.seed is not None:
|
|
set_seed(cfg.seed, accelerator=accelerator)
|
|
|
|
# Use accelerator's device
|
|
device = accelerator.device
|
|
if cfg.cudnn_deterministic:
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
else:
|
|
torch.backends.cudnn.benchmark = True
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
|
if is_main_process:
|
|
logging.info("Creating dataset")
|
|
dataset = make_dataset(cfg)
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
# Now all other processes can safely load the dataset
|
|
if not is_main_process:
|
|
dataset = make_dataset(cfg)
|
|
|
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
|
eval_env = None
|
|
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process:
|
|
logging.info("Creating env")
|
|
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
|
|
|
if cfg.is_reward_model_training:
|
|
if is_main_process:
|
|
logging.info("Creating reward model")
|
|
from lerobot.rewards import make_reward_model
|
|
|
|
policy = make_reward_model(
|
|
cfg=cfg.reward_model,
|
|
dataset_stats=dataset.meta.stats,
|
|
dataset_meta=dataset.meta,
|
|
)
|
|
if not policy.is_trainable:
|
|
raise ValueError(
|
|
f"Reward model '{policy.name}' is zero-shot and cannot be trained via lerobot-train. "
|
|
"Use it directly for inference via compute_reward() (e.g. offline precompute)."
|
|
)
|
|
else:
|
|
if is_main_process:
|
|
logging.info("Creating policy")
|
|
policy = make_policy(
|
|
cfg=cfg.policy,
|
|
ds_meta=dataset.meta,
|
|
rename_map=cfg.rename_map,
|
|
)
|
|
|
|
if cfg.peft is not None:
|
|
if cfg.is_reward_model_training:
|
|
raise ValueError("PEFT is only supported for policy training. ")
|
|
from peft import PeftModel
|
|
|
|
if isinstance(policy, PeftModel):
|
|
logging.info("PEFT adapter already loaded from checkpoint, skipping wrap_with_peft.")
|
|
else:
|
|
logging.info("Using PEFT! Wrapping model.")
|
|
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
|
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
|
|
|
# Wait for all processes to finish model creation before continuing
|
|
accelerator.wait_for_everyone()
|
|
|
|
active_cfg = cfg.trainable_config
|
|
processor_pretrained_path = active_cfg.pretrained_path
|
|
if (
|
|
getattr(active_cfg, "use_relative_actions", False)
|
|
and processor_pretrained_path is not None
|
|
and not cfg.resume
|
|
):
|
|
logging.warning(
|
|
"use_relative_actions=true with pretrained processors can skip relative transforms if "
|
|
"the checkpoint processors do not define them. Building processors from current policy config."
|
|
)
|
|
processor_pretrained_path = None
|
|
|
|
processor_kwargs = {}
|
|
postprocessor_kwargs = {}
|
|
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
|
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
|
|
|
if cfg.is_reward_model_training:
|
|
processor_kwargs["dataset_meta"] = dataset.meta
|
|
|
|
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
|
processor_kwargs["preprocessor_overrides"] = {
|
|
"device_processor": {"device": device.type},
|
|
"normalizer_processor": {
|
|
"stats": dataset.meta.stats,
|
|
"features": {**policy.config.input_features, **policy.config.output_features},
|
|
"norm_map": policy.config.normalization_mapping,
|
|
},
|
|
}
|
|
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
|
|
"rename_map": cfg.rename_map
|
|
}
|
|
postprocessor_kwargs["postprocessor_overrides"] = {
|
|
"unnormalizer_processor": {
|
|
"stats": dataset.meta.stats,
|
|
"features": policy.config.output_features,
|
|
"norm_map": policy.config.normalization_mapping,
|
|
},
|
|
}
|
|
|
|
if cfg.is_reward_model_training:
|
|
preprocessor, postprocessor = make_reward_pre_post_processors(
|
|
cfg.reward_model,
|
|
**processor_kwargs,
|
|
)
|
|
else:
|
|
preprocessor, postprocessor = make_pre_post_processors(
|
|
policy_cfg=cfg.policy,
|
|
pretrained_path=processor_pretrained_path,
|
|
**processor_kwargs,
|
|
**postprocessor_kwargs,
|
|
)
|
|
|
|
if is_main_process:
|
|
logging.info("Creating optimizer and scheduler")
|
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
|
|
|
# Create sample weighter if configured (e.g., for RA-BC training)
|
|
sample_weighter = None
|
|
if cfg.sample_weighting is not None:
|
|
from lerobot.utils.sample_weighting import make_sample_weighter
|
|
|
|
if is_main_process:
|
|
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
|
|
sample_weighter = make_sample_weighter(
|
|
cfg.sample_weighting,
|
|
policy,
|
|
device,
|
|
dataset_root=cfg.dataset.root,
|
|
dataset_repo_id=cfg.dataset.repo_id,
|
|
)
|
|
|
|
step = 0 # number of policy updates (forward + backward + optim)
|
|
|
|
if cfg.resume:
|
|
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
|
|
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
|
|
|
if is_main_process:
|
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
|
if cfg.env is not None:
|
|
logging.info(f"{cfg.env.task=}")
|
|
logging.info("Creating environment processors")
|
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
|
|
env_cfg=cfg.env, policy_cfg=cfg.policy
|
|
)
|
|
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
|
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
|
logging.info(f"{dataset.num_episodes=}")
|
|
num_processes = accelerator.num_processes
|
|
effective_bs = cfg.batch_size * num_processes
|
|
logging.info(f"Effective batch size: {cfg.batch_size} x {num_processes} = {effective_bs}")
|
|
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
|
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
|
|
|
# create dataloader for offline training
|
|
if hasattr(active_cfg, "drop_n_last_frames"):
|
|
shuffle = False
|
|
sampler = EpisodeAwareSampler(
|
|
dataset.meta.episodes["dataset_from_index"],
|
|
dataset.meta.episodes["dataset_to_index"],
|
|
episode_indices_to_use=dataset.episodes,
|
|
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
|
shuffle=True,
|
|
)
|
|
else:
|
|
shuffle = True
|
|
sampler = None
|
|
|
|
# Only swap in the language-aware collate when the dataset actually
|
|
# declares language columns; otherwise stay on PyTorch's default
|
|
# collate so non-language training runs are unaffected.
|
|
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=cfg.num_workers,
|
|
batch_size=cfg.batch_size,
|
|
shuffle=shuffle and not cfg.dataset.streaming,
|
|
sampler=sampler,
|
|
pin_memory=device.type == "cuda",
|
|
drop_last=False,
|
|
collate_fn=collate_fn,
|
|
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
|
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
|
)
|
|
|
|
# Prepare everything with accelerator
|
|
accelerator.wait_for_everyone()
|
|
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
|
policy, optimizer, dataloader, lr_scheduler
|
|
)
|
|
dl_iter = cycle(dataloader)
|
|
|
|
policy.train()
|
|
|
|
train_metrics = {
|
|
"loss": AverageMeter("loss", ":.3f"),
|
|
"grad_norm": AverageMeter("grdn", ":.3f"),
|
|
"lr": AverageMeter("lr", ":0.1e"),
|
|
"update_s": AverageMeter("updt_s", ":.3f"),
|
|
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
|
}
|
|
|
|
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
|
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
|
train_tracker = MetricsTracker(
|
|
cfg.batch_size,
|
|
dataset.num_frames,
|
|
dataset.num_episodes,
|
|
train_metrics,
|
|
initial_step=step,
|
|
accelerator=accelerator,
|
|
)
|
|
|
|
if is_main_process:
|
|
progbar = tqdm(
|
|
total=cfg.steps - step,
|
|
desc="Training",
|
|
unit="step",
|
|
disable=inside_slurm(),
|
|
position=0,
|
|
leave=True,
|
|
)
|
|
logging.info(
|
|
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
|
)
|
|
|
|
for _ in range(step, cfg.steps):
|
|
start_time = time.perf_counter()
|
|
batch = next(dl_iter)
|
|
for cam_key in dataset.meta.camera_keys:
|
|
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
|
|
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
|
|
batch = preprocessor(batch)
|
|
train_tracker.dataloading_s = time.perf_counter() - start_time
|
|
|
|
train_tracker, output_dict = update_policy(
|
|
train_tracker,
|
|
policy,
|
|
batch,
|
|
optimizer,
|
|
cfg.optimizer.grad_clip_norm,
|
|
accelerator=accelerator,
|
|
lr_scheduler=lr_scheduler,
|
|
sample_weighter=sample_weighter,
|
|
)
|
|
|
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
|
# increment `step` here.
|
|
step += 1
|
|
if is_main_process:
|
|
progbar.update(1)
|
|
train_tracker.step()
|
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
|
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
|
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
|
|
|
if is_log_step:
|
|
logging.info(train_tracker)
|
|
if wandb_logger:
|
|
wandb_log_dict = train_tracker.to_dict()
|
|
if output_dict:
|
|
wandb_log_dict.update(output_dict)
|
|
# Log sample weighting statistics if enabled
|
|
if sample_weighter is not None:
|
|
weighter_stats = sample_weighter.get_stats()
|
|
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
|
wandb_logger.log_dict(wandb_log_dict, step)
|
|
train_tracker.reset_averages()
|
|
|
|
if cfg.save_checkpoint and is_saving_step:
|
|
if is_main_process:
|
|
logging.info(f"Checkpoint policy after step {step}")
|
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
|
save_checkpoint(
|
|
checkpoint_dir=checkpoint_dir,
|
|
step=step,
|
|
cfg=cfg,
|
|
policy=accelerator.unwrap_model(policy),
|
|
optimizer=optimizer,
|
|
scheduler=lr_scheduler,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
)
|
|
update_last_checkpoint(checkpoint_dir)
|
|
if wandb_logger:
|
|
wandb_logger.log_policy(checkpoint_dir)
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
if cfg.env and is_eval_step:
|
|
if is_main_process:
|
|
step_id = get_step_identifier(step, cfg.steps)
|
|
logging.info(f"Eval policy at step {step}")
|
|
with torch.no_grad(), accelerator.autocast():
|
|
eval_info = eval_policy_all(
|
|
envs=eval_env, # dict[suite][task_id] -> vec_env
|
|
policy=accelerator.unwrap_model(policy),
|
|
env_preprocessor=env_preprocessor,
|
|
env_postprocessor=env_postprocessor,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
n_episodes=cfg.eval.n_episodes,
|
|
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
|
max_episodes_rendered=4,
|
|
start_seed=cfg.seed,
|
|
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
|
)
|
|
# overall metrics (suite-agnostic)
|
|
aggregated = eval_info["overall"]
|
|
|
|
# optional: per-suite logging
|
|
for suite, suite_info in eval_info.items():
|
|
logging.info("Suite %s aggregated: %s", suite, suite_info)
|
|
|
|
# meters/tracker
|
|
eval_metrics = {
|
|
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
|
"pc_success": AverageMeter("success", ":.1f"),
|
|
"eval_s": AverageMeter("eval_s", ":.3f"),
|
|
}
|
|
eval_tracker = MetricsTracker(
|
|
cfg.batch_size,
|
|
dataset.num_frames,
|
|
dataset.num_episodes,
|
|
eval_metrics,
|
|
initial_step=step,
|
|
accelerator=accelerator,
|
|
)
|
|
eval_tracker.eval_s = aggregated.pop("eval_s")
|
|
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
|
|
eval_tracker.pc_success = aggregated.pop("pc_success")
|
|
if wandb_logger:
|
|
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
|
|
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
|
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
if is_main_process:
|
|
progbar.close()
|
|
|
|
if eval_env:
|
|
close_envs(eval_env)
|
|
|
|
if is_main_process:
|
|
logging.info("End of training")
|
|
|
|
if getattr(active_cfg, "push_to_hub", False):
|
|
unwrapped_model = accelerator.unwrap_model(policy)
|
|
# PEFT only applies when training a policy — reward models use the plain path.
|
|
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
|
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
|
else:
|
|
unwrapped_model.push_model_to_hub(cfg)
|
|
preprocessor.push_to_hub(active_cfg.repo_id)
|
|
postprocessor.push_to_hub(active_cfg.repo_id)
|
|
|
|
# Properly clean up the distributed process group
|
|
accelerator.wait_for_everyone()
|
|
accelerator.end_training()
|
|
|
|
|
|
def main():
|
|
register_third_party_plugins()
|
|
train()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|