mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
huberman loss
This commit is contained in:
@@ -522,8 +522,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
if mismatch_tensor.any():
|
||||
eps = self.config.logit_eps
|
||||
zeros_target_logits = torch.logit(torch.full_like(mismatch_raw_logits, eps))
|
||||
mismatch_loss_per_sample = F.mse_loss(
|
||||
mismatch_raw_logits, zeros_target_logits, reduction='none'
|
||||
mismatch_loss_per_sample = F.smooth_l1_loss(
|
||||
mismatch_raw_logits, zeros_target_logits, beta=0.5, reduction='none'
|
||||
).mean(dim=1)
|
||||
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user