small improvements in train

This commit is contained in:
Pepijn
2025-10-14 13:53:38 +02:00
parent cabc47c5ad
commit a0d0b00e04

View File

@@ -95,7 +95,8 @@ def update_policy(
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext():
# Let accelerator handle mixed precision
with accelerator.autocast() if accelerator else (torch.autocast(device_type=device.type) if use_amp else nullcontext()):
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
@@ -217,6 +218,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
)
policy.to(device)
# Wait for all processes to finish policy creation before continuing
if accelerator:
accelerator.wait_for_everyone()
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
@@ -299,6 +304,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
prefetch_factor=2 if cfg.num_workers > 0 else None,
)
if accelerator:
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler
)
@@ -380,8 +386,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
if accelerator:
accelerator.wait_for_everyone()
if accelerator:
accelerator.wait_for_everyone()
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
@@ -431,6 +438,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Callable | None = None):
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
if accelerator:
accelerator.wait_for_everyone()
if eval_env:
close_envs(eval_env)