diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index d2355aaa9..61a638b1b 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -103,8 +103,10 @@ def update_policy( # Note: Policies supporting sample weighting must implement forward(batch, reduction="none") per_sample_loss, output_dict = policy.forward(batch, reduction="none") - # Apply sample weights: L_weighted = Σ(w_i * l_i) / (Σw_i + ε) - # Weights are already normalized to sum to batch_size + # Weighted loss: each sample's contribution is scaled by its weight. + # We divide by weight sum (not batch size) so that if some weights are zero, + # the remaining samples contribute proportionally more, preserving gradient scale. + # Weights are pre-normalized to sum to batch_size for stable training dynamics. epsilon = 1e-6 loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)