Update lerobot Python modules and add test training script

- Enhanced dataset processing and statistics computation
- Updated policy factory and normalization
- Improved SmolVLA2 modeling and expert integration
- Enhanced training and evaluation scripts
- Added utility improvements for training and wandb integration
- Added test training script with 2 datasets for validation
This commit is contained in:
danaaubakirova
2025-09-16 16:11:26 +00:00
parent 7848b15bfb
commit 6c8f1f962b
14 changed files with 440 additions and 52 deletions

View File

@@ -24,6 +24,9 @@ import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
import os
from datetime import timedelta
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
@@ -54,6 +57,8 @@ from lerobot.utils.utils import (
)
from lerobot.utils.wandb_utils import WandBLogger
def is_launched_with_accelerate() -> bool:
return "ACCELERATE_MIXED_PRECISION" in os.environ
def update_policy(
train_metrics: MetricsTracker,
@@ -65,41 +70,55 @@ def update_policy(
lr_scheduler=None,
use_amp: bool = False,
lock=None,
accelerator=None,
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
grad_norm = 0.0 # Initialize grad_norm to avoid undefined variable
if accelerator:
with accelerator.accumulate(policy):
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
accelerator.backward(loss)
if accelerator.sync_gradients:
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
optimizer.zero_grad()
else:
# Standard training loop without accelerate
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
grad_scaler.scale(loss).backward()
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
grad_scaler.scale(loss).backward()
grad_scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
grad_scaler.step(optimizer)
# Updates the scale for next iteration.
grad_scaler.update()
optimizer.zero_grad()
grad_scaler.update()
optimizer.zero_grad()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
if accelerator:
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
else:
policy.update()
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
@@ -110,8 +129,34 @@ def update_policy(
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
accelerator = None # Initialize accelerator variable
if is_launched_with_accelerate():
import accelerate
# For example pi0 has unused params (last llm block)
from accelerate import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
from accelerate import InitProcessGroupKwargs
# Set NCCL timeout (default 30 minutes = 1800 seconds)
nccl_timeout = getattr(cfg, 'nccl_timeout', 1800)
ddp_init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=nccl_timeout)) # FIXME(mshukor): allow user to set timeout. This should be longer than the evaluation time
# Set gradient accumulation steps (default 1)
gradient_accumulation_steps = getattr(cfg, 'gradient_accumulation_steps', 1)
accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, gradient_accumulation_steps=gradient_accumulation_steps, kwargs_handlers=[ddp_init_kwargs, ddp_kwargs])
if accelerator is not None and not accelerator.is_main_process:
# Disable duplicate logging on non-main processes
logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.")
logging.getLogger().setLevel(logging.WARNING)
logging.info(pformat(cfg.to_dict()))
if accelerator and not accelerator.is_main_process:
# Disable logging on non-main processes.
cfg.wandb.enable = False
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
@@ -193,7 +238,11 @@ def train(cfg: TrainPipelineConfig):
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
) # Most important line
if accelerator:
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
dl_iter = cycle(dataloader)
policy.train()
@@ -229,6 +278,7 @@ def train(cfg: TrainPipelineConfig):
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.policy.use_amp,
accelerator=accelerator,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -251,7 +301,9 @@ def train(cfg: TrainPipelineConfig):
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
# Unwrap policy from accelerate if needed
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
@@ -263,9 +315,11 @@ def train(cfg: TrainPipelineConfig):
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
# Unwrap policy from accelerate if needed for evaluation
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
eval_info = eval_policy(
eval_env,
policy,
unwrapped_policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
@@ -294,7 +348,9 @@ def train(cfg: TrainPipelineConfig):
logging.info("End of training")
if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg)
# Unwrap policy from accelerate if needed
unwrapped_policy = accelerator.unwrap_model(policy) if accelerator else policy
unwrapped_policy.push_model_to_hub(cfg)
if __name__ == "__main__":