mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
VQA annotations are sparse, so VQA was badly underrepresented in training: its effective share was weight x density, and blend draws that picked an ask_vqa* sub-recipe for a non-VQA frame were wasted entirely. Two pieces: 1. Recipe-side consumption (language_render.py): render_sample now routes any frame that carries a VQA annotation to a matching ask_vqa* sub-recipe, regardless of the weighted blend draw. No VQA annotation is wasted and no draw lands on a non-renderable VQA recipe — VQA's recipe-side share now equals the VQA-annotation density. 2. Dataset-side oversampling (WeightedEpisodeAwareSampler + vqa_target_fraction): a new weighted, episode-aware sampler draws frames with replacement by per-frame weight. When TrainPipelineConfig.vqa_target_fraction is set, the train script scans language_events, weights VQA frames so they make up ~that fraction of the training stream, and uses the weighted sampler. This is what actually lets VQA exceed its natural density. Default None keeps uniform episode-aware sampling unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
150 lines
6.1 KiB
Python
150 lines
6.1 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
from collections.abc import Iterator
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EpisodeAwareSampler:
|
|
def __init__(
|
|
self,
|
|
dataset_from_indices: list[int],
|
|
dataset_to_indices: list[int],
|
|
episode_indices_to_use: list | None = None,
|
|
drop_n_first_frames: int = 0,
|
|
drop_n_last_frames: int = 0,
|
|
shuffle: bool = False,
|
|
):
|
|
"""Sampler that optionally incorporates episode boundary information.
|
|
|
|
Args:
|
|
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
|
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
|
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
|
Assumes that episodes are indexed from 0 to N-1.
|
|
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
|
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
|
shuffle: Whether to shuffle the indices.
|
|
"""
|
|
if drop_n_first_frames < 0:
|
|
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
|
if drop_n_last_frames < 0:
|
|
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
|
|
|
indices = []
|
|
for episode_idx, (start_index, end_index) in enumerate(
|
|
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
|
):
|
|
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
|
ep_length = end_index - start_index
|
|
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
|
logger.warning(
|
|
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
|
"drop_n_last_frames=%d removes all frames. Skipping.",
|
|
episode_idx,
|
|
ep_length,
|
|
drop_n_first_frames,
|
|
drop_n_last_frames,
|
|
)
|
|
continue
|
|
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
|
|
|
if not indices:
|
|
raise ValueError(
|
|
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
|
"All episodes were either filtered out or had too few frames."
|
|
)
|
|
|
|
self.indices = indices
|
|
self.shuffle = shuffle
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
if self.shuffle:
|
|
for i in torch.randperm(len(self.indices)):
|
|
yield self.indices[i]
|
|
else:
|
|
for i in self.indices:
|
|
yield i
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.indices)
|
|
|
|
|
|
class WeightedEpisodeAwareSampler(EpisodeAwareSampler):
|
|
"""``EpisodeAwareSampler`` that draws frames *with replacement* in
|
|
proportion to per-frame weights.
|
|
|
|
Used to oversample frames carrying a sparse annotation (e.g. a VQA
|
|
question) so the policy sees them more often than their natural
|
|
dataset density. One epoch still yields ``len(self.indices)``
|
|
samples — the weights only change the *composition* of the stream,
|
|
not its length. Each epoch re-draws, so the oversampled subset
|
|
varies run to run.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_from_indices: list[int],
|
|
dataset_to_indices: list[int],
|
|
frame_weights,
|
|
*,
|
|
episode_indices_to_use: list | None = None,
|
|
drop_n_first_frames: int = 0,
|
|
drop_n_last_frames: int = 0,
|
|
):
|
|
"""
|
|
Args:
|
|
dataset_from_indices: Episode start indices (see ``EpisodeAwareSampler``).
|
|
dataset_to_indices: Episode end indices.
|
|
frame_weights: 1-D sequence/tensor of non-negative weights, one per
|
|
dataset frame (length == total dataset frames). Higher weight ⇒
|
|
that frame is sampled more often.
|
|
episode_indices_to_use / drop_n_first_frames / drop_n_last_frames:
|
|
Same meaning as ``EpisodeAwareSampler`` — the episode-boundary
|
|
frame filtering is applied first, then weighting is restricted
|
|
to the surviving frames.
|
|
"""
|
|
super().__init__(
|
|
dataset_from_indices,
|
|
dataset_to_indices,
|
|
episode_indices_to_use=episode_indices_to_use,
|
|
drop_n_first_frames=drop_n_first_frames,
|
|
drop_n_last_frames=drop_n_last_frames,
|
|
shuffle=False,
|
|
)
|
|
weights = torch.as_tensor(frame_weights, dtype=torch.double).flatten()
|
|
idx = torch.tensor(self.indices, dtype=torch.long)
|
|
if weights.numel() <= int(idx.max()):
|
|
raise ValueError(
|
|
f"frame_weights has {weights.numel()} entries but the sampler "
|
|
f"references frame index {int(idx.max())}."
|
|
)
|
|
selected = weights[idx]
|
|
if not torch.isfinite(selected).all() or bool((selected < 0).any()):
|
|
raise ValueError("frame_weights must be finite and non-negative.")
|
|
if float(selected.sum()) <= 0.0:
|
|
# All surviving frames have zero weight — fall back to uniform.
|
|
selected = torch.ones_like(selected)
|
|
self._weights = selected
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
picks = torch.multinomial(self._weights, num_samples=len(self.indices), replacement=True)
|
|
for i in picks.tolist():
|
|
yield self.indices[i]
|