2024-05-15 12:13:09 +02:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2024-02-29 23:13:06 +00:00
import logging
2024-04-29 12:27:58 +02:00
import time
2024-05-20 18:57:54 +01:00
from contextlib import nullcontext
2025-07-11 03:55:05 +00:00
from functools import partial
2024-05-28 12:04:23 +01:00
from pprint import pformat
2025-02-11 10:36:06 +01:00
from typing import Any
2024-02-10 15:46:24 +00:00
2024-01-29 12:49:30 +00:00
import torch
2025-02-11 10:36:06 +01:00
from termcolor import colored
2025-01-31 13:57:37 +01:00
from torch . amp import GradScaler
2025-02-11 10:36:06 +01:00
from torch . optim import Optimizer
2025-09-16 16:11:26 +00:00
import os
from datetime import timedelta
2024-01-29 12:49:30 +00:00
2025-07-01 16:34:46 +02:00
from lerobot . configs import parser
from lerobot . configs . train import TrainPipelineConfig
from lerobot . datasets . factory import make_dataset
from lerobot . datasets . sampler import EpisodeAwareSampler
from lerobot . datasets . utils import cycle
2025-07-11 03:55:05 +00:00
from lerobot . datasets . utils_must import multidataset_collate_fn
2025-07-01 16:34:46 +02:00
from lerobot . envs . factory import make_env
from lerobot . optim . factory import make_optimizer_and_scheduler
from lerobot . policies . factory import make_policy
from lerobot . policies . pretrained import PreTrainedPolicy
from lerobot . policies . utils import get_device_from_parameters
from lerobot . scripts . eval import eval_policy
from lerobot . utils . logging_utils import AverageMeter , MetricsTracker
from lerobot . utils . random_utils import set_seed
from lerobot . utils . train_utils import (
2025-02-11 10:36:06 +01:00
get_step_checkpoint_dir ,
get_step_identifier ,
load_training_state ,
save_checkpoint ,
update_last_checkpoint ,
)
2025-07-01 16:34:46 +02:00
from lerobot . utils . utils import (
2024-04-05 10:59:32 +00:00
format_big_number ,
get_safe_torch_device ,
2025-01-31 13:57:37 +01:00
has_method ,
2024-04-05 10:59:32 +00:00
init_logging ,
)
2025-07-01 16:34:46 +02:00
from lerobot . utils . wandb_utils import WandBLogger
2025-07-11 03:55:05 +00:00
2025-09-16 16:11:26 +00:00
def is_launched_with_accelerate ( ) - > bool :
return " ACCELERATE_MIXED_PRECISION " in os . environ
2024-01-29 12:49:30 +00:00
2024-05-20 18:57:54 +01:00
def update_policy (
2025-02-11 10:36:06 +01:00
train_metrics : MetricsTracker ,
policy : PreTrainedPolicy ,
batch : Any ,
optimizer : Optimizer ,
grad_clip_norm : float ,
2024-05-20 18:57:54 +01:00
grad_scaler : GradScaler ,
lr_scheduler = None ,
use_amp : bool = False ,
2024-07-25 11:16:38 +01:00
lock = None ,
2025-09-16 16:11:26 +00:00
accelerator = None ,
2025-02-11 10:36:06 +01:00
) - > tuple [ MetricsTracker , dict ] :
2024-05-20 18:57:54 +01:00
start_time = time . perf_counter ( )
device = get_device_from_parameters ( policy )
2024-04-29 12:27:58 +02:00
policy . train ( )
2025-09-16 16:11:26 +00:00
grad_norm = 0.0 # Initialize grad_norm to avoid undefined variable
if accelerator :
with accelerator . accumulate ( policy ) :
with torch . autocast ( device_type = device . type ) if use_amp else nullcontext ( ) :
loss , output_dict = policy . forward ( batch )
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
accelerator . backward ( loss )
if accelerator . sync_gradients :
grad_norm = torch . nn . utils . clip_grad_norm_ (
policy . parameters ( ) ,
grad_clip_norm ,
error_if_nonfinite = False ,
)
optimizer . step ( )
optimizer . zero_grad ( )
else :
# Standard training loop without accelerate
with torch . autocast ( device_type = device . type ) if use_amp else nullcontext ( ) :
loss , output_dict = policy . forward ( batch )
2024-05-20 18:57:54 +01:00
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
2025-09-16 16:11:26 +00:00
grad_scaler . scale ( loss ) . backward ( )
grad_scaler . unscale_ ( optimizer )
grad_norm = torch . nn . utils . clip_grad_norm_ (
policy . parameters ( ) ,
grad_clip_norm ,
error_if_nonfinite = False ,
)
2024-07-25 11:16:38 +01:00
grad_scaler . step ( optimizer )
2025-09-16 16:11:26 +00:00
grad_scaler . update ( )
optimizer . zero_grad ( )
2024-05-01 16:40:04 +01:00
2025-02-04 18:01:04 +01:00
# Step through pytorch scheduler at every batch instead of epoch
2024-04-29 12:27:58 +02:00
if lr_scheduler is not None :
lr_scheduler . step ( )
2025-01-31 13:57:37 +01:00
if has_method ( policy , " update " ) :
2025-09-16 16:11:26 +00:00
if accelerator :
accelerator . unwrap_model ( policy , keep_fp32_wrapper = True ) . update ( )
else :
policy . update ( )
2025-02-11 10:36:06 +01:00
train_metrics . loss = loss . item ( )
train_metrics . grad_norm = grad_norm . item ( )
train_metrics . lr = optimizer . param_groups [ 0 ] [ " lr " ]
train_metrics . update_s = time . perf_counter ( ) - start_time
return train_metrics , output_dict
2024-02-26 01:10:09 +00:00
2025-01-31 13:57:37 +01:00
@parser.wrap ( )
def train ( cfg : TrainPipelineConfig ) :
cfg . validate ( )
2025-09-16 16:11:26 +00:00
accelerator = None # Initialize accelerator variable
if is_launched_with_accelerate ( ) :
import accelerate
# For example pi0 has unused params (last llm block)
from accelerate import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs ( find_unused_parameters = True )
# accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs])
from accelerate import InitProcessGroupKwargs
# Set NCCL timeout (default 30 minutes = 1800 seconds)
nccl_timeout = getattr ( cfg , ' nccl_timeout ' , 1800 )
ddp_init_kwargs = InitProcessGroupKwargs ( timeout = timedelta ( seconds = nccl_timeout ) ) # FIXME(mshukor): allow user to set timeout. This should be longer than the evaluation time
# Set gradient accumulation steps (default 1)
gradient_accumulation_steps = getattr ( cfg , ' gradient_accumulation_steps ' , 1 )
accelerator = accelerate . Accelerator ( step_scheduler_with_optimizer = False , gradient_accumulation_steps = gradient_accumulation_steps , kwargs_handlers = [ ddp_init_kwargs , ddp_kwargs ] )
if accelerator is not None and not accelerator . is_main_process :
# Disable duplicate logging on non-main processes
logging . info ( f " Setting logging level on non-main process { accelerator . process_index } to WARNING. " )
logging . getLogger ( ) . setLevel ( logging . WARNING )
2025-02-11 10:36:06 +01:00
logging . info ( pformat ( cfg . to_dict ( ) ) )
2024-05-28 12:04:23 +01:00
2025-09-16 16:11:26 +00:00
if accelerator and not accelerator . is_main_process :
# Disable logging on non-main processes.
cfg . wandb . enable = False
2025-02-11 10:36:06 +01:00
if cfg . wandb . enable and cfg . wandb . project :
wandb_logger = WandBLogger ( cfg )
else :
wandb_logger = None
logging . info ( colored ( " Logs will be saved locally. " , " yellow " , attrs = [ " bold " ] ) )
2024-05-28 12:04:23 +01:00
2025-01-31 13:57:37 +01:00
if cfg . seed is not None :
2025-02-11 10:36:06 +01:00
set_seed ( cfg . seed )
2024-05-28 12:04:23 +01:00
2024-03-20 18:38:55 +01:00
# Check device is available
2025-03-06 17:59:28 +01:00
device = get_safe_torch_device ( cfg . policy . device , log = True )
2024-02-24 18:18:39 +00:00
torch . backends . cudnn . benchmark = True
2024-03-02 15:53:29 +00:00
torch . backends . cuda . matmul . allow_tf32 = True
2024-01-29 12:49:30 +00:00
2025-01-31 13:57:37 +01:00
logging . info ( " Creating dataset " )
2025-02-11 10:36:06 +01:00
dataset = make_dataset ( cfg )
2024-02-10 15:46:24 +00:00
2024-05-30 13:45:22 +02:00
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
2024-06-10 19:09:48 +01:00
eval_env = None
2025-01-31 13:57:37 +01:00
if cfg . eval_freq > 0 and cfg . env is not None :
logging . info ( " Creating env " )
2025-04-04 11:51:11 +02:00
eval_env = make_env ( cfg . env , n_envs = cfg . eval . batch_size , use_async_envs = cfg . eval . use_async_envs )
2024-03-02 15:53:29 +00:00
2025-01-31 13:57:37 +01:00
logging . info ( " Creating policy " )
2024-05-28 12:04:23 +01:00
policy = make_policy (
2025-01-31 13:57:37 +01:00
cfg = cfg . policy ,
2025-02-11 10:36:06 +01:00
ds_meta = dataset . meta ,
2024-05-28 12:04:23 +01:00
)
2025-02-04 18:01:04 +01:00
2025-01-31 13:57:37 +01:00
logging . info ( " Creating optimizer and scheduler " )
2024-05-04 16:20:30 +02:00
optimizer , lr_scheduler = make_optimizer_and_scheduler ( cfg , policy )
2025-03-06 17:59:28 +01:00
grad_scaler = GradScaler ( device . type , enabled = cfg . policy . use_amp )
2024-04-29 12:27:58 +02:00
2024-05-28 12:04:23 +01:00
step = 0 # number of policy updates (forward + backward + optim)
if cfg . resume :
2025-01-31 13:57:37 +01:00
step , optimizer , lr_scheduler = load_training_state ( cfg . checkpoint_path , optimizer , lr_scheduler )
2024-05-28 12:04:23 +01:00
2024-03-04 10:59:43 +00:00
num_learnable_params = sum ( p . numel ( ) for p in policy . parameters ( ) if p . requires_grad )
num_total_params = sum ( p . numel ( ) for p in policy . parameters ( ) )
2025-02-11 10:36:06 +01:00
logging . info ( colored ( " Output dir: " , " yellow " , attrs = [ " bold " ] ) + f " { cfg . output_dir } " )
2025-01-31 13:57:37 +01:00
if cfg . env is not None :
logging . info ( f " { cfg . env . task =} " )
2025-02-11 10:36:06 +01:00
logging . info ( f " { cfg . steps =} ( { format_big_number ( cfg . steps ) } ) " )
logging . info ( f " { dataset . num_frames =} ( { format_big_number ( dataset . num_frames ) } ) " )
logging . info ( f " { dataset . num_episodes =} " )
2024-03-04 10:59:43 +00:00
logging . info ( f " { num_learnable_params =} ( { format_big_number ( num_learnable_params ) } ) " )
logging . info ( f " { num_total_params =} ( { format_big_number ( num_total_params ) } ) " )
2024-04-10 11:34:01 +00:00
# create dataloader for offline training
2025-02-11 10:36:06 +01:00
if hasattr ( cfg . policy , " drop_n_last_frames " ) :
2024-05-31 22:43:47 +10:00
shuffle = False
sampler = EpisodeAwareSampler (
2025-02-11 10:36:06 +01:00
dataset . episode_data_index ,
2025-01-31 13:57:37 +01:00
drop_n_last_frames = cfg . policy . drop_n_last_frames ,
2024-05-31 22:43:47 +10:00
shuffle = True ,
)
else :
shuffle = True
sampler = None
2025-07-11 15:50:22 +02:00
2025-07-10 23:51:47 -04:00
keys_to_max_dim = getattr ( dataset . meta , " keys_to_max_dim " , { } )
2025-07-11 15:50:22 +02:00
keys_to_max_dim = {
" action " : ( 32 , ) ,
" observation.state " : ( 32 , ) ,
" observation.image " : ( 3 , 1080 , 1920 ) ,
" observation.image2 " : ( 3 , 1080 , 1920 ) ,
}
2025-07-10 23:51:47 -04:00
collate_fn = partial ( multidataset_collate_fn , keys_to_max_dim = keys_to_max_dim )
2024-03-31 15:05:25 +00:00
dataloader = torch . utils . data . DataLoader (
2025-02-11 10:36:06 +01:00
dataset ,
2025-07-11 15:50:22 +02:00
collate_fn = collate_fn ,
2025-01-31 13:57:37 +01:00
num_workers = cfg . num_workers ,
batch_size = cfg . batch_size ,
2024-05-31 22:43:47 +10:00
shuffle = shuffle ,
sampler = sampler ,
2024-05-20 18:57:54 +01:00
pin_memory = device . type != " cpu " ,
2024-04-10 14:59:54 +00:00
drop_last = False ,
2025-09-16 16:11:26 +00:00
) # Most important line
if accelerator :
policy , optimizer , dataloader , lr_scheduler = accelerator . prepare (
policy , optimizer , dataloader , lr_scheduler
)
2024-03-31 15:05:25 +00:00
dl_iter = cycle ( dataloader )
2024-04-10 11:34:01 +00:00
2024-05-01 16:40:04 +01:00
policy . train ( )
2025-02-04 18:01:04 +01:00
2025-02-11 10:36:06 +01:00
train_metrics = {
" loss " : AverageMeter ( " loss " , " :.3f " ) ,
" grad_norm " : AverageMeter ( " grdn " , " :.3f " ) ,
" lr " : AverageMeter ( " lr " , " :0.1e " ) ,
" update_s " : AverageMeter ( " updt_s " , " :.3f " ) ,
" dataloading_s " : AverageMeter ( " data_s " , " :.3f " ) ,
}
2025-02-04 18:01:04 +01:00
2025-02-11 10:36:06 +01:00
train_tracker = MetricsTracker (
cfg . batch_size , dataset . num_frames , dataset . num_episodes , train_metrics , initial_step = step
)
2024-06-04 21:32:05 +05:30
2025-02-11 10:36:06 +01:00
logging . info ( " Start offline training on a fixed dataset " )
for _ in range ( step , cfg . steps ) :
2024-06-04 21:32:05 +05:30
start_time = time . perf_counter ( )
2024-03-31 15:05:25 +00:00
batch = next ( dl_iter )
2025-02-11 10:36:06 +01:00
train_tracker . dataloading_s = time . perf_counter ( ) - start_time
2024-03-31 15:05:25 +00:00
for key in batch :
2025-02-04 18:01:04 +01:00
if isinstance ( batch [ key ] , torch . Tensor ) :
batch [ key ] = batch [ key ] . to ( device , non_blocking = True )
2024-05-20 18:57:54 +01:00
2025-02-11 10:36:06 +01:00
train_tracker , output_dict = update_policy (
train_tracker ,
2024-05-20 18:57:54 +01:00
policy ,
batch ,
optimizer ,
2025-01-31 13:57:37 +01:00
cfg . optimizer . grad_clip_norm ,
2024-05-20 18:57:54 +01:00
grad_scaler = grad_scaler ,
lr_scheduler = lr_scheduler ,
2025-03-06 17:59:28 +01:00
use_amp = cfg . policy . use_amp ,
2025-09-16 16:11:26 +00:00
accelerator = accelerator ,
2024-05-20 18:57:54 +01:00
)
2024-03-31 15:05:25 +00:00
2025-02-11 10:36:06 +01:00
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
2024-05-28 12:04:23 +01:00
step + = 1
2025-02-11 10:36:06 +01:00
train_tracker . step ( )
is_log_step = cfg . log_freq > 0 and step % cfg . log_freq == 0
is_saving_step = step % cfg . save_freq == 0 or step == cfg . steps
is_eval_step = cfg . eval_freq > 0 and step % cfg . eval_freq == 0
if is_log_step :
logging . info ( train_tracker )
if wandb_logger :
2025-02-14 18:00:12 +01:00
wandb_log_dict = train_tracker . to_dict ( )
if output_dict :
wandb_log_dict . update ( output_dict )
2025-02-12 12:53:55 +01:00
wandb_logger . log_dict ( wandb_log_dict , step )
2025-02-11 10:36:06 +01:00
train_tracker . reset_averages ( )
if cfg . save_checkpoint and is_saving_step :
logging . info ( f " Checkpoint policy after step { step } " )
checkpoint_dir = get_step_checkpoint_dir ( cfg . output_dir , cfg . steps , step )
2025-09-16 16:11:26 +00:00
# Unwrap policy from accelerate if needed
unwrapped_policy = accelerator . unwrap_model ( policy ) if accelerator else policy
save_checkpoint ( checkpoint_dir , step , cfg , unwrapped_policy , optimizer , lr_scheduler )
2025-02-11 10:36:06 +01:00
update_last_checkpoint ( checkpoint_dir )
if wandb_logger :
wandb_logger . log_policy ( checkpoint_dir )
if cfg . env and is_eval_step :
step_id = get_step_identifier ( step , cfg . steps )
logging . info ( f " Eval policy at step { step } " )
2025-03-06 17:59:28 +01:00
with (
torch . no_grad ( ) ,
torch . autocast ( device_type = device . type ) if cfg . policy . use_amp else nullcontext ( ) ,
) :
2025-09-16 16:11:26 +00:00
# Unwrap policy from accelerate if needed for evaluation
unwrapped_policy = accelerator . unwrap_model ( policy ) if accelerator else policy
2024-07-25 11:16:38 +01:00
eval_info = eval_policy (
2025-02-11 10:36:06 +01:00
eval_env ,
2025-09-16 16:11:26 +00:00
unwrapped_policy ,
2025-02-11 10:36:06 +01:00
cfg . eval . n_episodes ,
videos_dir = cfg . output_dir / " eval " / f " videos_step_ { step_id } " ,
max_episodes_rendered = 4 ,
start_seed = cfg . seed ,
2024-07-25 11:16:38 +01:00
)
2025-02-11 10:36:06 +01:00
eval_metrics = {
" avg_sum_reward " : AverageMeter ( " ∑rwrd " , " :.3f " ) ,
" pc_success " : AverageMeter ( " success " , " :.1f " ) ,
" eval_s " : AverageMeter ( " eval_s " , " :.3f " ) ,
}
eval_tracker = MetricsTracker (
cfg . batch_size , dataset . num_frames , dataset . num_episodes , eval_metrics , initial_step = step
)
eval_tracker . eval_s = eval_info [ " aggregated " ] . pop ( " eval_s " )
eval_tracker . avg_sum_reward = eval_info [ " aggregated " ] . pop ( " avg_sum_reward " )
eval_tracker . pc_success = eval_info [ " aggregated " ] . pop ( " pc_success " )
logging . info ( eval_tracker )
if wandb_logger :
wandb_log_dict = { * * eval_tracker . to_dict ( ) , * * eval_info }
2025-02-12 12:53:55 +01:00
wandb_logger . log_dict ( wandb_log_dict , step , mode = " eval " )
2025-02-11 10:36:06 +01:00
wandb_logger . log_video ( eval_info [ " video_paths " ] [ 0 ] , step , mode = " eval " )
2024-05-28 12:04:23 +01:00
2024-06-10 19:09:48 +01:00
if eval_env :
eval_env . close ( )
2024-05-30 16:12:21 +01:00
logging . info ( " End of training " )
2024-03-06 10:14:03 +00:00
2025-06-26 14:36:16 +02:00
if cfg . policy . push_to_hub :
2025-09-16 16:11:26 +00:00
# Unwrap policy from accelerate if needed
unwrapped_policy = accelerator . unwrap_model ( policy ) if accelerator else policy
unwrapped_policy . push_model_to_hub ( cfg )
2025-06-26 14:36:16 +02:00
2024-01-29 12:49:30 +00:00
if __name__ == " __main__ " :
2025-01-31 13:57:37 +01:00
init_logging ( )
train ( )