mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
use siglip 2
This commit is contained in:
@@ -67,10 +67,18 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
device = get_device_from_parameters(policy)
|
||||
policy.train()
|
||||
|
||||
# Forward pass timing
|
||||
forward_start = time.perf_counter()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
forward_time = time.perf_counter() - forward_start
|
||||
|
||||
# Backward pass timing
|
||||
backward_start = time.perf_counter()
|
||||
grad_scaler.scale(loss).backward()
|
||||
backward_time = time.perf_counter() - backward_start
|
||||
|
||||
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
|
||||
grad_scaler.unscale_(optimizer)
|
||||
@@ -81,6 +89,9 @@ def update_policy(
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
# Optimizer step timing
|
||||
optim_start = time.perf_counter()
|
||||
|
||||
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
|
||||
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
|
||||
with lock if lock is not None else nullcontext():
|
||||
@@ -97,6 +108,19 @@ def update_policy(
|
||||
if has_method(policy, "update"):
|
||||
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
|
||||
policy.update()
|
||||
|
||||
optim_time = time.perf_counter() - optim_start
|
||||
total_time = time.perf_counter() - start_time
|
||||
|
||||
# Print detailed timing for RLearN policy
|
||||
if getattr(policy, "name", None) == "rlearn":
|
||||
print(f"Training Step Timing:")
|
||||
print(f" Forward pass: {forward_time*1000:.2f} ms")
|
||||
print(f" Backward pass: {backward_time*1000:.2f} ms")
|
||||
print(f" Optimizer step: {optim_time*1000:.2f} ms")
|
||||
print(f" Total update: {total_time*1000:.2f} ms")
|
||||
print(f" Steps/sec: {1.0/total_time:.2f}")
|
||||
print("-" * 40)
|
||||
|
||||
train_metrics.loss = loss.item()
|
||||
train_metrics.grad_norm = grad_norm.item()
|
||||
@@ -213,10 +237,17 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
# Data loading timing
|
||||
data_start = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
data_loading_time = time.perf_counter() - data_start
|
||||
|
||||
# Preprocessing timing
|
||||
preprocess_start = time.perf_counter()
|
||||
batch = preprocessor(batch)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
preprocess_time = time.perf_counter() - preprocess_start
|
||||
|
||||
train_tracker.dataloading_s = data_loading_time + preprocess_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
@@ -256,13 +287,22 @@ def train(cfg: TrainPipelineConfig):
|
||||
total_pixels += sum(_count_pixels(t) for t in v)
|
||||
|
||||
# Avoid div-by-zero
|
||||
upd_s = max(train_tracker.update_s, 1e-8)
|
||||
meter = train_tracker.update_s
|
||||
upd_s = meter.val if isinstance(meter, AverageMeter) else float(meter)
|
||||
upd_s = max(upd_s, 1e-8)
|
||||
pix_per_s = float(total_pixels) / upd_s
|
||||
try:
|
||||
train_tracker.pix_s = pix_per_s
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Print data loading timing for RLearN
|
||||
if getattr(policy, "name", None) == "rlearn":
|
||||
print(f"Data Pipeline Timing:")
|
||||
print(f" Data loading: {data_loading_time*1000:.2f} ms")
|
||||
print(f" Preprocessing: {preprocess_time*1000:.2f} ms")
|
||||
print(f" Total data pipeline: {(data_loading_time + preprocess_time)*1000:.2f} ms")
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
|
||||
Reference in New Issue
Block a user