This commit is contained in:
Pepijn
2025-09-01 13:34:55 +02:00
parent 45348d7b69
commit a305f5f46a

View File

@@ -422,6 +422,7 @@ class RLearNPolicy(PreTrainedPolicy):
# Compute main loss (or just return predictions in eval)
loss_start = time.perf_counter()
loss = torch.tensor(0.0, device=device)
if target is None:
total_loss = torch.tensor(0.0, device=device)
loss = total_loss
@@ -435,6 +436,7 @@ class RLearNPolicy(PreTrainedPolicy):
bin_idx = torch.where(video_mask, bin_idx, torch.full_like(bin_idx, -1))
loss_ce = F.cross_entropy(video_frame_logits.permute(0, 2, 1), bin_idx, ignore_index=-1)
total_loss = loss_ce
loss = loss_ce
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
else:
# HL-Gauss or regression