improve comment

This commit is contained in:
Michel Aractingi
2026-01-14 17:13:26 +01:00
parent 0264ac717b
commit 80b0f1aaa2

View File

@@ -103,8 +103,10 @@ def update_policy(
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
# Apply sample weights: L_weighted = Σ(w_i * l_i) / (Σw_i + ε)
# Weights are already normalized to sum to batch_size
# Weighted loss: each sample's contribution is scaled by its weight.
# We divide by weight sum (not batch size) so that if some weights are zero,
# the remaining samples contribute proportionally more, preserving gradient scale.
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
epsilon = 1e-6
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)