mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
1090 lines
41 KiB
Python
1090 lines
41 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import contextlib
|
|
import glob
|
|
import importlib
|
|
import logging
|
|
import queue
|
|
import shutil
|
|
import tempfile
|
|
import threading
|
|
import warnings
|
|
from dataclasses import dataclass, field
|
|
from fractions import Fraction
|
|
from pathlib import Path
|
|
from threading import Lock
|
|
from typing import Any, ClassVar
|
|
|
|
import av
|
|
import fsspec
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
import torch
|
|
import torchvision
|
|
from datasets.features.features import register_feature
|
|
from PIL import Image
|
|
|
|
from lerobot.utils.import_utils import get_safe_default_codec
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
|
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
|
HW_ENCODERS = [
|
|
"h264_videotoolbox", # macOS
|
|
"hevc_videotoolbox", # macOS
|
|
"h264_nvenc", # NVIDIA GPU
|
|
"hevc_nvenc", # NVIDIA GPU
|
|
"h264_vaapi", # Linux Intel/AMD
|
|
"h264_qsv", # Intel Quick Sync
|
|
]
|
|
|
|
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
|
|
|
|
|
def _get_codec_options(
|
|
vcodec: str,
|
|
g: int | None = 2,
|
|
crf: int | None = 30,
|
|
preset: int | None = None,
|
|
) -> dict:
|
|
"""Build codec-specific options dict for video encoding."""
|
|
options = {}
|
|
|
|
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
|
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
|
options["g"] = str(g)
|
|
|
|
# Quality control (codec-specific parameter names)
|
|
if crf is not None:
|
|
if vcodec in ("h264", "hevc", "libsvtav1"):
|
|
options["crf"] = str(crf)
|
|
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
|
quality = max(1, min(100, int(100 - crf * 2)))
|
|
options["q:v"] = str(quality)
|
|
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
|
options["rc"] = "constqp"
|
|
options["qp"] = str(crf)
|
|
elif vcodec in ("h264_vaapi",):
|
|
options["qp"] = str(crf)
|
|
elif vcodec in ("h264_qsv",):
|
|
options["global_quality"] = str(crf)
|
|
|
|
# Preset (only for libsvtav1)
|
|
if vcodec == "libsvtav1":
|
|
options["preset"] = str(preset) if preset is not None else "12"
|
|
|
|
return options
|
|
|
|
|
|
def detect_available_hw_encoders() -> list[str]:
|
|
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
|
available = []
|
|
for codec_name in HW_ENCODERS:
|
|
try:
|
|
av.codec.Codec(codec_name, "w")
|
|
available.append(codec_name)
|
|
except Exception: # nosec B110
|
|
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
|
return available
|
|
|
|
|
|
def resolve_vcodec(vcodec: str) -> str:
|
|
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
|
if vcodec not in VALID_VIDEO_CODECS:
|
|
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
|
if vcodec != "auto":
|
|
logger.info(f"Using video codec: {vcodec}")
|
|
return vcodec
|
|
available = detect_available_hw_encoders()
|
|
for encoder in HW_ENCODERS:
|
|
if encoder in available:
|
|
logger.info(f"Auto-selected video codec: {encoder}")
|
|
return encoder
|
|
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
|
return "libsvtav1"
|
|
|
|
|
|
def decode_video_frames(
|
|
video_path: Path | str,
|
|
timestamps: list[float],
|
|
tolerance_s: float,
|
|
backend: str | None = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Decodes video frames using the specified backend.
|
|
|
|
Args:
|
|
video_path (Path): Path to the video file.
|
|
timestamps (list[float]): List of timestamps to extract frames.
|
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
|
|
|
Returns:
|
|
torch.Tensor: Decoded frames.
|
|
|
|
Currently supports torchcodec on cpu and pyav.
|
|
"""
|
|
if backend is None:
|
|
backend = get_safe_default_codec()
|
|
if backend == "torchcodec":
|
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
|
elif backend in ["pyav", "video_reader"]:
|
|
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
|
else:
|
|
raise ValueError(f"Unsupported video backend: {backend}")
|
|
|
|
|
|
def decode_video_frames_torchvision(
|
|
video_path: Path | str,
|
|
timestamps: list[float],
|
|
tolerance_s: float,
|
|
backend: str = "pyav",
|
|
log_loaded_timestamps: bool = False,
|
|
) -> torch.Tensor:
|
|
"""Loads frames associated to the requested timestamps of a video
|
|
|
|
The backend can be either "pyav" (default) or "video_reader".
|
|
"video_reader" requires installing torchvision from source, see:
|
|
https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
|
|
(note that you need to compile against ffmpeg<4.3)
|
|
|
|
While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup.
|
|
For more info on video decoding, see `benchmark/video/README.md`
|
|
|
|
See torchvision doc for more info on these two backends:
|
|
https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend
|
|
|
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
|
"""
|
|
video_path = str(video_path)
|
|
|
|
# set backend
|
|
keyframes_only = False
|
|
torchvision.set_video_backend(backend)
|
|
if backend == "pyav":
|
|
keyframes_only = True # pyav doesn't support accurate seek
|
|
|
|
# set a video stream reader
|
|
# TODO(rcadene): also load audio stream at the same time
|
|
reader = torchvision.io.VideoReader(video_path, "video")
|
|
|
|
# set the first and last requested timestamps
|
|
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
|
first_ts = min(timestamps)
|
|
last_ts = max(timestamps)
|
|
|
|
# access closest key frame of the first requested frame
|
|
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
|
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
|
reader.seek(first_ts, keyframes_only=keyframes_only)
|
|
|
|
# load all frames until last requested frame
|
|
loaded_frames = []
|
|
loaded_ts = []
|
|
for frame in reader:
|
|
current_ts = frame["pts"]
|
|
if log_loaded_timestamps:
|
|
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
|
loaded_frames.append(frame["data"])
|
|
loaded_ts.append(current_ts)
|
|
if current_ts >= last_ts:
|
|
break
|
|
|
|
if backend == "pyav":
|
|
reader.container.close()
|
|
|
|
reader = None
|
|
|
|
query_ts = torch.tensor(timestamps)
|
|
loaded_ts = torch.tensor(loaded_ts)
|
|
|
|
# compute distances between each query timestamp and timestamps of all loaded frames
|
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
|
min_, argmin_ = dist.min(1)
|
|
|
|
is_within_tol = min_ < tolerance_s
|
|
if not is_within_tol.all():
|
|
raise FrameTimestampError(
|
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
|
" It means that the closest frame that can be loaded from the video is too far away in time."
|
|
" This might be due to synchronization issues with timestamps during data collection."
|
|
" To be safe, we advise to ignore this item during training."
|
|
f"\nqueried timestamps: {query_ts}"
|
|
f"\nloaded timestamps: {loaded_ts}"
|
|
f"\nvideo: {video_path}"
|
|
f"\nbackend: {backend}"
|
|
)
|
|
|
|
# get closest frames to the query timestamps
|
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
|
closest_ts = loaded_ts[argmin_]
|
|
|
|
if log_loaded_timestamps:
|
|
logger.info(f"{closest_ts=}")
|
|
|
|
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
|
closest_frames = closest_frames.type(torch.float32) / 255
|
|
|
|
if len(timestamps) != len(closest_frames):
|
|
raise FrameTimestampError(
|
|
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
|
f"number of queried timestamps ({len(timestamps)})"
|
|
)
|
|
return closest_frames
|
|
|
|
|
|
class VideoDecoderCache:
|
|
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
|
|
|
def __init__(self):
|
|
self._cache: dict[str, tuple[Any, Any]] = {}
|
|
self._lock = Lock()
|
|
|
|
def get_decoder(self, video_path: str):
|
|
"""Get a cached decoder or create a new one."""
|
|
if importlib.util.find_spec("torchcodec"):
|
|
from torchcodec.decoders import VideoDecoder
|
|
else:
|
|
raise ImportError(
|
|
"'torchcodec' is required but not installed. "
|
|
"Install it with: pip install 'lerobot[dataset]' (or uv pip install 'lerobot[dataset]')"
|
|
)
|
|
|
|
video_path = str(video_path)
|
|
|
|
with self._lock:
|
|
if video_path not in self._cache:
|
|
file_handle = fsspec.open(video_path).__enter__()
|
|
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
|
self._cache[video_path] = (decoder, file_handle)
|
|
|
|
return self._cache[video_path][0]
|
|
|
|
def clear(self):
|
|
"""Clear the cache and close file handles."""
|
|
with self._lock:
|
|
for _, file_handle in self._cache.values():
|
|
file_handle.close()
|
|
self._cache.clear()
|
|
|
|
def size(self) -> int:
|
|
"""Return the number of cached decoders."""
|
|
with self._lock:
|
|
return len(self._cache)
|
|
|
|
|
|
class FrameTimestampError(ValueError):
|
|
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
|
|
|
|
pass
|
|
|
|
|
|
_default_decoder_cache = VideoDecoderCache()
|
|
|
|
|
|
def decode_video_frames_torchcodec(
|
|
video_path: Path | str,
|
|
timestamps: list[float],
|
|
tolerance_s: float,
|
|
log_loaded_timestamps: bool = False,
|
|
decoder_cache: VideoDecoderCache | None = None,
|
|
) -> torch.Tensor:
|
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
|
|
|
Args:
|
|
video_path: Path to the video file.
|
|
timestamps: List of timestamps to extract frames.
|
|
tolerance_s: Allowed deviation in seconds for frame retrieval.
|
|
log_loaded_timestamps: Whether to log loaded timestamps.
|
|
decoder_cache: Optional decoder cache instance. Uses default if None.
|
|
|
|
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
|
|
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
|
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
|
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
|
"""
|
|
if decoder_cache is None:
|
|
decoder_cache = _default_decoder_cache
|
|
|
|
# Use cached decoder instead of creating new one each time
|
|
decoder = decoder_cache.get_decoder(str(video_path))
|
|
|
|
loaded_ts = []
|
|
loaded_frames = []
|
|
|
|
# get metadata for frame information
|
|
metadata = decoder.metadata
|
|
average_fps = metadata.average_fps
|
|
# convert timestamps to frame indices
|
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
|
# retrieve frames based on indices
|
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
|
|
|
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
|
|
loaded_frames.append(frame)
|
|
loaded_ts.append(pts.item())
|
|
if log_loaded_timestamps:
|
|
logger.info(f"Frame loaded at timestamp={pts:.4f}")
|
|
|
|
query_ts = torch.tensor(timestamps)
|
|
loaded_ts = torch.tensor(loaded_ts)
|
|
|
|
# compute distances between each query timestamp and loaded timestamps
|
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
|
min_, argmin_ = dist.min(1)
|
|
|
|
is_within_tol = min_ < tolerance_s
|
|
if not is_within_tol.all():
|
|
raise FrameTimestampError(
|
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
|
" It means that the closest frame that can be loaded from the video is too far away in time."
|
|
" This might be due to synchronization issues with timestamps during data collection."
|
|
" To be safe, we advise to ignore this item during training."
|
|
f"\nqueried timestamps: {query_ts}"
|
|
f"\nloaded timestamps: {loaded_ts}"
|
|
f"\nvideo: {video_path}"
|
|
)
|
|
|
|
# get closest frames to the query timestamps
|
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
|
closest_ts = loaded_ts[argmin_]
|
|
|
|
if log_loaded_timestamps:
|
|
logger.info(f"{closest_ts=}")
|
|
|
|
# convert to float32 in [0,1] range
|
|
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
|
|
|
if not len(timestamps) == len(closest_frames):
|
|
raise FrameTimestampError(
|
|
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
|
)
|
|
|
|
return closest_frames
|
|
|
|
|
|
def encode_video_frames(
|
|
imgs_dir: Path | str,
|
|
video_path: Path | str,
|
|
fps: int,
|
|
vcodec: str = "libsvtav1",
|
|
pix_fmt: str = "yuv420p",
|
|
g: int | None = 2,
|
|
crf: int | None = 30,
|
|
fast_decode: int = 0,
|
|
log_level: int | None = av.logging.WARNING,
|
|
overwrite: bool = False,
|
|
preset: int | None = None,
|
|
encoder_threads: int | None = None,
|
|
) -> None:
|
|
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
|
vcodec = resolve_vcodec(vcodec)
|
|
|
|
video_path = Path(video_path)
|
|
imgs_dir = Path(imgs_dir)
|
|
|
|
if video_path.exists() and not overwrite:
|
|
logger.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
|
return
|
|
|
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Encoders/pixel formats incompatibility check
|
|
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
|
logger.warning(
|
|
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
|
)
|
|
pix_fmt = "yuv420p"
|
|
|
|
# Get input frames
|
|
template = "frame-" + ("[0-9]" * 6) + ".png"
|
|
input_list = sorted(
|
|
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
|
)
|
|
|
|
# Define video output frame size (assuming all input frames are the same size)
|
|
if len(input_list) == 0:
|
|
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
|
with Image.open(input_list[0]) as dummy_image:
|
|
width, height = dummy_image.size
|
|
|
|
# Define video codec options
|
|
video_options = _get_codec_options(vcodec, g, crf, preset)
|
|
|
|
if fast_decode:
|
|
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
|
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
|
video_options[key] = value
|
|
|
|
if encoder_threads is not None:
|
|
if vcodec == "libsvtav1":
|
|
lp_param = f"lp={encoder_threads}"
|
|
if "svtav1-params" in video_options:
|
|
video_options["svtav1-params"] += f":{lp_param}"
|
|
else:
|
|
video_options["svtav1-params"] = lp_param
|
|
else:
|
|
video_options["threads"] = str(encoder_threads)
|
|
|
|
# Set logging level
|
|
if log_level is not None:
|
|
# "While less efficient, it is generally preferable to modify logging with Python's logging"
|
|
logging.getLogger("libav").setLevel(log_level)
|
|
|
|
# Create and open output file (overwrite by default)
|
|
with av.open(str(video_path), "w") as output:
|
|
output_stream = output.add_stream(vcodec, fps, options=video_options)
|
|
output_stream.pix_fmt = pix_fmt
|
|
output_stream.width = width
|
|
output_stream.height = height
|
|
|
|
# Loop through input frames and encode them
|
|
for input_data in input_list:
|
|
with Image.open(input_data) as input_image:
|
|
input_image = input_image.convert("RGB")
|
|
input_frame = av.VideoFrame.from_image(input_image)
|
|
packet = output_stream.encode(input_frame)
|
|
if packet:
|
|
output.mux(packet)
|
|
|
|
# Flush the encoder
|
|
packet = output_stream.encode()
|
|
if packet:
|
|
output.mux(packet)
|
|
|
|
# Reset logging level
|
|
if log_level is not None:
|
|
av.logging.restore_default_callback()
|
|
|
|
if not video_path.exists():
|
|
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
|
|
|
|
|
def concatenate_video_files(
|
|
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
|
):
|
|
"""
|
|
Concatenate multiple video files into a single video file using pyav.
|
|
|
|
This function takes a list of video input file paths and concatenates them into a single
|
|
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
|
concatenation without re-encoding.
|
|
|
|
Args:
|
|
input_video_paths: Ordered list of input video file paths to concatenate.
|
|
output_video_path: Path to the output video file.
|
|
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
|
|
|
Note:
|
|
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
|
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
|
|
codec, resolution, and frame rate for proper concatenation.
|
|
"""
|
|
|
|
output_video_path = Path(output_video_path)
|
|
|
|
if output_video_path.exists() and not overwrite:
|
|
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
|
return
|
|
|
|
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if len(input_video_paths) == 0:
|
|
raise FileNotFoundError("No input video paths provided.")
|
|
|
|
# Create a temporary .ffconcat file to list the input video paths
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
|
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
|
for input_path in input_video_paths:
|
|
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
|
tmp_concatenate_file.flush()
|
|
tmp_concatenate_path = tmp_concatenate_file.name
|
|
|
|
# Create input and output containers
|
|
input_container = av.open(
|
|
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
|
) # safe = 0 allows absolute paths as well as relative paths
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
|
tmp_output_video_path = tmp_named_file.name
|
|
|
|
output_container = av.open(
|
|
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
|
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
|
|
|
# Replicate input streams in output container
|
|
stream_map = {}
|
|
for input_stream in input_container.streams:
|
|
if input_stream.type in ("video", "audio", "subtitle"): # only copy compatible streams
|
|
stream_map[input_stream.index] = output_container.add_stream_from_template(
|
|
template=input_stream, opaque=True
|
|
)
|
|
|
|
# set the time base to the input stream time base (missing in the codec context)
|
|
stream_map[input_stream.index].time_base = input_stream.time_base
|
|
|
|
# Demux + remux packets (no re-encode)
|
|
for packet in input_container.demux():
|
|
# Skip packets from un-mapped streams
|
|
if packet.stream.index not in stream_map:
|
|
continue
|
|
|
|
# Skip demux flushing packets
|
|
if packet.dts is None:
|
|
continue
|
|
|
|
output_stream = stream_map[packet.stream.index]
|
|
packet.stream = output_stream
|
|
output_container.mux(packet)
|
|
|
|
input_container.close()
|
|
output_container.close()
|
|
shutil.move(tmp_output_video_path, output_video_path)
|
|
Path(tmp_concatenate_path).unlink()
|
|
|
|
|
|
class _CameraEncoderThread(threading.Thread):
|
|
"""A thread that encodes video frames streamed via a queue into an MP4 file.
|
|
|
|
One instance is created per camera per episode. Frames are received as numpy arrays
|
|
from the main thread, encoded in real-time using PyAV (which releases the GIL during
|
|
encoding), and written to disk. Stats are computed incrementally using
|
|
RunningQuantileStats and returned via result_queue.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
video_path: Path,
|
|
fps: int,
|
|
vcodec: str,
|
|
pix_fmt: str,
|
|
g: int | None,
|
|
crf: int | None,
|
|
preset: int | None,
|
|
frame_queue: queue.Queue,
|
|
result_queue: queue.Queue,
|
|
stop_event: threading.Event,
|
|
encoder_threads: int | None = None,
|
|
):
|
|
super().__init__(daemon=True)
|
|
self.video_path = video_path
|
|
self.fps = fps
|
|
self.vcodec = vcodec
|
|
self.pix_fmt = pix_fmt
|
|
self.g = g
|
|
self.crf = crf
|
|
self.preset = preset
|
|
self.frame_queue = frame_queue
|
|
self.result_queue = result_queue
|
|
self.stop_event = stop_event
|
|
self.encoder_threads = encoder_threads
|
|
|
|
def run(self) -> None:
|
|
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
|
|
|
container = None
|
|
output_stream = None
|
|
stats_tracker = RunningQuantileStats()
|
|
frame_count = 0
|
|
|
|
try:
|
|
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
|
|
|
while True:
|
|
try:
|
|
frame_data = self.frame_queue.get(timeout=1)
|
|
except queue.Empty:
|
|
if self.stop_event.is_set():
|
|
break
|
|
continue
|
|
|
|
if frame_data is None:
|
|
# Sentinel: flush and close
|
|
break
|
|
|
|
# Ensure HWC uint8 numpy array
|
|
if isinstance(frame_data, np.ndarray):
|
|
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
|
# CHW -> HWC
|
|
frame_data = frame_data.transpose(1, 2, 0)
|
|
if frame_data.dtype != np.uint8:
|
|
frame_data = (frame_data * 255).astype(np.uint8)
|
|
|
|
# Open container on first frame (to get width/height)
|
|
if container is None:
|
|
height, width = frame_data.shape[:2]
|
|
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
|
if self.encoder_threads is not None:
|
|
if self.vcodec == "libsvtav1":
|
|
lp_param = f"lp={self.encoder_threads}"
|
|
if "svtav1-params" in video_options:
|
|
video_options["svtav1-params"] += f":{lp_param}"
|
|
else:
|
|
video_options["svtav1-params"] = lp_param
|
|
else:
|
|
video_options["threads"] = str(self.encoder_threads)
|
|
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
|
container = av.open(str(self.video_path), "w")
|
|
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
|
output_stream.pix_fmt = self.pix_fmt
|
|
output_stream.width = width
|
|
output_stream.height = height
|
|
output_stream.time_base = Fraction(1, self.fps)
|
|
|
|
# Encode frame with explicit timestamps
|
|
pil_img = Image.fromarray(frame_data)
|
|
video_frame = av.VideoFrame.from_image(pil_img)
|
|
video_frame.pts = frame_count
|
|
video_frame.time_base = Fraction(1, self.fps)
|
|
packet = output_stream.encode(video_frame)
|
|
if packet:
|
|
container.mux(packet)
|
|
|
|
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
|
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
|
img_downsampled = auto_downsample_height_width(img_chw)
|
|
# Reshape CHW to (H*W, C) for per-channel stats
|
|
channels = img_downsampled.shape[0]
|
|
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
|
stats_tracker.update(img_for_stats)
|
|
|
|
frame_count += 1
|
|
|
|
# Flush encoder
|
|
if output_stream is not None:
|
|
packet = output_stream.encode()
|
|
if packet:
|
|
container.mux(packet)
|
|
|
|
if container is not None:
|
|
container.close()
|
|
|
|
av.logging.restore_default_callback()
|
|
|
|
# Get stats and put on result queue
|
|
if frame_count >= 2:
|
|
stats = stats_tracker.get_statistics()
|
|
self.result_queue.put(("ok", stats))
|
|
else:
|
|
self.result_queue.put(("ok", None))
|
|
|
|
except Exception as e:
|
|
logger.error(f"Encoder thread error: {e}")
|
|
if container is not None:
|
|
with contextlib.suppress(Exception):
|
|
container.close()
|
|
self.result_queue.put(("error", str(e)))
|
|
|
|
|
|
class StreamingVideoEncoder:
|
|
"""Manages per-camera encoder threads for real-time video encoding during recording.
|
|
|
|
Instead of writing frames as PNG images and then encoding to MP4 at episode end,
|
|
this class streams frames directly to encoder threads, eliminating the
|
|
PNG round-trip and making save_episode() near-instant.
|
|
|
|
Uses threading instead of multiprocessing to avoid the overhead of pickling large
|
|
numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL,
|
|
so encoding runs in parallel with the main recording loop.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fps: int,
|
|
vcodec: str = "libsvtav1",
|
|
pix_fmt: str = "yuv420p",
|
|
g: int | None = 2,
|
|
crf: int | None = 30,
|
|
preset: int | None = None,
|
|
queue_maxsize: int = 30,
|
|
encoder_threads: int | None = None,
|
|
):
|
|
self.fps = fps
|
|
self.vcodec = resolve_vcodec(vcodec)
|
|
self.pix_fmt = pix_fmt
|
|
self.g = g
|
|
self.crf = crf
|
|
self.preset = preset
|
|
self.queue_maxsize = queue_maxsize
|
|
self.encoder_threads = encoder_threads
|
|
|
|
self._frame_queues: dict[str, queue.Queue] = {}
|
|
self._result_queues: dict[str, queue.Queue] = {}
|
|
self._threads: dict[str, _CameraEncoderThread] = {}
|
|
self._stop_events: dict[str, threading.Event] = {}
|
|
self._video_paths: dict[str, Path] = {}
|
|
self._dropped_frames: dict[str, int] = {}
|
|
self._episode_active = False
|
|
self._closed = False
|
|
|
|
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
|
"""Start encoder threads for a new episode.
|
|
|
|
Args:
|
|
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
|
temp_dir: Base directory for temporary MP4 files
|
|
"""
|
|
if self._episode_active:
|
|
self.cancel_episode()
|
|
|
|
self._dropped_frames.clear()
|
|
|
|
for video_key in video_keys:
|
|
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
|
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
|
stop_event = threading.Event()
|
|
|
|
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
|
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
|
|
|
encoder_thread = _CameraEncoderThread(
|
|
video_path=video_path,
|
|
fps=self.fps,
|
|
vcodec=self.vcodec,
|
|
pix_fmt=self.pix_fmt,
|
|
g=self.g,
|
|
crf=self.crf,
|
|
preset=self.preset,
|
|
frame_queue=frame_queue,
|
|
result_queue=result_queue,
|
|
stop_event=stop_event,
|
|
encoder_threads=self.encoder_threads,
|
|
)
|
|
encoder_thread.start()
|
|
|
|
self._frame_queues[video_key] = frame_queue
|
|
self._result_queues[video_key] = result_queue
|
|
self._threads[video_key] = encoder_thread
|
|
self._stop_events[video_key] = stop_event
|
|
self._video_paths[video_key] = video_path
|
|
|
|
self._episode_active = True
|
|
|
|
def feed_frame(self, video_key: str, image: np.ndarray) -> None:
|
|
"""Feed a frame to the encoder for a specific camera.
|
|
|
|
A copy of the image is made before enqueueing to prevent race conditions
|
|
with camera drivers that may reuse buffers. If the encoder queue is full
|
|
(encoder can't keep up), the frame is dropped with a warning instead of
|
|
crashing the recording session.
|
|
|
|
Args:
|
|
video_key: The video feature key
|
|
image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float
|
|
|
|
Raises:
|
|
RuntimeError: If the encoder thread has crashed
|
|
"""
|
|
if not self._episode_active:
|
|
raise RuntimeError("No active episode. Call start_episode() first.")
|
|
|
|
thread = self._threads[video_key]
|
|
if not thread.is_alive():
|
|
# Check for error
|
|
try:
|
|
status, msg = self._result_queues[video_key].get_nowait()
|
|
if status == "error":
|
|
raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}")
|
|
except queue.Empty:
|
|
pass
|
|
raise RuntimeError(f"Encoder thread for {video_key} is not alive")
|
|
|
|
try:
|
|
self._frame_queues[video_key].put(image.copy(), timeout=0.1)
|
|
except queue.Full:
|
|
self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1
|
|
count = self._dropped_frames[video_key]
|
|
# Log periodically to avoid spam (1st, then every 10th)
|
|
if count == 1 or count % 10 == 0:
|
|
logger.warning(
|
|
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
|
|
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
|
|
)
|
|
|
|
def finish_episode(self) -> dict[str, tuple[Path, dict | None]]:
|
|
"""Finish encoding the current episode.
|
|
|
|
Sends sentinel values, waits for encoder threads to complete,
|
|
and collects results.
|
|
|
|
Returns:
|
|
Dict mapping video_key to (mp4_path, stats_dict_or_None)
|
|
"""
|
|
if not self._episode_active:
|
|
raise RuntimeError("No active episode to finish.")
|
|
|
|
results = {}
|
|
|
|
# Report dropped frames
|
|
for video_key, count in self._dropped_frames.items():
|
|
if count > 0:
|
|
logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
|
|
|
# Send sentinel to all queues
|
|
for video_key in self._frame_queues:
|
|
self._frame_queues[video_key].put(None)
|
|
|
|
# Wait for all threads and collect results
|
|
for video_key in self._threads:
|
|
self._threads[video_key].join(timeout=120)
|
|
if self._threads[video_key].is_alive():
|
|
logger.error(f"Encoder thread for {video_key} did not finish in time")
|
|
self._stop_events[video_key].set()
|
|
self._threads[video_key].join(timeout=5)
|
|
results[video_key] = (self._video_paths[video_key], None)
|
|
continue
|
|
|
|
try:
|
|
status, data = self._result_queues[video_key].get(timeout=5)
|
|
if status == "error":
|
|
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
|
|
results[video_key] = (self._video_paths[video_key], data)
|
|
except queue.Empty:
|
|
logger.error(f"No result from encoder thread for {video_key}")
|
|
results[video_key] = (self._video_paths[video_key], None)
|
|
|
|
self._cleanup()
|
|
self._episode_active = False
|
|
return results
|
|
|
|
def cancel_episode(self) -> None:
|
|
"""Cancel the current episode, stopping encoder threads and cleaning up."""
|
|
if not self._episode_active:
|
|
return
|
|
|
|
# Signal all threads to stop
|
|
for video_key in self._stop_events:
|
|
self._stop_events[video_key].set()
|
|
|
|
# Wait for threads to finish
|
|
for video_key in self._threads:
|
|
self._threads[video_key].join(timeout=5)
|
|
|
|
# Clean up temp MP4 files
|
|
video_path = self._video_paths.get(video_key)
|
|
if video_path is not None and video_path.exists():
|
|
shutil.rmtree(str(video_path.parent), ignore_errors=True)
|
|
|
|
self._cleanup()
|
|
self._episode_active = False
|
|
|
|
def close(self) -> None:
|
|
"""Close the encoder, canceling any in-progress episode."""
|
|
if self._closed:
|
|
return
|
|
if self._episode_active:
|
|
self.cancel_episode()
|
|
self._closed = True
|
|
|
|
def _cleanup(self) -> None:
|
|
"""Clean up queues and thread tracking dicts."""
|
|
for q in self._frame_queues.values():
|
|
with contextlib.suppress(Exception):
|
|
while not q.empty():
|
|
q.get_nowait()
|
|
self._frame_queues.clear()
|
|
self._result_queues.clear()
|
|
self._threads.clear()
|
|
self._stop_events.clear()
|
|
self._video_paths.clear()
|
|
|
|
|
|
@dataclass
|
|
class VideoFrame:
|
|
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
|
"""
|
|
Provides a type for a dataset containing video frames.
|
|
|
|
Example:
|
|
|
|
```python
|
|
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
|
features = {"image": VideoFrame()}
|
|
Dataset.from_dict(data_dict, features=Features(features))
|
|
```
|
|
"""
|
|
|
|
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
|
|
_type: str = field(default="VideoFrame", init=False, repr=False)
|
|
|
|
def __call__(self):
|
|
return self.pa_type
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
|
category=UserWarning,
|
|
)
|
|
# to make VideoFrame available in HuggingFace `datasets`
|
|
register_feature(VideoFrame, "VideoFrame")
|
|
|
|
|
|
def get_audio_info(video_path: Path | str) -> dict:
|
|
# Set logging level
|
|
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
|
|
|
# Getting audio stream information
|
|
audio_info = {}
|
|
with av.open(str(video_path), "r") as audio_file:
|
|
try:
|
|
audio_stream = audio_file.streams.audio[0]
|
|
except IndexError:
|
|
# Reset logging level
|
|
av.logging.restore_default_callback()
|
|
return {"has_audio": False}
|
|
|
|
audio_info["audio.channels"] = audio_stream.channels
|
|
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
|
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
|
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
|
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
|
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
|
# In an ideal loseless case : fixed number of bits per sample.
|
|
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
|
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
|
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
|
audio_info["has_audio"] = True
|
|
|
|
# Reset logging level
|
|
av.logging.restore_default_callback()
|
|
|
|
return audio_info
|
|
|
|
|
|
def get_video_info(video_path: Path | str) -> dict:
|
|
# Set logging level
|
|
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
|
|
|
# Getting video stream information
|
|
video_info = {}
|
|
with av.open(str(video_path), "r") as video_file:
|
|
try:
|
|
video_stream = video_file.streams.video[0]
|
|
except IndexError:
|
|
# Reset logging level
|
|
av.logging.restore_default_callback()
|
|
return {}
|
|
|
|
video_info["video.height"] = video_stream.height
|
|
video_info["video.width"] = video_stream.width
|
|
video_info["video.codec"] = video_stream.codec.canonical_name
|
|
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
|
video_info["video.is_depth_map"] = False
|
|
|
|
# Calculate fps from r_frame_rate
|
|
video_info["video.fps"] = int(video_stream.base_rate)
|
|
|
|
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
|
video_info["video.channels"] = pixel_channels
|
|
|
|
# Reset logging level
|
|
av.logging.restore_default_callback()
|
|
|
|
# Adding audio stream information
|
|
video_info.update(**get_audio_info(video_path))
|
|
|
|
return video_info
|
|
|
|
|
|
def get_video_pixel_channels(pix_fmt: str) -> int:
|
|
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
|
return 1
|
|
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
|
return 4
|
|
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
|
return 3
|
|
else:
|
|
raise ValueError("Unknown format")
|
|
|
|
|
|
def get_video_duration_in_s(video_path: Path | str) -> float:
|
|
"""
|
|
Get the duration of a video file in seconds using PyAV.
|
|
|
|
Args:
|
|
video_path: Path to the video file.
|
|
|
|
Returns:
|
|
Duration of the video in seconds.
|
|
"""
|
|
with av.open(str(video_path)) as container:
|
|
# Get the first video stream
|
|
video_stream = container.streams.video[0]
|
|
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
|
|
if video_stream.duration is not None:
|
|
duration = float(video_stream.duration * video_stream.time_base)
|
|
else:
|
|
# Fallback to container duration if stream duration is not available
|
|
duration = float(container.duration / av.time_base)
|
|
return duration
|
|
|
|
|
|
class VideoEncodingManager:
|
|
"""
|
|
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
|
|
|
This manager handles:
|
|
- Batch encoding for any remaining episodes when recording interrupted
|
|
- Cleaning up temporary image files from interrupted episodes
|
|
- Removing empty image directories
|
|
|
|
Args:
|
|
dataset: The LeRobotDataset instance
|
|
"""
|
|
|
|
def __init__(self, dataset):
|
|
self.dataset = dataset
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
writer = self.dataset.writer
|
|
if writer is not None:
|
|
if exc_type is not None and writer._streaming_encoder is not None:
|
|
writer.cancel_pending_videos()
|
|
|
|
# finalize() handles flush_pending_videos + parquet + metadata
|
|
self.dataset.finalize()
|
|
|
|
# Clean up episode images if recording was interrupted (only for non-streaming mode)
|
|
if exc_type is not None and writer._streaming_encoder is None:
|
|
writer.cleanup_interrupted_episode(self.dataset.num_episodes)
|
|
else:
|
|
self.dataset.finalize()
|
|
|
|
# Clean up any remaining images directory if it's empty
|
|
img_dir = self.dataset.root / "images"
|
|
if img_dir.exists():
|
|
png_files = list(img_dir.rglob("*.png"))
|
|
if len(png_files) == 0:
|
|
shutil.rmtree(img_dir)
|
|
logger.debug("Cleaned up empty images directory")
|
|
else:
|
|
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
|
|
|
return False # Don't suppress the original exception
|