add layernorm in head

This commit is contained in:
Pepijn
2025-08-31 01:13:22 +02:00
parent a1b1643ff6
commit 0ffc5b4741

View File

@@ -179,13 +179,16 @@ class RLearNPolicy(PreTrainedPolicy):
depth=config.mlp_predictor_depth
)
# Layer normalization before reward head to stabilize MLP outputs
self.pre_reward_norm = nn.LayerNorm(config.dim_model)
# MSE regression head with sigmoid activation to bound outputs to [0,1]
self.reward_head = nn.Linear(config.dim_model, 1)
# Initialize with small weights to prevent sigmoid saturation
# Target: sigmoid(0) = 0.5, so we want raw logits around [-2, 2] range
with torch.no_grad():
self.reward_head.weight.normal_(0.0, 0.01) # Much smaller std
self.reward_head.weight.normal_(0.0, 0.02) # Small but not tiny
self.reward_head.bias.fill_(0.0) # Start at sigmoid(0) = 0.5
self.sigmoid = nn.Sigmoid()
@@ -292,7 +295,8 @@ class RLearNPolicy(PreTrainedPolicy):
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
# Get rewards via linear head with sigmoid activation
return self.sigmoid(self.reward_head(video_frame_embeds)).squeeze(-1) # (B, T)
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
return self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T)
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# Initial version: no-op; rely on upstream processors if any
@@ -518,7 +522,8 @@ class RLearNPolicy(PreTrainedPolicy):
# During inference, we might not want to compute loss
if not self.training and target is None:
# Return predictions without loss
rewards = self.sigmoid(self.reward_head(video_frame_embeds)).squeeze(-1)
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1)
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
# Calculate loss using MSE
@@ -526,7 +531,8 @@ class RLearNPolicy(PreTrainedPolicy):
assert target.dtype == torch.float, "Continuous rewards require float targets"
# Get reward predictions with sigmoid activation
predicted_rewards = self.sigmoid(self.reward_head(video_frame_embeds)).squeeze(-1) # (B, T_eff)
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
predicted_rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T_eff)
# MSE loss with masking for variable length sequences
loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean')
@@ -551,7 +557,8 @@ class RLearNPolicy(PreTrainedPolicy):
mismatch_embeds = self.mlp_predictor(attended_video_mm)
# Mismatched pairs should predict zero progress
mismatch_predictions = self.sigmoid(self.reward_head(mismatch_embeds)).squeeze(-1)
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
mismatch_predictions = self.sigmoid(self.reward_head(normalized_mismatch_embeds)).squeeze(-1)
zeros_target = torch.zeros_like(target[:, :T_eff])
L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean')
@@ -562,9 +569,10 @@ class RLearNPolicy(PreTrainedPolicy):
# DEBUG: Print targets and predictions occasionally during training
if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print
with torch.no_grad():
# Get raw MLP outputs before reward head and sigmoid predictions
# Get raw MLP outputs, normalized outputs, and predictions
raw_outputs = video_frame_embeds
raw_logits = self.reward_head(video_frame_embeds).squeeze(-1)
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
raw_logits = self.reward_head(normalized_embeds).squeeze(-1)
preds = self.sigmoid(raw_logits)
print(f"\n=== DEBUG TRAINING ===")
@@ -575,13 +583,14 @@ class RLearNPolicy(PreTrainedPolicy):
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
# Model output statistics
print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]")
print(f"Normalized MLP range: [{normalized_embeds.min():.6f}, {normalized_embeds.max():.6f}]")
print(f"Raw logits range: [{raw_logits.min():.6f}, {raw_logits.max():.6f}]")
print(f"Raw logits mean: {raw_logits.mean():.6f}")
print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]")
print(f"Sigmoid pred mean: {preds.mean():.3f}")
print(f"Loss: {loss:.4f}")
print("First sample targets:", target[0, :5].cpu().numpy())
print("First sample preds:", preds[0, :5].cpu().numpy())
print("First sample targets (all 16):", target[0].cpu().numpy())
print("First sample preds (all 16):", preds[0].cpu().numpy())
print("="*25)
total_forward_time = time.perf_counter() - forward_start