huberman loss

This commit is contained in:
Pepijn
2025-09-01 11:59:20 +02:00
parent 4c4462edea
commit f84c20d403

View File

@@ -522,8 +522,8 @@ class RLearNPolicy(PreTrainedPolicy):
if mismatch_tensor.any():
eps = self.config.logit_eps
zeros_target_logits = torch.logit(torch.full_like(mismatch_raw_logits, eps))
mismatch_loss_per_sample = F.mse_loss(
mismatch_raw_logits, zeros_target_logits, reduction='none'
mismatch_loss_per_sample = F.smooth_l1_loss(
mismatch_raw_logits, zeros_target_logits, beta=0.5, reduction='none'
).mean(dim=1)
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()