mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
feat(annotate): emit VQA per-camera and propagate camera field
Module 3 now produces one (vqa, user) + (vqa, assistant) pair per emission tick *per camera* rather than only against the dataset's first camera. Each emitted row carries the `camera` field added in PR 1 (language-columns), so the resolver can disambiguate per-camera VQA via `emitted_at(t, style=vqa, role=assistant, camera=...)` without ambiguity. - `frames.py`: `FrameProvider` Protocol gains a `camera_keys` property and a `camera_key=` argument on `frames_at` / `video_for_episode`. `VideoFrameProvider` exposes every `observation.images.*` key the dataset declares (not just the first) and keys its decode cache on `(episode, camera, timestamp)` so per-camera reads don't collide. Module 1 / 2 keep their old single-camera behaviour by leaving `camera_key=None` (falls back to the default camera). - `modules/general_vqa.py`: `run_episode` iterates `frame_provider .camera_keys` for each emission tick, builds one prompt per camera, batches all of them through the VLM, and stamps the resulting rows with `camera=<that key>`. Empty `camera_keys` (null provider) makes the module a no-op rather than silently emitting untagged rows. - `writer.py`: `_normalize_persistent_row` / `_normalize_event_row` carry `camera` through and call `validate_camera_field` so the invariant is enforced at the writer boundary. Event sort key now includes `camera` for deterministic ordering when several cameras share `(timestamp, style, role)`. `speech_atom` sets `camera=None`. - `validator.py`: `StagingValidator` gains a `dataset_camera_keys` field; `_check_camera_field` enforces the invariant and cross-checks every view-dependent row's `camera` against the dataset's known video keys. New `_check_vqa_uniqueness_per_frame_camera` flags duplicate `(vqa, role)` pairs at the same `(t, camera)`. - `lerobot_annotate.py`: passes the live frame provider's `camera_keys` into the validator so the cross-check uses the actual dataset camera set. - Tests: `_StubFrameProvider` exposes `camera_keys` and accepts the new `camera_key=` kwarg. `test_module3_vqa_unique_per_frame_and_camera` configures two cameras and asserts both are represented, that every emitted row has a `camera` tag, and that uniqueness holds per `(timestamp, camera, role)`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -34,10 +34,29 @@ from .reader import EpisodeRecord
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
"""Return one PIL.Image per timestamp; empty list if no camera available."""
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one PIL.Image per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(Module 1, Module 2) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` PIL images covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. The returned list is
|
||||
@@ -51,10 +70,24 @@ class FrameProvider(Protocol):
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
@@ -64,12 +97,18 @@ def null_provider() -> FrameProvider:
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's first ``observation.images.*`` stream.
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
The first camera key is used unconditionally — Module 1/2/3 prompts care
|
||||
about *what is happening*, not which camera angle the model sees, so a
|
||||
single canonical viewpoint is enough. Override ``camera_key`` if you
|
||||
want a specific stream.
|
||||
By default the *first* camera key is used for Module 1 (subtask
|
||||
decomposition) and Module 2 (interjection scenarios) — those prompts care
|
||||
about *what is happening*, not which angle. Module 3 (VQA) instead
|
||||
iterates over every camera in :attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
||||
@@ -81,24 +120,37 @@ class VideoFrameProvider:
|
||||
cache_size: int = 256
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
keys = list(self._meta.video_keys or [])
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
keys = self._meta.video_keys
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
if not timestamps or self.camera_key is None:
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, round(float(ts), 6))
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
@@ -108,20 +160,22 @@ class VideoFrameProvider:
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses)
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# decoder may return fewer frames than requested when some
|
||||
# timestamps fall outside the video; pair what we have and
|
||||
# leave the rest as None to be filtered below.
|
||||
for i, img in zip(miss_indices, decoded):
|
||||
out[i] = img
|
||||
key = (record.episode_index, round(float(timestamps[i]), 6))
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = img
|
||||
# filter out any None left over from decode failures
|
||||
return [img for img in out if img is not None]
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]:
|
||||
def _decode(
|
||||
self, episode_index: int, timestamps: list[float], camera_key: str
|
||||
) -> list[Any]:
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
from PIL import Image # noqa: PLC0415
|
||||
@@ -129,9 +183,9 @@ class VideoFrameProvider:
|
||||
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
|
||||
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key)
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
# ``torchcodec`` import currently bad-allocs on cu128/torch-2.8 in
|
||||
# some environments; default to ``pyav`` (always available via
|
||||
# the ``av`` package) and let users override with
|
||||
@@ -156,13 +210,19 @@ class VideoFrameProvider:
|
||||
out.append(Image.fromarray(hwc, mode="RGB"))
|
||||
return out
|
||||
|
||||
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` images uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally.
|
||||
"""
|
||||
if max_frames <= 0 or self.camera_key is None or not record.frame_timestamps:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
@@ -175,7 +235,7 @@ class VideoFrameProvider:
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps)
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
|
||||
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
||||
|
||||
Reference in New Issue
Block a user