mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
regulalizer
This commit is contained in:
@@ -71,7 +71,7 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
compile_model: bool = True
|
compile_model: bool = True
|
||||||
|
|
||||||
# ReWiND augmentation
|
# ReWiND augmentation
|
||||||
rewind_prob: float = 0.5
|
rewind_prob: float = 0.8
|
||||||
rewind_last3_prob: float = 0.3
|
rewind_last3_prob: float = 0.3
|
||||||
mismatch_prob: float = 0.2
|
mismatch_prob: float = 0.2
|
||||||
|
|
||||||
|
|||||||
@@ -445,6 +445,18 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
|
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
|
||||||
predicted_rewards = torch.sigmoid(raw_logits)
|
predicted_rewards = torch.sigmoid(raw_logits)
|
||||||
|
|
||||||
|
# Regularizers to avoid flat outputs and encourage local forward progress
|
||||||
|
# Encourage non-flat predictions per sample
|
||||||
|
var_min = 1e-3
|
||||||
|
pred = predicted_rewards
|
||||||
|
L_flat = F.relu(var_min - pred.var(dim=1, unbiased=False)).mean() if pred.shape[1] > 1 else torch.zeros((), device=device)
|
||||||
|
# Enforce local forward progress on logits without overconstraining
|
||||||
|
rank_margin = 0.02
|
||||||
|
if raw_logits.shape[1] > 1:
|
||||||
|
L_rank = F.relu(rank_margin - (raw_logits[:, 1:] - raw_logits[:, :-1])).mean()
|
||||||
|
else:
|
||||||
|
L_rank = torch.zeros((), device=device)
|
||||||
|
|
||||||
# Generate progress labels on-the-fly (ReWiND approach)
|
# Generate progress labels on-the-fly (ReWiND approach)
|
||||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||||
loss_dict: dict[str, float] = {}
|
loss_dict: dict[str, float] = {}
|
||||||
@@ -476,7 +488,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
eps = self.config.logit_eps
|
eps = self.config.logit_eps
|
||||||
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
||||||
# Robust Huber (Smooth L1) on logits
|
# Robust Huber (Smooth L1) on logits
|
||||||
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.5)
|
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.25)
|
||||||
total_loss = loss
|
total_loss = loss
|
||||||
|
|
||||||
|
|
||||||
@@ -523,12 +535,12 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
eps = self.config.logit_eps
|
eps = self.config.logit_eps
|
||||||
zeros_target_logits = torch.logit(torch.full_like(mismatch_raw_logits, eps))
|
zeros_target_logits = torch.logit(torch.full_like(mismatch_raw_logits, eps))
|
||||||
mismatch_loss_per_sample = F.smooth_l1_loss(
|
mismatch_loss_per_sample = F.smooth_l1_loss(
|
||||||
mismatch_raw_logits, zeros_target_logits, beta=0.5, reduction='none'
|
mismatch_raw_logits, zeros_target_logits, beta=0.25, reduction='none'
|
||||||
).mean(dim=1)
|
).mean(dim=1)
|
||||||
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
|
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
|
||||||
|
|
||||||
# Total loss
|
# Total loss
|
||||||
total_loss = total_loss + L_mismatch
|
total_loss = total_loss + L_mismatch + 0.3 * L_rank + 0.05 * L_flat
|
||||||
loss_time = time.perf_counter() - loss_start
|
loss_time = time.perf_counter() - loss_start
|
||||||
|
|
||||||
# DEBUG: Clean logit regression monitoring with full array printing
|
# DEBUG: Clean logit regression monitoring with full array printing
|
||||||
@@ -583,6 +595,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
"loss": float(total_loss.detach().item()),
|
"loss": float(total_loss.detach().item()),
|
||||||
"loss_main": float(loss.detach().item() if isinstance(loss, torch.Tensor) else 0.0),
|
"loss_main": float(loss.detach().item() if isinstance(loss, torch.Tensor) else 0.0),
|
||||||
"loss_mismatch": float(L_mismatch.detach().item()),
|
"loss_mismatch": float(L_mismatch.detach().item()),
|
||||||
|
"loss_rank": float(L_rank.detach().item()),
|
||||||
|
"loss_flat": float(L_flat.detach().item()),
|
||||||
"t_eff": float(T_eff),
|
"t_eff": float(T_eff),
|
||||||
"lang_len_mean": float(mask.sum().float().mean().item()), # Use mask to get actual lengths
|
"lang_len_mean": float(mask.sum().float().mean().item()), # Use mask to get actual lengths
|
||||||
# Target statistics for monitoring
|
# Target statistics for monitoring
|
||||||
|
|||||||
Reference in New Issue
Block a user