#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import time from contextlib import nullcontext from functools import partial from pprint import pformat from typing import Any import torch from termcolor import colored from torch.amp import GradScaler from torch.optim import Optimizer import os from datetime import timedelta from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle from lerobot.datasets.utils_must import multidataset_collate_fn from lerobot.envs.factory import make_env from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters from lerobot.scripts.eval import eval_policy from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( get_step_checkpoint_dir, get_step_identifier, load_training_state, save_checkpoint, update_last_checkpoint, ) from lerobot.utils.utils import ( format_big_number, get_safe_torch_device, has_method, init_logging, ) from lerobot.utils.wandb_utils import WandBLogger def is_launched_with_accelerate() -> bool: return "ACCELERATE_MIXED_PRECISION" in os.environ def update_policy( train_metrics: MetricsTracker, policy: PreTrainedPolicy, batch: Any, optimizer: Optimizer, grad_clip_norm: float, grad_scaler: GradScaler, lr_scheduler=None, use_amp: bool = False, lock=None, accelerator=None, ) -> tuple[MetricsTracker, dict]: start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() grad_norm = 0.0 # Initialize grad_norm to avoid undefined variable if accelerator: with accelerator.accumulate(policy): with torch.autocast(device_type=device.type) if use_amp else nullcontext(): loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) accelerator.backward(loss) if accelerator.sync_gradients: grad_norm = torch.nn.utils.clip_grad_norm_( policy.parameters(), grad_clip_norm, error_if_nonfinite=False, ) optimizer.step() optimizer.zero_grad() else: # Standard training loop without accelerate with torch.autocast(device_type=device.type) if use_amp else nullcontext(): loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) grad_scaler.scale(loss).backward() grad_scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( policy.parameters(), grad_clip_norm, error_if_nonfinite=False, ) grad_scaler.step(optimizer) grad_scaler.update() optimizer.zero_grad() # Step through pytorch scheduler at every batch instead of epoch if lr_scheduler is not None: lr_scheduler.step() if has_method(policy, "update"): if accelerator: accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() else: policy.update() train_metrics.loss = loss.item() train_metrics.grad_norm = grad_norm.item() train_metrics.lr = optimizer.param_groups[0]["lr"] train_metrics.update_s = time.perf_counter() - start_time return train_metrics, output_dict @parser.wrap() def train(cfg: TrainPipelineConfig): cfg.validate() accelerator = None # Initialize accelerator variable if is_launched_with_accelerate(): import accelerate # For example pi0 has unused params (last llm block) from accelerate import DistributedDataParallelKwargs ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) # accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs]) from accelerate import InitProcessGroupKwargs # Set NCCL timeout (default 30 minutes = 1800 seconds) nccl_timeout = getattr(cfg, 'nccl_timeout', 1800) ddp_init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=nccl_timeout)) # FIXME(mshukor): allow user to set timeout. This should be longer than the evaluation time # Set gradient accumulation steps (default 1) gradient_accumulation_steps = getattr(cfg, 'gradient_accumulation_steps', 1) accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=gradient_accumulation_steps, kwargs_handlers=[ddp_init_kwargs, ddp_kwargs]) if accelerator is not None and not accelerator.is_main_process: # Disable duplicate logging on non-main processes logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.") logging.getLogger().setLevel(logging.WARNING) logging.info(pformat(cfg.to_dict())) if accelerator and not accelerator.is_main_process: # Disable logging on non-main processes. cfg.wandb.enable = False if cfg.wandb.enable and cfg.wandb.project: wandb_logger = WandBLogger(cfg) else: wandb_logger = None logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) if cfg.seed is not None: set_seed(cfg.seed) # Check device is available device = get_safe_torch_device(cfg.policy.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True logging.info("Creating dataset") dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. eval_env = None if cfg.eval_freq > 0 and cfg.env is not None: logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) logging.info("Creating policy") policy = make_policy( cfg=cfg.policy, ds_meta=dataset.meta, ) logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) step = 0 # number of policy updates (forward + backward + optim) if cfg.resume: step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") if cfg.env is not None: logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") logging.info(f"{dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training if hasattr(cfg.policy, "drop_n_last_frames"): shuffle = False sampler = EpisodeAwareSampler( dataset.episode_data_index, drop_n_last_frames=cfg.policy.drop_n_last_frames, shuffle=True, ) else: shuffle = True sampler = None keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {}) keys_to_max_dim = { "action": (32,), "observation.state": (32,), "observation.image": (3, 1080, 1920), "observation.image2": (3, 1080, 1920), } collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim) dataloader = torch.utils.data.DataLoader( dataset, collate_fn=collate_fn, num_workers=cfg.num_workers, batch_size=cfg.batch_size, shuffle=shuffle, sampler=sampler, pin_memory=device.type != "cpu", drop_last=False, ) # Most important line if accelerator: policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( policy, optimizer, dataloader, lr_scheduler ) dl_iter = cycle(dataloader) policy.train() train_metrics = { "loss": AverageMeter("loss", ":.3f"), "grad_norm": AverageMeter("grdn", ":.3f"), "lr": AverageMeter("lr", ":0.1e"), "update_s": AverageMeter("updt_s", ":.3f"), "dataloading_s": AverageMeter("data_s", ":.3f"), } train_tracker = MetricsTracker( cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step ) logging.info("Start offline training on a fixed dataset") for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) train_tracker.dataloading_s = time.perf_counter() - start_time for key in batch: if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to(device, non_blocking=True) train_tracker, output_dict = update_policy( train_tracker, policy, batch, optimizer, cfg.optimizer.grad_clip_norm, grad_scaler=grad_scaler, lr_scheduler=lr_scheduler, use_amp=cfg.policy.use_amp, accelerator=accelerator, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here. step += 1 train_tracker.step() is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 if is_log_step: logging.info(train_tracker) if wandb_logger: wandb_log_dict = train_tracker.to_dict() if output_dict: wandb_log_dict.update(output_dict) wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step: logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) # Unwrap policy from accelerate if needed unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler) update_last_checkpoint(checkpoint_dir) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) if cfg.env and is_eval_step: step_id = get_step_identifier(step, cfg.steps) logging.info(f"Eval policy at step {step}") with ( torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): # Unwrap policy from accelerate if needed for evaluation unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy eval_info = eval_policy( eval_env, unwrapped_policy, cfg.eval.n_episodes, videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", max_episodes_rendered=4, start_seed=cfg.seed, ) eval_metrics = { "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), "pc_success": AverageMeter("success", ":.1f"), "eval_s": AverageMeter("eval_s", ":.3f"), } eval_tracker = MetricsTracker( cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step ) eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success") logging.info(eval_tracker) if wandb_logger: wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} wandb_logger.log_dict(wandb_log_dict, step, mode="eval") wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") if eval_env: eval_env.close() logging.info("End of training") if cfg.policy.push_to_hub: # Unwrap policy from accelerate if needed unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy unwrapped_policy.push_model_to_hub(cfg) if __name__ == "__main__": init_logging() train()