mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
feat(annotate): attach camera keyframes to module prompts; default to Qwen3.6-27B-FP8
Closes the visual-grounding gap flagged after the initial PR review:
modules now decode actual camera frames at the relevant timestamps and
attach them as `{"type":"image", "image":<PIL>}` content blocks to the
VLM prompts.
- New `frames.py`:
- `FrameProvider` Protocol; `VideoFrameProvider` decodes from the
dataset's first `observation.images.*` stream via
`LeRobotDatasetMetadata.get_video_file_path` and
`decode_video_frames`, with the same `from_timestamp` shift the main
dataset uses.
- Per-process LRU cache so co-timestamped Module 1 plan-update + Module
2 calls share decode work.
- `make_frame_provider` falls back to a null provider when the dataset
has no video tracks → text-only prompts (graceful absence).
- Modules 1/2/3 take an optional `frame_provider` (default null) and
prepend image blocks before the text block.
- Module 1 attaches `keyframes_per_episode` keyframes to the subtask
decomposition prompt.
- Module 2 attaches the frame at the interjection timestamp.
- Module 3 attaches the exact emission frame to each VQA pair.
- VlmConfig: backend now defaults to `vllm`; default model is
`Qwen/Qwen3.6-27B-FP8`. New knobs: `--vlm.tensor_parallel_size`,
`--vlm.camera_key` (override the keyframe stream).
- `_make_vllm_client` honours `tensor_parallel_size` so 27B-FP8 sharded
on 2× GPUs works out of the box.
- `test_module3_attaches_frame_image_block_to_prompt` asserts modules
emit one image block per VQA prompt at the exact emission timestamp.
- Docs: example switched to `imstevenpmwork/super_poulain_draft` +
Qwen3.6-27B-FP8 + tensor_parallel_size=2; documents the keyframe
attachment behaviour and the no-video fallback.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
150
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
150
src/lerobot/annotations/steerable_pipeline/frames.py
Normal file
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Keyframe extraction for the annotation pipeline.
|
||||
|
||||
Modules attach decoded camera frames to their VLM prompts so the model can
|
||||
ground subtask decomposition, interjection scenarios, and VQA in actual
|
||||
visual content. The pipeline shares one provider across modules and one
|
||||
episode at a time, with a small per-episode cache so multiple modules
|
||||
querying the same timestamp pay decode cost once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
|
||||
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
"""Return one PIL.Image per timestamp; empty list if no camera available."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def null_provider() -> FrameProvider:
|
||||
return _NullProvider()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's first ``observation.images.*`` stream.
|
||||
|
||||
The first camera key is used unconditionally — Module 1/2/3 prompts care
|
||||
about *what is happening*, not which camera angle the model sees, so a
|
||||
single canonical viewpoint is enough. Override ``camera_key`` if you
|
||||
want a specific stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
||||
"""
|
||||
|
||||
root: Path
|
||||
camera_key: str | None = None
|
||||
tolerance_s: float = 1e-2
|
||||
cache_size: int = 256
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
if self.camera_key is None:
|
||||
keys = self._meta.video_keys
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
if not timestamps or self.camera_key is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
else:
|
||||
out.append(None)
|
||||
misses.append(float(ts))
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses)
|
||||
for i, img in zip(miss_indices, decoded, strict=True):
|
||||
out[i] = img
|
||||
key = (record.episode_index, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = img
|
||||
# filter out any None left over from decode failures
|
||||
return [img for img in out if img is not None]
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]:
|
||||
from PIL import Image # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
|
||||
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key)
|
||||
try:
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted,
|
||||
self.tolerance_s,
|
||||
return_uint8=True,
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
# frames: [N, C, H, W] uint8, RGB
|
||||
out: list[Any] = []
|
||||
arr = frames.cpu().numpy() if hasattr(frames, "cpu") else frames
|
||||
for i in range(arr.shape[0]):
|
||||
chw = arr[i]
|
||||
hwc = chw.transpose(1, 2, 0)
|
||||
out.append(Image.fromarray(hwc, mode="RGB"))
|
||||
return out
|
||||
|
||||
|
||||
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
||||
"""Build a :class:`VideoFrameProvider` if videos are present, else null."""
|
||||
try:
|
||||
provider = VideoFrameProvider(root=root, camera_key=camera_key)
|
||||
except Exception:
|
||||
return null_provider()
|
||||
if provider.camera_key is None:
|
||||
return null_provider()
|
||||
return provider
|
||||
|
||||
|
||||
def to_image_blocks(images: list[Any]) -> list[dict[str, Any]]:
|
||||
"""Convert PIL images to Qwen-VL-compatible content blocks."""
|
||||
return [{"type": "image", "image": img} for img in images]
|
||||
Reference in New Issue
Block a user