feat: oversample sparse VQA annotations (recipe consumption + weighted sampler)

VQA annotations are sparse, so VQA was badly underrepresented in training:
its effective share was weight x density, and blend draws that picked an
ask_vqa* sub-recipe for a non-VQA frame were wasted entirely.

Two pieces:

1. Recipe-side consumption (language_render.py): render_sample now routes
   any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe,
   regardless of the weighted blend draw. No VQA annotation is wasted and no
   draw lands on a non-renderable VQA recipe — VQA's recipe-side share now
   equals the VQA-annotation density.

2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction):
   a new weighted, episode-aware sampler draws frames with replacement by
   per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the
   train script scans language_events, weights VQA frames so they make up
   ~that fraction of the training stream, and uses the weighted sampler. This
   is what actually lets VQA exceed its natural density. Default None keeps
   uniform episode-aware sampling unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 15:30:00 +02:00
parent b319ccf688
commit fbcb9225f5
7 changed files with 343 additions and 51 deletions

View File

@@ -84,3 +84,66 @@ class EpisodeAwareSampler:
def __len__(self) -> int:
return len(self.indices)
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
proportion to per-frame weights.
Used to oversample frames carrying a sparse annotation (e.g. a VQA
question) so the policy sees them more often than their natural
dataset density. One epoch still yields ``len(self.indices)``
samples — the weights only change the *composition* of the stream,
not its length. Each epoch re-draws, so the oversampled subset
varies run to run.
"""
def __init__(
self,
dataset_from_indices: list[int],
dataset_to_indices: list[int],
frame_weights,
*,
episode_indices_to_use: list | None = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
):
"""
Args:
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
dataset_to_indices: Episode end indices.
frame_weights: 1-D sequence/tensor of non-negative weights, one per
dataset frame (length == total dataset frames). Higher weight ⇒
that frame is sampled more often.
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
frame filtering is applied first, then weighting is restricted
to the surviving frames.
"""
super().__init__(
dataset_from_indices,
dataset_to_indices,
episode_indices_to_use=episode_indices_to_use,
drop_n_first_frames=drop_n_first_frames,
drop_n_last_frames=drop_n_last_frames,
shuffle=False,
)
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
idx = torch.tensor(self.indices, dtype=torch.long)
if weights.numel() <= int(idx.max()):
raise ValueError(
f"frame_weights has {weights.numel()} entries but the sampler "
f"references frame index {int(idx.max())}."
)
selected = weights[idx]
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
raise ValueError("frame_weights must be finite and non-negative.")
if float(selected.sum()) <= 0.0:
# All surviving frames have zero weight — fall back to uniform.
selected = torch.ones_like(selected)
self._weights = selected
def __iter__(self) -> Iterator[int]:
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
for i in picks.tolist():
yield self.indices[i]