mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 03:41:25 +00:00
Update lerobot Python modules and add test training script
- Enhanced dataset processing and statistics computation - Updated policy factory and normalization - Improved SmolVLA2 modeling and expert integration - Enhanced training and evaluation scripts - Added utility improvements for training and wandb integration - Added test training script with 2 datasets for validation
This commit is contained in:
@@ -24,6 +24,9 @@ import torch
|
||||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -54,6 +57,8 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
from lerobot.utils.wandb_utils import WandBLogger
|
||||
|
||||
def is_launched_with_accelerate() -> bool:
|
||||
return "ACCELERATE_MIXED_PRECISION" in os.environ
|
||||
|
||||
def update_policy(
|
||||
train_metrics: MetricsTracker,
|
||||
@@ -65,41 +70,55 @@ def update_policy(
|
||||
lr_scheduler=None,
|
||||
use_amp: bool = False,
|
||||
lock=None,
|
||||
accelerator=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
grad_norm = 0.0 # Initialize grad_norm to avoid undefined variable
|
||||
|
||||
if accelerator:
|
||||
with accelerator.accumulate(policy):
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
else:
|
||||
# Standard training loop without accelerate
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||
grad_scaler.unscale_(optimizer)
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
with lock if lock is not None else nullcontext():
|
||||
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
grad_scaler.step(optimizer)
|
||||
# Updates the scale for next iteration.
|
||||
grad_scaler.update()
|
||||
|
||||
optimizer.zero_grad()
|
||||
grad_scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Step through pytorch scheduler at every batch instead of epoch
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if has_method(policy, "update"):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
if accelerator:
|
||||
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
|
||||
else:
|
||||
policy.update()
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
||||
@@ -110,8 +129,34 @@ def update_policy(
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
|
||||
accelerator = None # Initialize accelerator variable
|
||||
|
||||
if is_launched_with_accelerate():
|
||||
import accelerate
|
||||
|
||||
# For example pi0 has unused params (last llm block)
|
||||
from accelerate import DistributedDataParallelKwargs
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
|
||||
from accelerate import InitProcessGroupKwargs
|
||||
# Set NCCL timeout (default 30 minutes = 1800 seconds)
|
||||
nccl_timeout = getattr(cfg, 'nccl_timeout', 1800)
|
||||
ddp_init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=nccl_timeout)) # FIXME(mshukor): allow user to set timeout. This should be longer than the evaluation time
|
||||
# Set gradient accumulation steps (default 1)
|
||||
gradient_accumulation_steps = getattr(cfg, 'gradient_accumulation_steps', 1)
|
||||
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=gradient_accumulation_steps, kwargs_handlers=[ddp_init_kwargs, ddp_kwargs])
|
||||
if accelerator is not None and not accelerator.is_main_process:
|
||||
# Disable duplicate logging on non-main processes
|
||||
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
|
||||
if accelerator and not accelerator.is_main_process:
|
||||
# Disable logging on non-main processes.
|
||||
cfg.wandb.enable = False
|
||||
|
||||
if cfg.wandb.enable and cfg.wandb.project:
|
||||
wandb_logger = WandBLogger(cfg)
|
||||
else:
|
||||
@@ -193,7 +238,11 @@ def train(cfg: TrainPipelineConfig):
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
) # Most important line
|
||||
if accelerator:
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
@@ -229,6 +278,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
grad_scaler=grad_scaler,
|
||||
lr_scheduler=lr_scheduler,
|
||||
use_amp=cfg.policy.use_amp,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -251,7 +301,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
# Unwrap policy from accelerate if needed
|
||||
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
|
||||
save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
@@ -263,9 +315,11 @@ def train(cfg: TrainPipelineConfig):
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
# Unwrap policy from accelerate if needed for evaluation
|
||||
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
unwrapped_policy,
|
||||
cfg.eval.n_episodes,
|
||||
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
|
||||
max_episodes_rendered=4,
|
||||
@@ -294,7 +348,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
logging.info("End of training")
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
policy.push_model_to_hub(cfg)
|
||||
# Unwrap policy from accelerate if needed
|
||||
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
|
||||
unwrapped_policy.push_model_to_hub(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user