mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
fix(annotate): decode keyframes via PyAV directly
The pyav fallback routed through lerobot's decode_video_frames(backend= "pyav"), which uses torchvision.io.VideoReader — removed in torchvision 0.23+. On modern torch stacks (e.g. vllm-openai with torchvision 0.26) both torchcodec and that path fail, leaving interjection/vqa prompts without visual context. Add _decode_frames_av: a self-contained PyAV decoder that picks the nearest frame per timestamp. It is the always-available tail of the decoder chain (torchcodec -> pyav) and the target of --video_backend=pyav. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -303,19 +303,21 @@ class VideoFrameProvider:
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
|
||||
# When no backend is pinned, try the platform default first and fall
|
||||
# back to ``pyav`` if it raises — torchcodec is broken in some
|
||||
# containers (e.g. vllm-openai), where pyav decodes the same file fine.
|
||||
# Build the decoder chain. torchcodec is fast but unusable in some
|
||||
# containers (vllm-openai: "Operation not permitted"); lerobot's
|
||||
# ``pyav`` backend routes through ``torchvision.io.VideoReader``,
|
||||
# removed in torchvision 0.23+. ``_decode_frames_av`` talks to the
|
||||
# ``av`` package directly and is the always-available fallback.
|
||||
if self.video_backend:
|
||||
backends: list[str | None] = [self.video_backend]
|
||||
chain = [self.video_backend]
|
||||
else:
|
||||
backends = [None]
|
||||
if get_safe_default_codec() != "pyav":
|
||||
backends.append("pyav")
|
||||
chain = (["torchcodec"] if get_safe_default_codec() == "torchcodec" else []) + ["pyav"]
|
||||
|
||||
exc: Exception | None = None
|
||||
for backend in backends:
|
||||
for backend in chain:
|
||||
try:
|
||||
if backend in ("pyav", "av"):
|
||||
return _decode_frames_av(video_path, shifted)
|
||||
# Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp.
|
||||
decoded = decode_video_frames(
|
||||
video_path, shifted, self.tolerance_s, backend=backend, return_uint8=True
|
||||
@@ -339,7 +341,7 @@ class VideoFrameProvider:
|
||||
episode_index,
|
||||
camera_key,
|
||||
video_path,
|
||||
backends,
|
||||
chain,
|
||||
exc,
|
||||
exc_info=exc,
|
||||
)
|
||||
@@ -359,6 +361,42 @@ def make_frame_provider(
|
||||
return provider
|
||||
|
||||
|
||||
def _decode_frames_av(video_path: Path, timestamps: list[float]) -> list[Any]:
|
||||
"""Decode the frames nearest to ``timestamps`` using PyAV directly.
|
||||
|
||||
lerobot's ``decode_video_frames(backend="pyav")`` routes through
|
||||
``torchvision.io.VideoReader``, removed in torchvision 0.23+. This helper
|
||||
talks to the ``av`` package directly so keyframe extraction keeps working
|
||||
on modern torch/torchvision stacks and in containers where torchcodec
|
||||
cannot decode. Returns one ``(C, H, W)`` uint8 tensor per timestamp.
|
||||
"""
|
||||
import av # noqa: PLC0415
|
||||
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
loaded_frames: list[torch.Tensor] = []
|
||||
loaded_ts: list[float] = []
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
# Seek to the keyframe at or before the first requested timestamp.
|
||||
offset = max(int(first_ts / stream.time_base), 0) if stream.time_base else 0
|
||||
container.seek(offset, stream=stream, backward=True, any_frame=False)
|
||||
for idx, frame in enumerate(container.decode(stream)):
|
||||
ts = frame.time
|
||||
if ts is None:
|
||||
ts = float(frame.pts * stream.time_base) if frame.pts is not None else float(idx)
|
||||
loaded_ts.append(ts)
|
||||
loaded_frames.append(
|
||||
torch.from_numpy(frame.to_ndarray(format="rgb24")).permute(2, 0, 1).contiguous()
|
||||
)
|
||||
if ts >= last_ts:
|
||||
break
|
||||
if not loaded_frames:
|
||||
raise RuntimeError(f"PyAV decoded no frames from {video_path}")
|
||||
ts_tensor = torch.tensor(loaded_ts)
|
||||
return [loaded_frames[int(torch.argmin((ts_tensor - q).abs()))] for q in timestamps]
|
||||
|
||||
|
||||
def _frame_to_pil(frame: Any) -> Any:
|
||||
"""Materialise a decoded frame as a ``PIL.Image`` for the VLM message.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user