fix lr scheduling

This commit is contained in:
Pepijn
2025-09-24 11:05:40 +02:00
parent 76d1430895
commit bab60cf02f

View File

@@ -92,16 +92,20 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / self.num_warmup_steps
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1
return 0.1 # Start at 10% of peak LR
if current_step >= self.num_warmup_steps:
return 1.0 # Reach 100% at end of warmup
# Linear interpolation from 0.1 to 1.0
return 0.1 + 0.9 * (current_step / self.num_warmup_steps)
def cosine_decay_schedule(current_step):
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
# Steps since warmup ended (this was the bug!)
decay_step = current_step - self.num_warmup_steps
decay_step = min(decay_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * decay_step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed
return (1 - alpha) * cosine_decay + alpha
if current_step < self.num_warmup_steps:
return linear_warmup_schedule(current_step)