mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
small improvements in train
This commit is contained in:
@@ -95,7 +95,8 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext():
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast() if accelerator else (torch.autocast(device_type=device.type) if use_amp else nullcontext()):
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
@@ -217,6 +218,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
)
|
||||
policy.to(device)
|
||||
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
@@ -299,6 +304,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
||||
)
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
@@ -380,8 +386,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
@@ -431,6 +438,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
|
||||
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
|
||||
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
|
||||
|
||||
if accelerator:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user