mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Replaced assert statements with FrameTimestampError exceptions in decode_video_frames_torchvision and decode_video_frames_torchcodec. Assertions are unsuitable for runtime validation because they can be silently disabled with python -O, and they produce unhelpful AssertionError tracebacks. The codebase already defines FrameTimestampError for this exact purpose but it was only used in one of the three validation sites. Also removed AssertionError from the except clause in LeRobotDataset.__init__, which was masking video timestamp errors by silently triggering a dataset re-download instead of surfacing the actual problem.
1113 lines
42 KiB
Python
1113 lines
42 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import contextlib
|
|
import glob
|
|
import importlib
|
|
import logging
|
|
import queue
|
|
import shutil
|
|
import tempfile
|
|
import threading
|
|
import warnings
|
|
from dataclasses import dataclass, field
|
|
from fractions import Fraction
|
|
from pathlib import Path
|
|
from threading import Lock
|
|
from typing import Any, ClassVar
|
|
|
|
import av
|
|
import fsspec
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
import torch
|
|
import torchvision
|
|
from datasets.features.features import register_feature
|
|
from PIL import Image
|
|
|
|
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
|
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
|
HW_ENCODERS = [
|
|
"h264_videotoolbox", # macOS
|
|
"hevc_videotoolbox", # macOS
|
|
"h264_nvenc", # NVIDIA GPU
|
|
"hevc_nvenc", # NVIDIA GPU
|
|
"h264_vaapi", # Linux Intel/AMD
|
|
"h264_qsv", # Intel Quick Sync
|
|
]
|
|
|
|
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
|
|
|
|
|
def _get_codec_options(
|
|
vcodec: str,
|
|
g: int | None = 2,
|
|
crf: int | None = 30,
|
|
preset: int | None = None,
|
|
) -> dict:
|
|
"""Build codec-specific options dict for video encoding."""
|
|
options = {}
|
|
|
|
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
|
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
|
options["g"] = str(g)
|
|
|
|
# Quality control (codec-specific parameter names)
|
|
if crf is not None:
|
|
if vcodec in ("h264", "hevc", "libsvtav1"):
|
|
options["crf"] = str(crf)
|
|
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
|
quality = max(1, min(100, int(100 - crf * 2)))
|
|
options["q:v"] = str(quality)
|
|
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
|
options["rc"] = "constqp"
|
|
options["qp"] = str(crf)
|
|
elif vcodec in ("h264_vaapi",):
|
|
options["qp"] = str(crf)
|
|
elif vcodec in ("h264_qsv",):
|
|
options["global_quality"] = str(crf)
|
|
|
|
# Preset (only for libsvtav1)
|
|
if vcodec == "libsvtav1":
|
|
options["preset"] = str(preset) if preset is not None else "12"
|
|
|
|
return options
|
|
|
|
|
|
def detect_available_hw_encoders() -> list[str]:
|
|
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
|
available = []
|
|
for codec_name in HW_ENCODERS:
|
|
try:
|
|
av.codec.Codec(codec_name, "w")
|
|
available.append(codec_name)
|
|
except Exception: # nosec B110
|
|
pass # nosec B110
|
|
return available
|
|
|
|
|
|
def resolve_vcodec(vcodec: str) -> str:
|
|
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
|
if vcodec not in VALID_VIDEO_CODECS:
|
|
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
|
if vcodec != "auto":
|
|
logging.info(f"Using video codec: {vcodec}")
|
|
return vcodec
|
|
available = detect_available_hw_encoders()
|
|
for encoder in HW_ENCODERS:
|
|
if encoder in available:
|
|
logging.info(f"Auto-selected video codec: {encoder}")
|
|
return encoder
|
|
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
|
return "libsvtav1"
|
|
|
|
|
|
def get_safe_default_codec():
|
|
if importlib.util.find_spec("torchcodec"):
|
|
return "torchcodec"
|
|
else:
|
|
logging.warning(
|
|
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
|
)
|
|
return "pyav"
|
|
|
|
|
|
def decode_video_frames(
|
|
video_path: Path | str,
|
|
timestamps: list[float],
|
|
tolerance_s: float,
|
|
backend: str | None = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Decodes video frames using the specified backend.
|
|
|
|
Args:
|
|
video_path (Path): Path to the video file.
|
|
timestamps (list[float]): List of timestamps to extract frames.
|
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
|
|
|
Returns:
|
|
torch.Tensor: Decoded frames.
|
|
|
|
Currently supports torchcodec on cpu and pyav.
|
|
"""
|
|
if backend is None:
|
|
backend = get_safe_default_codec()
|
|
if backend == "torchcodec":
|
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
|
elif backend in ["pyav", "video_reader"]:
|
|
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
|
else:
|
|
raise ValueError(f"Unsupported video backend: {backend}")
|
|
|
|
|
|
def decode_video_frames_torchvision(
|
|
video_path: Path | str,
|
|
timestamps: list[float],
|
|
tolerance_s: float,
|
|
backend: str = "pyav",
|
|
log_loaded_timestamps: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Loads frames associated to the requested timestamps of a video
|
|
|
|
The backend can be either "pyav" (default) or "video_reader".
|
|
"video_reader" requires installing torchvision from source, see:
|
|
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
|
(note that you need to compile against ffmpeg<4.3)
|
|
|
|
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
|
For more info on video decoding, see `benchmark/video/README.md`
|
|
|
|
See torchvision doc for more info on these two backends:
|
|
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
|
|
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
|
"""
|
|
video_path = str(video_path)
|
|
|
|
# set backend
|
|
keyframes_only = False
|
|
torchvision.set_video_backend(backend)
|
|
if backend == "pyav":
|
|
keyframes_only = True # pyav doesn't support accurate seek
|
|
|
|
# set a video stream reader
|
|
# TODO(rcadene): also load audio stream at the same time
|
|
reader = torchvision.io.VideoReader(video_path, "video")
|
|
|
|
# set the first and last requested timestamps
|
|
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
|
first_ts = min(timestamps)
|
|
last_ts = max(timestamps)
|
|
|
|
# access closest key frame of the first requested frame
|
|
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
|
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
|
reader.seek(first_ts, keyframes_only=keyframes_only)
|
|
|
|
# load all frames until last requested frame
|
|
loaded_frames = []
|
|
loaded_ts = []
|
|
for frame in reader:
|
|
current_ts = frame["pts"]
|
|
if log_loaded_timestamps:
|
|
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
|
loaded_frames.append(frame["data"])
|
|
loaded_ts.append(current_ts)
|
|
if current_ts >= last_ts:
|
|
break
|
|
|
|
if backend == "pyav":
|
|
reader.container.close()
|
|
|
|
reader = None
|
|
|
|
query_ts = torch.tensor(timestamps)
|
|
loaded_ts = torch.tensor(loaded_ts)
|
|
|
|
# compute distances between each query timestamp and timestamps of all loaded frames
|
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
|
min_, argmin_ = dist.min(1)
|
|
|
|
is_within_tol = min_ < tolerance_s
|
|
if not is_within_tol.all():
|
|
raise FrameTimestampError(
|
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
|
" It means that the closest frame that can be loaded from the video is too far away in time."
|
|
" This might be due to synchronization issues with timestamps during data collection."
|
|
" To be safe, we advise to ignore this item during training."
|
|
f"\nqueried timestamps: {query_ts}"
|
|
f"\nloaded timestamps: {loaded_ts}"
|
|
f"\nvideo: {video_path}"
|
|
f"\nbackend: {backend}"
|
|
)
|
|
|
|
# get closest frames to the query timestamps
|
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
|
closest_ts = loaded_ts[argmin_]
|
|
|
|
if log_loaded_timestamps:
|
|
logging.info(f"{closest_ts=}")
|
|
|
|
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
|
closest_frames = closest_frames.type(torch.float32) / 255
|
|
|
|
if len(timestamps) != len(closest_frames):
|
|
raise FrameTimestampError(
|
|
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
|
f"number of queried timestamps ({len(timestamps)})"
|
|
)
|
|
return closest_frames
|
|
|
|
|
|
class VideoDecoderCache:
|
|
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
|
|
|
def __init__(self):
|
|
self._cache: dict[str, tuple[Any, Any]] = {}
|
|
self._lock = Lock()
|
|
|
|
def get_decoder(self, video_path: str):
|
|
"""Get a cached decoder or create a new one."""
|
|
if importlib.util.find_spec("torchcodec"):
|
|
from torchcodec.decoders import VideoDecoder
|
|
else:
|
|
raise ImportError("torchcodec is required but not available.")
|
|
|
|
video_path = str(video_path)
|
|
|
|
with self._lock:
|
|
if video_path not in self._cache:
|
|
file_handle = fsspec.open(video_path).__enter__()
|
|
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
|
self._cache[video_path] = (decoder, file_handle)
|
|
|
|
return self._cache[video_path][0]
|
|
|
|
def clear(self):
|
|
"""Clear the cache and close file handles."""
|
|
with self._lock:
|
|
for _, file_handle in self._cache.values():
|
|
file_handle.close()
|
|
self._cache.clear()
|
|
|
|
def size(self) -> int:
|
|
"""Return the number of cached decoders."""
|
|
with self._lock:
|
|
return len(self._cache)
|
|
|
|
|
|
class FrameTimestampError(ValueError):
|
|
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
|
|
|
|
pass
|
|
|
|
|
|
_default_decoder_cache = VideoDecoderCache()
|
|
|
|
|
|
def decode_video_frames_torchcodec(
|
|
video_path: Path | str,
|
|
timestamps: list[float],
|
|
tolerance_s: float,
|
|
log_loaded_timestamps: bool = False,
|
|
decoder_cache: VideoDecoderCache | None = None,
|
|
) -> torch.Tensor:
|
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
|
|
|
Args:
|
|
video_path: Path to the video file.
|
|
timestamps: List of timestamps to extract frames.
|
|
tolerance_s: Allowed deviation in seconds for frame retrieval.
|
|
log_loaded_timestamps: Whether to log loaded timestamps.
|
|
decoder_cache: Optional decoder cache instance. Uses default if None.
|
|
|
|
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
|
|
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
|
"""
|
|
if decoder_cache is None:
|
|
decoder_cache = _default_decoder_cache
|
|
|
|
# Use cached decoder instead of creating new one each time
|
|
decoder = decoder_cache.get_decoder(str(video_path))
|
|
|
|
loaded_ts = []
|
|
loaded_frames = []
|
|
|
|
# get metadata for frame information
|
|
metadata = decoder.metadata
|
|
average_fps = metadata.average_fps
|
|
# convert timestamps to frame indices
|
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
|
# retrieve frames based on indices
|
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
|
|
|
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
|
|
loaded_frames.append(frame)
|
|
loaded_ts.append(pts.item())
|
|
if log_loaded_timestamps:
|
|
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
|
|
|
query_ts = torch.tensor(timestamps)
|
|
loaded_ts = torch.tensor(loaded_ts)
|
|
|
|
# compute distances between each query timestamp and loaded timestamps
|
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
|
min_, argmin_ = dist.min(1)
|
|
|
|
is_within_tol = min_ < tolerance_s
|
|
if not is_within_tol.all():
|
|
raise FrameTimestampError(
|
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
|
" It means that the closest frame that can be loaded from the video is too far away in time."
|
|
" This might be due to synchronization issues with timestamps during data collection."
|
|
" To be safe, we advise to ignore this item during training."
|
|
f"\nqueried timestamps: {query_ts}"
|
|
f"\nloaded timestamps: {loaded_ts}"
|
|
f"\nvideo: {video_path}"
|
|
)
|
|
|
|
# get closest frames to the query timestamps
|
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
|
closest_ts = loaded_ts[argmin_]
|
|
|
|
if log_loaded_timestamps:
|
|
logging.info(f"{closest_ts=}")
|
|
|
|
# convert to float32 in [0,1] range
|
|
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
|
|
|
if not len(timestamps) == len(closest_frames):
|
|
raise FrameTimestampError(
|
|
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
|
)
|
|
|
|
return closest_frames
|
|
|
|
|
|
def encode_video_frames(
|
|
imgs_dir: Path | str,
|
|
video_path: Path | str,
|
|
fps: int,
|
|
vcodec: str = "libsvtav1",
|
|
pix_fmt: str = "yuv420p",
|
|
g: int | None = 2,
|
|
crf: int | None = 30,
|
|
fast_decode: int = 0,
|
|
log_level: int | None = av.logging.WARNING,
|
|
overwrite: bool = False,
|
|
preset: int | None = None,
|
|
encoder_threads: int | None = None,
|
|
) -> None:
|
|
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
|
vcodec = resolve_vcodec(vcodec)
|
|
|
|
video_path = Path(video_path)
|
|
imgs_dir = Path(imgs_dir)
|
|
|
|
if video_path.exists() and not overwrite:
|
|
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
|
return
|
|
|
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Encoders/pixel formats incompatibility check
|
|
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
|
logging.warning(
|
|
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
|
)
|
|
pix_fmt = "yuv420p"
|
|
|
|
# Get input frames
|
|
template = "frame-" + ("[0-9]" * 6) + ".png"
|
|
input_list = sorted(
|
|
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
|
)
|
|
|
|
# Define video output frame size (assuming all input frames are the same size)
|
|
if len(input_list) == 0:
|
|
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
|
with Image.open(input_list[0]) as dummy_image:
|
|
width, height = dummy_image.size
|
|
|
|
# Define video codec options
|
|
video_options = _get_codec_options(vcodec, g, crf, preset)
|
|
|
|
if fast_decode:
|
|
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
|
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
|
video_options[key] = value
|
|
|
|
if encoder_threads is not None:
|
|
if vcodec == "libsvtav1":
|
|
lp_param = f"lp={encoder_threads}"
|
|
if "svtav1-params" in video_options:
|
|
video_options["svtav1-params"] += f":{lp_param}"
|
|
else:
|
|
video_options["svtav1-params"] = lp_param
|
|
else:
|
|
video_options["threads"] = str(encoder_threads)
|
|
|
|
# Set logging level
|
|
if log_level is not None:
|
|
# "While less efficient, it is generally preferable to modify logging with Python's logging"
|
|
logging.getLogger("libav").setLevel(log_level)
|
|
|
|
# Create and open output file (overwrite by default)
|
|
with av.open(str(video_path), "w") as output:
|
|
output_stream = output.add_stream(vcodec, fps, options=video_options)
|
|
output_stream.pix_fmt = pix_fmt
|
|
output_stream.width = width
|
|
output_stream.height = height
|
|
|
|
# Loop through input frames and encode them
|
|
for input_data in input_list:
|
|
with Image.open(input_data) as input_image:
|
|
input_image = input_image.convert("RGB")
|
|
input_frame = av.VideoFrame.from_image(input_image)
|
|
packet = output_stream.encode(input_frame)
|
|
if packet:
|
|
output.mux(packet)
|
|
|
|
# Flush the encoder
|
|
packet = output_stream.encode()
|
|
if packet:
|
|
output.mux(packet)
|
|
|
|
# Reset logging level
|
|
if log_level is not None:
|
|
av.logging.restore_default_callback()
|
|
|
|
if not video_path.exists():
|
|
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
|
|
|
|
|
def concatenate_video_files(
|
|
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
|
):
|
|
"""
|
|
Concatenate multiple video files into a single video file using pyav.
|
|
|
|
This function takes a list of video input file paths and concatenates them into a single
|
|
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
|
concatenation without re-encoding.
|
|
|
|
Args:
|
|
input_video_paths: Ordered list of input video file paths to concatenate.
|
|
output_video_path: Path to the output video file.
|
|
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
|
|
|
Note:
|
|
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
|
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
|
|
codec, resolution, and frame rate for proper concatenation.
|
|
"""
|
|
|
|
output_video_path = Path(output_video_path)
|
|
|
|
if output_video_path.exists() and not overwrite:
|
|
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
|
return
|
|
|
|
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if len(input_video_paths) == 0:
|
|
raise FileNotFoundError("No input video paths provided.")
|
|
|
|
# Create a temporary .ffconcat file to list the input video paths
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
|
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
|
for input_path in input_video_paths:
|
|
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
|
tmp_concatenate_file.flush()
|
|
tmp_concatenate_path = tmp_concatenate_file.name
|
|
|
|
# Create input and output containers
|
|
input_container = av.open(
|
|
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
|
) # safe = 0 allows absolute paths as well as relative paths
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
|
tmp_output_video_path = tmp_named_file.name
|
|
|
|
output_container = av.open(
|
|
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
|
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
|
|
|
# Replicate input streams in output container
|
|
stream_map = {}
|
|
for input_stream in input_container.streams:
|
|
if input_stream.type in ("video", "audio", "subtitle"): # only copy compatible streams
|
|
stream_map[input_stream.index] = output_container.add_stream_from_template(
|
|
template=input_stream, opaque=True
|
|
)
|
|
|
|
# set the time base to the input stream time base (missing in the codec context)
|
|
stream_map[input_stream.index].time_base = input_stream.time_base
|
|
|
|
# Demux + remux packets (no re-encode)
|
|
for packet in input_container.demux():
|
|
# Skip packets from un-mapped streams
|
|
if packet.stream.index not in stream_map:
|
|
continue
|
|
|
|
# Skip demux flushing packets
|
|
if packet.dts is None:
|
|
continue
|
|
|
|
output_stream = stream_map[packet.stream.index]
|
|
packet.stream = output_stream
|
|
output_container.mux(packet)
|
|
|
|
input_container.close()
|
|
output_container.close()
|
|
shutil.move(tmp_output_video_path, output_video_path)
|
|
Path(tmp_concatenate_path).unlink()
|
|
|
|
|
|
class _CameraEncoderThread(threading.Thread):
|
|
"""A thread that encodes video frames streamed via a queue into an MP4 file.
|
|
|
|
One instance is created per camera per episode. Frames are received as numpy arrays
|
|
from the main thread, encoded in real-time using PyAV (which releases the GIL during
|
|
encoding), and written to disk. Stats are computed incrementally using
|
|
RunningQuantileStats and returned via result_queue.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
video_path: Path,
|
|
fps: int,
|
|
vcodec: str,
|
|
pix_fmt: str,
|
|
g: int | None,
|
|
crf: int | None,
|
|
preset: int | None,
|
|
frame_queue: queue.Queue,
|
|
result_queue: queue.Queue,
|
|
stop_event: threading.Event,
|
|
encoder_threads: int | None = None,
|
|
):
|
|
super().__init__(daemon=True)
|
|
self.video_path = video_path
|
|
self.fps = fps
|
|
self.vcodec = vcodec
|
|
self.pix_fmt = pix_fmt
|
|
self.g = g
|
|
self.crf = crf
|
|
self.preset = preset
|
|
self.frame_queue = frame_queue
|
|
self.result_queue = result_queue
|
|
self.stop_event = stop_event
|
|
self.encoder_threads = encoder_threads
|
|
|
|
def run(self) -> None:
|
|
from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width
|
|
|
|
container = None
|
|
output_stream = None
|
|
stats_tracker = RunningQuantileStats()
|
|
frame_count = 0
|
|
|
|
try:
|
|
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
|
|
|
while True:
|
|
try:
|
|
frame_data = self.frame_queue.get(timeout=1)
|
|
except queue.Empty:
|
|
if self.stop_event.is_set():
|
|
break
|
|
continue
|
|
|
|
if frame_data is None:
|
|
# Sentinel: flush and close
|
|
break
|
|
|
|
# Ensure HWC uint8 numpy array
|
|
if isinstance(frame_data, np.ndarray):
|
|
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
|
# CHW -> HWC
|
|
frame_data = frame_data.transpose(1, 2, 0)
|
|
if frame_data.dtype != np.uint8:
|
|
frame_data = (frame_data * 255).astype(np.uint8)
|
|
|
|
# Open container on first frame (to get width/height)
|
|
if container is None:
|
|
height, width = frame_data.shape[:2]
|
|
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
|
if self.encoder_threads is not None:
|
|
if self.vcodec == "libsvtav1":
|
|
lp_param = f"lp={self.encoder_threads}"
|
|
if "svtav1-params" in video_options:
|
|
video_options["svtav1-params"] += f":{lp_param}"
|
|
else:
|
|
video_options["svtav1-params"] = lp_param
|
|
else:
|
|
video_options["threads"] = str(self.encoder_threads)
|
|
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
|
container = av.open(str(self.video_path), "w")
|
|
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
|
output_stream.pix_fmt = self.pix_fmt
|
|
output_stream.width = width
|
|
output_stream.height = height
|
|
output_stream.time_base = Fraction(1, self.fps)
|
|
|
|
# Encode frame with explicit timestamps
|
|
pil_img = Image.fromarray(frame_data)
|
|
video_frame = av.VideoFrame.from_image(pil_img)
|
|
video_frame.pts = frame_count
|
|
video_frame.time_base = Fraction(1, self.fps)
|
|
packet = output_stream.encode(video_frame)
|
|
if packet:
|
|
container.mux(packet)
|
|
|
|
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
|
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
|
img_downsampled = auto_downsample_height_width(img_chw)
|
|
# Reshape CHW to (H*W, C) for per-channel stats
|
|
channels = img_downsampled.shape[0]
|
|
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
|
stats_tracker.update(img_for_stats)
|
|
|
|
frame_count += 1
|
|
|
|
# Flush encoder
|
|
if output_stream is not None:
|
|
packet = output_stream.encode()
|
|
if packet:
|
|
container.mux(packet)
|
|
|
|
if container is not None:
|
|
container.close()
|
|
|
|
av.logging.restore_default_callback()
|
|
|
|
# Get stats and put on result queue
|
|
if frame_count >= 2:
|
|
stats = stats_tracker.get_statistics()
|
|
self.result_queue.put(("ok", stats))
|
|
else:
|
|
self.result_queue.put(("ok", None))
|
|
|
|
except Exception as e:
|
|
logging.error(f"Encoder thread error: {e}")
|
|
if container is not None:
|
|
with contextlib.suppress(Exception):
|
|
container.close()
|
|
self.result_queue.put(("error", str(e)))
|
|
|
|
|
|
class StreamingVideoEncoder:
|
|
"""Manages per-camera encoder threads for real-time video encoding during recording.
|
|
|
|
Instead of writing frames as PNG images and then encoding to MP4 at episode end,
|
|
this class streams frames directly to encoder threads, eliminating the
|
|
PNG round-trip and making save_episode() near-instant.
|
|
|
|
Uses threading instead of multiprocessing to avoid the overhead of pickling large
|
|
numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL,
|
|
so encoding runs in parallel with the main recording loop.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fps: int,
|
|
vcodec: str = "libsvtav1",
|
|
pix_fmt: str = "yuv420p",
|
|
g: int | None = 2,
|
|
crf: int | None = 30,
|
|
preset: int | None = None,
|
|
queue_maxsize: int = 30,
|
|
encoder_threads: int | None = None,
|
|
):
|
|
self.fps = fps
|
|
self.vcodec = resolve_vcodec(vcodec)
|
|
self.pix_fmt = pix_fmt
|
|
self.g = g
|
|
self.crf = crf
|
|
self.preset = preset
|
|
self.queue_maxsize = queue_maxsize
|
|
self.encoder_threads = encoder_threads
|
|
|
|
self._frame_queues: dict[str, queue.Queue] = {}
|
|
self._result_queues: dict[str, queue.Queue] = {}
|
|
self._threads: dict[str, _CameraEncoderThread] = {}
|
|
self._stop_events: dict[str, threading.Event] = {}
|
|
self._video_paths: dict[str, Path] = {}
|
|
self._dropped_frames: dict[str, int] = {}
|
|
self._episode_active = False
|
|
|
|
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
|
"""Start encoder threads for a new episode.
|
|
|
|
Args:
|
|
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
|
temp_dir: Base directory for temporary MP4 files
|
|
"""
|
|
if self._episode_active:
|
|
self.cancel_episode()
|
|
|
|
self._dropped_frames.clear()
|
|
|
|
for video_key in video_keys:
|
|
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
|
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
|
stop_event = threading.Event()
|
|
|
|
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
|
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
|
|
|
encoder_thread = _CameraEncoderThread(
|
|
video_path=video_path,
|
|
fps=self.fps,
|
|
vcodec=self.vcodec,
|
|
pix_fmt=self.pix_fmt,
|
|
g=self.g,
|
|
crf=self.crf,
|
|
preset=self.preset,
|
|
frame_queue=frame_queue,
|
|
result_queue=result_queue,
|
|
stop_event=stop_event,
|
|
encoder_threads=self.encoder_threads,
|
|
)
|
|
encoder_thread.start()
|
|
|
|
self._frame_queues[video_key] = frame_queue
|
|
self._result_queues[video_key] = result_queue
|
|
self._threads[video_key] = encoder_thread
|
|
self._stop_events[video_key] = stop_event
|
|
self._video_paths[video_key] = video_path
|
|
|
|
self._episode_active = True
|
|
|
|
def feed_frame(self, video_key: str, image: np.ndarray) -> None:
|
|
"""Feed a frame to the encoder for a specific camera.
|
|
|
|
A copy of the image is made before enqueueing to prevent race conditions
|
|
with camera drivers that may reuse buffers. If the encoder queue is full
|
|
(encoder can't keep up), the frame is dropped with a warning instead of
|
|
crashing the recording session.
|
|
|
|
Args:
|
|
video_key: The video feature key
|
|
image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float
|
|
|
|
Raises:
|
|
RuntimeError: If the encoder thread has crashed
|
|
"""
|
|
if not self._episode_active:
|
|
raise RuntimeError("No active episode. Call start_episode() first.")
|
|
|
|
thread = self._threads[video_key]
|
|
if not thread.is_alive():
|
|
# Check for error
|
|
try:
|
|
status, msg = self._result_queues[video_key].get_nowait()
|
|
if status == "error":
|
|
raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}")
|
|
except queue.Empty:
|
|
pass
|
|
raise RuntimeError(f"Encoder thread for {video_key} is not alive")
|
|
|
|
try:
|
|
self._frame_queues[video_key].put(image.copy(), timeout=0.1)
|
|
except queue.Full:
|
|
self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1
|
|
count = self._dropped_frames[video_key]
|
|
# Log periodically to avoid spam (1st, then every 10th)
|
|
if count == 1 or count % 10 == 0:
|
|
logging.warning(
|
|
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
|
|
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
|
|
)
|
|
|
|
def finish_episode(self) -> dict[str, tuple[Path, dict | None]]:
|
|
"""Finish encoding the current episode.
|
|
|
|
Sends sentinel values, waits for encoder threads to complete,
|
|
and collects results.
|
|
|
|
Returns:
|
|
Dict mapping video_key to (mp4_path, stats_dict_or_None)
|
|
"""
|
|
if not self._episode_active:
|
|
raise RuntimeError("No active episode to finish.")
|
|
|
|
results = {}
|
|
|
|
# Report dropped frames
|
|
for video_key, count in self._dropped_frames.items():
|
|
if count > 0:
|
|
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
|
|
|
# Send sentinel to all queues
|
|
for video_key in self._frame_queues:
|
|
self._frame_queues[video_key].put(None)
|
|
|
|
# Wait for all threads and collect results
|
|
for video_key in self._threads:
|
|
self._threads[video_key].join(timeout=120)
|
|
if self._threads[video_key].is_alive():
|
|
logging.error(f"Encoder thread for {video_key} did not finish in time")
|
|
self._stop_events[video_key].set()
|
|
self._threads[video_key].join(timeout=5)
|
|
results[video_key] = (self._video_paths[video_key], None)
|
|
continue
|
|
|
|
try:
|
|
status, data = self._result_queues[video_key].get(timeout=5)
|
|
if status == "error":
|
|
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
|
|
results[video_key] = (self._video_paths[video_key], data)
|
|
except queue.Empty:
|
|
logging.error(f"No result from encoder thread for {video_key}")
|
|
results[video_key] = (self._video_paths[video_key], None)
|
|
|
|
self._cleanup()
|
|
self._episode_active = False
|
|
return results
|
|
|
|
def cancel_episode(self) -> None:
|
|
"""Cancel the current episode, stopping encoder threads and cleaning up."""
|
|
if not self._episode_active:
|
|
return
|
|
|
|
# Signal all threads to stop
|
|
for video_key in self._stop_events:
|
|
self._stop_events[video_key].set()
|
|
|
|
# Wait for threads to finish
|
|
for video_key in self._threads:
|
|
self._threads[video_key].join(timeout=5)
|
|
|
|
# Clean up temp MP4 files
|
|
video_path = self._video_paths.get(video_key)
|
|
if video_path is not None and video_path.exists():
|
|
shutil.rmtree(str(video_path.parent), ignore_errors=True)
|
|
|
|
self._cleanup()
|
|
self._episode_active = False
|
|
|
|
def close(self) -> None:
|
|
"""Close the encoder, canceling any in-progress episode."""
|
|
if self._episode_active:
|
|
self.cancel_episode()
|
|
|
|
def _cleanup(self) -> None:
|
|
"""Clean up queues and thread tracking dicts."""
|
|
for q in self._frame_queues.values():
|
|
with contextlib.suppress(Exception):
|
|
while not q.empty():
|
|
q.get_nowait()
|
|
self._frame_queues.clear()
|
|
self._result_queues.clear()
|
|
self._threads.clear()
|
|
self._stop_events.clear()
|
|
self._video_paths.clear()
|
|
|
|
|
|
@dataclass
|
|
class VideoFrame:
|
|
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
|
"""
|
|
Provides a type for a dataset containing video frames.
|
|
|
|
Example:
|
|
|
|
```python
|
|
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
|
features = {"image": VideoFrame()}
|
|
Dataset.from_dict(data_dict, features=Features(features))
|
|
```
|
|
"""
|
|
|
|
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
|
|
_type: str = field(default="VideoFrame", init=False, repr=False)
|
|
|
|
def __call__(self):
|
|
return self.pa_type
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
|
category=UserWarning,
|
|
)
|
|
# to make VideoFrame available in HuggingFace `datasets`
|
|
register_feature(VideoFrame, "VideoFrame")
|
|
|
|
|
|
def get_audio_info(video_path: Path | str) -> dict:
|
|
# Set logging level
|
|
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
|
|
|
# Getting audio stream information
|
|
audio_info = {}
|
|
with av.open(str(video_path), "r") as audio_file:
|
|
try:
|
|
audio_stream = audio_file.streams.audio[0]
|
|
except IndexError:
|
|
# Reset logging level
|
|
av.logging.restore_default_callback()
|
|
return {"has_audio": False}
|
|
|
|
audio_info["audio.channels"] = audio_stream.channels
|
|
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
|
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
|
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
|
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
|
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
|
# In an ideal loseless case : fixed number of bits per sample.
|
|
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
|
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
|
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
|
audio_info["has_audio"] = True
|
|
|
|
# Reset logging level
|
|
av.logging.restore_default_callback()
|
|
|
|
return audio_info
|
|
|
|
|
|
def get_video_info(video_path: Path | str) -> dict:
|
|
# Set logging level
|
|
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
|
|
|
# Getting video stream information
|
|
video_info = {}
|
|
with av.open(str(video_path), "r") as video_file:
|
|
try:
|
|
video_stream = video_file.streams.video[0]
|
|
except IndexError:
|
|
# Reset logging level
|
|
av.logging.restore_default_callback()
|
|
return {}
|
|
|
|
video_info["video.height"] = video_stream.height
|
|
video_info["video.width"] = video_stream.width
|
|
video_info["video.codec"] = video_stream.codec.canonical_name
|
|
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
|
video_info["video.is_depth_map"] = False
|
|
|
|
# Calculate fps from r_frame_rate
|
|
video_info["video.fps"] = int(video_stream.base_rate)
|
|
|
|
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
|
video_info["video.channels"] = pixel_channels
|
|
|
|
# Reset logging level
|
|
av.logging.restore_default_callback()
|
|
|
|
# Adding audio stream information
|
|
video_info.update(**get_audio_info(video_path))
|
|
|
|
return video_info
|
|
|
|
|
|
def get_video_pixel_channels(pix_fmt: str) -> int:
|
|
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
|
return 1
|
|
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
|
return 4
|
|
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
|
return 3
|
|
else:
|
|
raise ValueError("Unknown format")
|
|
|
|
|
|
def get_video_duration_in_s(video_path: Path | str) -> float:
|
|
"""
|
|
Get the duration of a video file in seconds using PyAV.
|
|
|
|
Args:
|
|
video_path: Path to the video file.
|
|
|
|
Returns:
|
|
Duration of the video in seconds.
|
|
"""
|
|
with av.open(str(video_path)) as container:
|
|
# Get the first video stream
|
|
video_stream = container.streams.video[0]
|
|
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
|
|
if video_stream.duration is not None:
|
|
duration = float(video_stream.duration * video_stream.time_base)
|
|
else:
|
|
# Fallback to container duration if stream duration is not available
|
|
duration = float(container.duration / av.time_base)
|
|
return duration
|
|
|
|
|
|
class VideoEncodingManager:
|
|
"""
|
|
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
|
|
|
This manager handles:
|
|
- Batch encoding for any remaining episodes when recording interrupted
|
|
- Cleaning up temporary image files from interrupted episodes
|
|
- Removing empty image directories
|
|
|
|
Args:
|
|
dataset: The LeRobotDataset instance
|
|
"""
|
|
|
|
def __init__(self, dataset):
|
|
self.dataset = dataset
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
|
|
|
|
if streaming_encoder is not None:
|
|
# Handle streaming encoder cleanup
|
|
if exc_type is not None:
|
|
streaming_encoder.cancel_episode()
|
|
streaming_encoder.close()
|
|
elif self.dataset.episodes_since_last_encoding > 0:
|
|
# Handle any remaining episodes that haven't been batch encoded
|
|
if exc_type is not None:
|
|
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
|
else:
|
|
logging.info("Recording stopped. Encoding remaining episodes...")
|
|
|
|
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
|
|
end_ep = self.dataset.num_episodes
|
|
logging.info(
|
|
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
|
|
f"from episode {start_ep} to {end_ep - 1}"
|
|
)
|
|
self.dataset._batch_save_episode_video(start_ep, end_ep)
|
|
|
|
# Finalize the dataset to properly close all writers
|
|
self.dataset.finalize()
|
|
|
|
# Clean up episode images if recording was interrupted (only for non-streaming mode)
|
|
if exc_type is not None and streaming_encoder is None:
|
|
interrupted_episode_index = self.dataset.num_episodes
|
|
for key in self.dataset.meta.video_keys:
|
|
img_dir = self.dataset._get_image_file_path(
|
|
episode_index=interrupted_episode_index, image_key=key, frame_index=0
|
|
).parent
|
|
if img_dir.exists():
|
|
logging.debug(
|
|
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
|
)
|
|
shutil.rmtree(img_dir)
|
|
|
|
# Clean up any remaining images directory if it's empty
|
|
img_dir = self.dataset.root / "images"
|
|
if img_dir.exists():
|
|
png_files = list(img_dir.rglob("*.png"))
|
|
if len(png_files) == 0:
|
|
shutil.rmtree(img_dir)
|
|
logging.debug("Cleaned up empty images directory")
|
|
else:
|
|
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
|
|
|
return False # Don't suppress the original exception
|