mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
improve comment
This commit is contained in:
@@ -103,8 +103,10 @@ def update_policy(
|
||||
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
|
||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||
|
||||
# Apply sample weights: L_weighted = Σ(w_i * l_i) / (Σw_i + ε)
|
||||
# Weights are already normalized to sum to batch_size
|
||||
# Weighted loss: each sample's contribution is scaled by its weight.
|
||||
# We divide by weight sum (not batch size) so that if some weights are zero,
|
||||
# the remaining samples contribute proportionally more, preserving gradient scale.
|
||||
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
|
||||
epsilon = 1e-6
|
||||
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user