mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
regulalizer
This commit is contained in:
@@ -71,7 +71,7 @@ class RLearNConfig(PreTrainedConfig):
|
||||
compile_model: bool = True
|
||||
|
||||
# ReWiND augmentation
|
||||
rewind_prob: float = 0.5
|
||||
rewind_prob: float = 0.8
|
||||
rewind_last3_prob: float = 0.3
|
||||
mismatch_prob: float = 0.2
|
||||
|
||||
|
||||
@@ -445,6 +445,18 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
|
||||
predicted_rewards = torch.sigmoid(raw_logits)
|
||||
|
||||
# Regularizers to avoid flat outputs and encourage local forward progress
|
||||
# Encourage non-flat predictions per sample
|
||||
var_min = 1e-3
|
||||
pred = predicted_rewards
|
||||
L_flat = F.relu(var_min - pred.var(dim=1, unbiased=False)).mean() if pred.shape[1] > 1 else torch.zeros((), device=device)
|
||||
# Enforce local forward progress on logits without overconstraining
|
||||
rank_margin = 0.02
|
||||
if raw_logits.shape[1] > 1:
|
||||
L_rank = F.relu(rank_margin - (raw_logits[:, 1:] - raw_logits[:, :-1])).mean()
|
||||
else:
|
||||
L_rank = torch.zeros((), device=device)
|
||||
|
||||
# Generate progress labels on-the-fly (ReWiND approach)
|
||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||
loss_dict: dict[str, float] = {}
|
||||
@@ -476,7 +488,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
eps = self.config.logit_eps
|
||||
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
||||
# Robust Huber (Smooth L1) on logits
|
||||
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.5)
|
||||
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.25)
|
||||
total_loss = loss
|
||||
|
||||
|
||||
@@ -523,12 +535,12 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
eps = self.config.logit_eps
|
||||
zeros_target_logits = torch.logit(torch.full_like(mismatch_raw_logits, eps))
|
||||
mismatch_loss_per_sample = F.smooth_l1_loss(
|
||||
mismatch_raw_logits, zeros_target_logits, beta=0.5, reduction='none'
|
||||
mismatch_raw_logits, zeros_target_logits, beta=0.25, reduction='none'
|
||||
).mean(dim=1)
|
||||
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
|
||||
|
||||
# Total loss
|
||||
total_loss = total_loss + L_mismatch
|
||||
total_loss = total_loss + L_mismatch + 0.3 * L_rank + 0.05 * L_flat
|
||||
loss_time = time.perf_counter() - loss_start
|
||||
|
||||
# DEBUG: Clean logit regression monitoring with full array printing
|
||||
@@ -583,6 +595,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
"loss": float(total_loss.detach().item()),
|
||||
"loss_main": float(loss.detach().item() if isinstance(loss, torch.Tensor) else 0.0),
|
||||
"loss_mismatch": float(L_mismatch.detach().item()),
|
||||
"loss_rank": float(L_rank.detach().item()),
|
||||
"loss_flat": float(L_flat.detach().item()),
|
||||
"t_eff": float(T_eff),
|
||||
"lang_len_mean": float(mask.sum().float().mean().item()), # Use mask to get actual lengths
|
||||
# Target statistics for monitoring
|
||||
|
||||
Reference in New Issue
Block a user