mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
hl-gauss
This commit is contained in:
@@ -422,6 +422,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Compute main loss (or just return predictions in eval)
|
||||
loss_start = time.perf_counter()
|
||||
loss = torch.tensor(0.0, device=device)
|
||||
if target is None:
|
||||
total_loss = torch.tensor(0.0, device=device)
|
||||
loss = total_loss
|
||||
@@ -435,6 +436,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
bin_idx = torch.where(video_mask, bin_idx, torch.full_like(bin_idx, -1))
|
||||
loss_ce = F.cross_entropy(video_frame_logits.permute(0, 2, 1), bin_idx, ignore_index=-1)
|
||||
total_loss = loss_ce
|
||||
loss = loss_ce
|
||||
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
|
||||
else:
|
||||
# HL-Gauss or regression
|
||||
|
||||
Reference in New Issue
Block a user