use siglip 2

This commit is contained in:
Pepijn
2025-08-30 14:28:55 +02:00
parent 76e260c401
commit bde397e891
7 changed files with 224 additions and 1397 deletions

View File

@@ -67,10 +67,18 @@ def update_policy(
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
# Forward pass timing
forward_start = time.perf_counter()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
forward_time = time.perf_counter() - forward_start
# Backward pass timing
backward_start = time.perf_counter()
grad_scaler.scale(loss).backward()
backward_time = time.perf_counter() - backward_start
# Unscale the gradient of the optimizer's assigned params in-place **prior to gradient clipping**.
grad_scaler.unscale_(optimizer)
@@ -81,6 +89,9 @@ def update_policy(
error_if_nonfinite=False,
)
# Optimizer step timing
optim_start = time.perf_counter()
# Optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
with lock if lock is not None else nullcontext():
@@ -97,6 +108,19 @@ def update_policy(
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
optim_time = time.perf_counter() - optim_start
total_time = time.perf_counter() - start_time
# Print detailed timing for RLearN policy
if getattr(policy, "name", None) == "rlearn":
print(f"Training Step Timing:")
print(f" Forward pass: {forward_time*1000:.2f} ms")
print(f" Backward pass: {backward_time*1000:.2f} ms")
print(f" Optimizer step: {optim_time*1000:.2f} ms")
print(f" Total update: {total_time*1000:.2f} ms")
print(f" Steps/sec: {1.0/total_time:.2f}")
print("-" * 40)
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
@@ -213,10 +237,17 @@ def train(cfg: TrainPipelineConfig):
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
# Data loading timing
data_start = time.perf_counter()
batch = next(dl_iter)
data_loading_time = time.perf_counter() - data_start
# Preprocessing timing
preprocess_start = time.perf_counter()
batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
preprocess_time = time.perf_counter() - preprocess_start
train_tracker.dataloading_s = data_loading_time + preprocess_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
@@ -256,13 +287,22 @@ def train(cfg: TrainPipelineConfig):
total_pixels += sum(_count_pixels(t) for t in v)
# Avoid div-by-zero
upd_s = max(train_tracker.update_s, 1e-8)
meter = train_tracker.update_s
upd_s = meter.val if isinstance(meter, AverageMeter) else float(meter)
upd_s = max(upd_s, 1e-8)
pix_per_s = float(total_pixels) / upd_s
try:
train_tracker.pix_s = pix_per_s
except Exception:
pass
# Print data loading timing for RLearN
if getattr(policy, "name", None) == "rlearn":
print(f"Data Pipeline Timing:")
print(f" Data loading: {data_loading_time*1000:.2f} ms")
print(f" Preprocessing: {preprocess_time*1000:.2f} ms")
print(f" Total data pipeline: {(data_loading_time + preprocess_time)*1000:.2f} ms")
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1