From a305f5f46a2c844e7a75e0bb2be9ee3006d2d629 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 1 Sep 2025 13:34:55 +0200 Subject: [PATCH] hl-gauss --- src/lerobot/policies/rlearn/modeling_rlearn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index ebe62574c..e12d380fa 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -422,6 +422,7 @@ class RLearNPolicy(PreTrainedPolicy): # Compute main loss (or just return predictions in eval) loss_start = time.perf_counter() + loss = torch.tensor(0.0, device=device) if target is None: total_loss = torch.tensor(0.0, device=device) loss = total_loss @@ -435,6 +436,7 @@ class RLearNPolicy(PreTrainedPolicy): bin_idx = torch.where(video_mask, bin_idx, torch.full_like(bin_idx, -1)) loss_ce = F.cross_entropy(video_frame_logits.permute(0, 2, 1), bin_idx, ignore_index=-1) total_loss = loss_ce + loss = loss_ce predicted_rewards = torch.softmax(video_frame_logits, dim=-1) else: # HL-Gauss or regression