diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 452662bde..e343f16c3 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -179,13 +179,16 @@ class RLearNPolicy(PreTrainedPolicy): depth=config.mlp_predictor_depth ) + # Layer normalization before reward head to stabilize MLP outputs + self.pre_reward_norm = nn.LayerNorm(config.dim_model) + # MSE regression head with sigmoid activation to bound outputs to [0,1] self.reward_head = nn.Linear(config.dim_model, 1) # Initialize with small weights to prevent sigmoid saturation # Target: sigmoid(0) = 0.5, so we want raw logits around [-2, 2] range with torch.no_grad(): - self.reward_head.weight.normal_(0.0, 0.01) # Much smaller std + self.reward_head.weight.normal_(0.0, 0.02) # Small but not tiny self.reward_head.bias.fill_(0.0) # Start at sigmoid(0) = 0.5 self.sigmoid = nn.Sigmoid() @@ -292,7 +295,8 @@ class RLearNPolicy(PreTrainedPolicy): video_frame_embeds = self.mlp_predictor(attended_video_tokens) # Get rewards via linear head with sigmoid activation - return self.sigmoid(self.reward_head(video_frame_embeds)).squeeze(-1) # (B, T) + normalized_embeds = self.pre_reward_norm(video_frame_embeds) + return self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T) def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # Initial version: no-op; rely on upstream processors if any @@ -518,7 +522,8 @@ class RLearNPolicy(PreTrainedPolicy): # During inference, we might not want to compute loss if not self.training and target is None: # Return predictions without loss - rewards = self.sigmoid(self.reward_head(video_frame_embeds)).squeeze(-1) + normalized_embeds = self.pre_reward_norm(video_frame_embeds) + rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} # Calculate loss using MSE @@ -526,7 +531,8 @@ class RLearNPolicy(PreTrainedPolicy): assert target.dtype == torch.float, "Continuous rewards require float targets" # Get reward predictions with sigmoid activation - predicted_rewards = self.sigmoid(self.reward_head(video_frame_embeds)).squeeze(-1) # (B, T_eff) + normalized_embeds = self.pre_reward_norm(video_frame_embeds) + predicted_rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T_eff) # MSE loss with masking for variable length sequences loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean') @@ -551,7 +557,8 @@ class RLearNPolicy(PreTrainedPolicy): mismatch_embeds = self.mlp_predictor(attended_video_mm) # Mismatched pairs should predict zero progress - mismatch_predictions = self.sigmoid(self.reward_head(mismatch_embeds)).squeeze(-1) + normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds) + mismatch_predictions = self.sigmoid(self.reward_head(normalized_mismatch_embeds)).squeeze(-1) zeros_target = torch.zeros_like(target[:, :T_eff]) L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean') @@ -562,9 +569,10 @@ class RLearNPolicy(PreTrainedPolicy): # DEBUG: Print targets and predictions occasionally during training if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print with torch.no_grad(): - # Get raw MLP outputs before reward head and sigmoid predictions + # Get raw MLP outputs, normalized outputs, and predictions raw_outputs = video_frame_embeds - raw_logits = self.reward_head(video_frame_embeds).squeeze(-1) + normalized_embeds = self.pre_reward_norm(video_frame_embeds) + raw_logits = self.reward_head(normalized_embeds).squeeze(-1) preds = self.sigmoid(raw_logits) print(f"\n=== DEBUG TRAINING ===") @@ -575,13 +583,14 @@ class RLearNPolicy(PreTrainedPolicy): print(f"Target range: [{target.min():.3f}, {target.max():.3f}]") # Model output statistics print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]") + print(f"Normalized MLP range: [{normalized_embeds.min():.6f}, {normalized_embeds.max():.6f}]") print(f"Raw logits range: [{raw_logits.min():.6f}, {raw_logits.max():.6f}]") print(f"Raw logits mean: {raw_logits.mean():.6f}") print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]") print(f"Sigmoid pred mean: {preds.mean():.3f}") print(f"Loss: {loss:.4f}") - print("First sample targets:", target[0, :5].cpu().numpy()) - print("First sample preds:", preds[0, :5].cpu().numpy()) + print("First sample targets (all 16):", target[0].cpu().numpy()) + print("First sample preds (all 16):", preds[0].cpu().numpy()) print("="*25) total_forward_time = time.perf_counter() - forward_start