regulalizer

This commit is contained in:
Pepijn
2025-09-01 12:07:37 +02:00
parent f84c20d403
commit c2bf226082
2 changed files with 18 additions and 4 deletions

View File

@@ -71,7 +71,7 @@ class RLearNConfig(PreTrainedConfig):
compile_model: bool = True
# ReWiND augmentation
rewind_prob: float = 0.5
rewind_prob: float = 0.8
rewind_last3_prob: float = 0.3
mismatch_prob: float = 0.2

View File

@@ -445,6 +445,18 @@ class RLearNPolicy(PreTrainedPolicy):
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
predicted_rewards = torch.sigmoid(raw_logits)
# Regularizers to avoid flat outputs and encourage local forward progress
# Encourage non-flat predictions per sample
var_min = 1e-3
pred = predicted_rewards
L_flat = F.relu(var_min - pred.var(dim=1, unbiased=False)).mean() if pred.shape[1] > 1 else torch.zeros((), device=device)
# Enforce local forward progress on logits without overconstraining
rank_margin = 0.02
if raw_logits.shape[1] > 1:
L_rank = F.relu(rank_margin - (raw_logits[:, 1:] - raw_logits[:, :-1])).mean()
else:
L_rank = torch.zeros((), device=device)
# Generate progress labels on-the-fly (ReWiND approach)
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
loss_dict: dict[str, float] = {}
@@ -476,7 +488,7 @@ class RLearNPolicy(PreTrainedPolicy):
eps = self.config.logit_eps
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
# Robust Huber (Smooth L1) on logits
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.5)
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.25)
total_loss = loss
@@ -523,12 +535,12 @@ class RLearNPolicy(PreTrainedPolicy):
eps = self.config.logit_eps
zeros_target_logits = torch.logit(torch.full_like(mismatch_raw_logits, eps))
mismatch_loss_per_sample = F.smooth_l1_loss(
mismatch_raw_logits, zeros_target_logits, beta=0.5, reduction='none'
mismatch_raw_logits, zeros_target_logits, beta=0.25, reduction='none'
).mean(dim=1)
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
# Total loss
total_loss = total_loss + L_mismatch
total_loss = total_loss + L_mismatch + 0.3 * L_rank + 0.05 * L_flat
loss_time = time.perf_counter() - loss_start
# DEBUG: Clean logit regression monitoring with full array printing
@@ -583,6 +595,8 @@ class RLearNPolicy(PreTrainedPolicy):
"loss": float(total_loss.detach().item()),
"loss_main": float(loss.detach().item() if isinstance(loss, torch.Tensor) else 0.0),
"loss_mismatch": float(L_mismatch.detach().item()),
"loss_rank": float(L_rank.detach().item()),
"loss_flat": float(L_flat.detach().item()),
"t_eff": float(T_eff),
"lang_len_mean": float(mask.sum().float().mean().item()), # Use mask to get actual lengths
# Target statistics for monitoring