mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
add layernorm in head
This commit is contained in:
@@ -179,13 +179,16 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
depth=config.mlp_predictor_depth
|
||||
)
|
||||
|
||||
# Layer normalization before reward head to stabilize MLP outputs
|
||||
self.pre_reward_norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
# MSE regression head with sigmoid activation to bound outputs to [0,1]
|
||||
self.reward_head = nn.Linear(config.dim_model, 1)
|
||||
|
||||
# Initialize with small weights to prevent sigmoid saturation
|
||||
# Target: sigmoid(0) = 0.5, so we want raw logits around [-2, 2] range
|
||||
with torch.no_grad():
|
||||
self.reward_head.weight.normal_(0.0, 0.01) # Much smaller std
|
||||
self.reward_head.weight.normal_(0.0, 0.02) # Small but not tiny
|
||||
self.reward_head.bias.fill_(0.0) # Start at sigmoid(0) = 0.5
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
@@ -292,7 +295,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
|
||||
|
||||
# Get rewards via linear head with sigmoid activation
|
||||
return self.sigmoid(self.reward_head(video_frame_embeds)).squeeze(-1) # (B, T)
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
return self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T)
|
||||
|
||||
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# Initial version: no-op; rely on upstream processors if any
|
||||
@@ -518,7 +522,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# During inference, we might not want to compute loss
|
||||
if not self.training and target is None:
|
||||
# Return predictions without loss
|
||||
rewards = self.sigmoid(self.reward_head(video_frame_embeds)).squeeze(-1)
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1)
|
||||
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
||||
|
||||
# Calculate loss using MSE
|
||||
@@ -526,7 +531,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
assert target.dtype == torch.float, "Continuous rewards require float targets"
|
||||
|
||||
# Get reward predictions with sigmoid activation
|
||||
predicted_rewards = self.sigmoid(self.reward_head(video_frame_embeds)).squeeze(-1) # (B, T_eff)
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
predicted_rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T_eff)
|
||||
|
||||
# MSE loss with masking for variable length sequences
|
||||
loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean')
|
||||
@@ -551,7 +557,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
mismatch_embeds = self.mlp_predictor(attended_video_mm)
|
||||
|
||||
# Mismatched pairs should predict zero progress
|
||||
mismatch_predictions = self.sigmoid(self.reward_head(mismatch_embeds)).squeeze(-1)
|
||||
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
|
||||
mismatch_predictions = self.sigmoid(self.reward_head(normalized_mismatch_embeds)).squeeze(-1)
|
||||
zeros_target = torch.zeros_like(target[:, :T_eff])
|
||||
L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean')
|
||||
|
||||
@@ -562,9 +569,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# DEBUG: Print targets and predictions occasionally during training
|
||||
if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print
|
||||
with torch.no_grad():
|
||||
# Get raw MLP outputs before reward head and sigmoid predictions
|
||||
# Get raw MLP outputs, normalized outputs, and predictions
|
||||
raw_outputs = video_frame_embeds
|
||||
raw_logits = self.reward_head(video_frame_embeds).squeeze(-1)
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
raw_logits = self.reward_head(normalized_embeds).squeeze(-1)
|
||||
preds = self.sigmoid(raw_logits)
|
||||
|
||||
print(f"\n=== DEBUG TRAINING ===")
|
||||
@@ -575,13 +583,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
|
||||
# Model output statistics
|
||||
print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]")
|
||||
print(f"Normalized MLP range: [{normalized_embeds.min():.6f}, {normalized_embeds.max():.6f}]")
|
||||
print(f"Raw logits range: [{raw_logits.min():.6f}, {raw_logits.max():.6f}]")
|
||||
print(f"Raw logits mean: {raw_logits.mean():.6f}")
|
||||
print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]")
|
||||
print(f"Sigmoid pred mean: {preds.mean():.3f}")
|
||||
print(f"Loss: {loss:.4f}")
|
||||
print("First sample targets:", target[0, :5].cpu().numpy())
|
||||
print("First sample preds:", preds[0, :5].cpu().numpy())
|
||||
print("First sample targets (all 16):", target[0].cpu().numpy())
|
||||
print("First sample preds (all 16):", preds[0].cpu().numpy())
|
||||
print("="*25)
|
||||
|
||||
total_forward_time = time.perf_counter() - forward_start
|
||||
|
||||
Reference in New Issue
Block a user