diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index d1d5291b7..a3ba8b5cc 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -71,7 +71,7 @@ class RLearNConfig(PreTrainedConfig): compile_model: bool = True # ReWiND augmentation - rewind_prob: float = 0.5 + rewind_prob: float = 0.8 rewind_last3_prob: float = 0.3 mismatch_prob: float = 0.2 diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 2f86b88ed..a986ae818 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -445,6 +445,18 @@ class RLearNPolicy(PreTrainedPolicy): raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff) predicted_rewards = torch.sigmoid(raw_logits) + # Regularizers to avoid flat outputs and encourage local forward progress + # Encourage non-flat predictions per sample + var_min = 1e-3 + pred = predicted_rewards + L_flat = F.relu(var_min - pred.var(dim=1, unbiased=False)).mean() if pred.shape[1] > 1 else torch.zeros((), device=device) + # Enforce local forward progress on logits without overconstraining + rank_margin = 0.02 + if raw_logits.shape[1] > 1: + L_rank = F.relu(rank_margin - (raw_logits[:, 1:] - raw_logits[:, :-1])).mean() + else: + L_rank = torch.zeros((), device=device) + # Generate progress labels on-the-fly (ReWiND approach) # IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window loss_dict: dict[str, float] = {} @@ -476,7 +488,7 @@ class RLearNPolicy(PreTrainedPolicy): eps = self.config.logit_eps target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps)) # Robust Huber (Smooth L1) on logits - loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.5) + loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.25) total_loss = loss @@ -523,12 +535,12 @@ class RLearNPolicy(PreTrainedPolicy): eps = self.config.logit_eps zeros_target_logits = torch.logit(torch.full_like(mismatch_raw_logits, eps)) mismatch_loss_per_sample = F.smooth_l1_loss( - mismatch_raw_logits, zeros_target_logits, beta=0.5, reduction='none' + mismatch_raw_logits, zeros_target_logits, beta=0.25, reduction='none' ).mean(dim=1) L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean() # Total loss - total_loss = total_loss + L_mismatch + total_loss = total_loss + L_mismatch + 0.3 * L_rank + 0.05 * L_flat loss_time = time.perf_counter() - loss_start # DEBUG: Clean logit regression monitoring with full array printing @@ -583,6 +595,8 @@ class RLearNPolicy(PreTrainedPolicy): "loss": float(total_loss.detach().item()), "loss_main": float(loss.detach().item() if isinstance(loss, torch.Tensor) else 0.0), "loss_mismatch": float(L_mismatch.detach().item()), + "loss_rank": float(L_rank.detach().item()), + "loss_flat": float(L_flat.detach().item()), "t_eff": float(T_eff), "lang_len_mean": float(mask.sum().float().mean().item()), # Use mask to get actual lengths # Target statistics for monitoring