2024-05-15 12:13:09 +02:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2026-02-23 13:57:43 +01:00
import contextlib
2025-04-29 17:39:35 +02:00
import glob
2025-03-17 13:23:11 +01:00
import importlib
2024-05-03 00:50:19 +02:00
import logging
2026-02-23 13:57:43 +01:00
import queue
2025-07-18 19:18:52 +09:00
import shutil
2025-09-15 09:53:30 +02:00
import tempfile
2026-02-23 13:57:43 +01:00
import threading
2024-05-03 00:50:19 +02:00
import warnings
from dataclasses import dataclass , field
2026-02-23 13:57:43 +01:00
from fractions import Fraction
2024-05-03 00:50:19 +02:00
from pathlib import Path
2025-09-15 14:08:01 +02:00
from threading import Lock
2024-05-03 00:50:19 +02:00
from typing import Any , ClassVar
2025-04-29 17:39:35 +02:00
import av
2025-09-15 14:08:01 +02:00
import fsspec
2026-02-23 13:57:43 +01:00
import numpy as np
2024-05-03 00:50:19 +02:00
import pyarrow as pa
import torch
import torchvision
from datasets . features . features import register_feature
2024-11-29 19:04:00 +01:00
from PIL import Image
2025-03-17 13:23:11 +01:00
2026-03-15 22:12:09 -07:00
logger = logging . getLogger ( __name__ )
2026-02-23 13:57:43 +01:00
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_ENCODERS = [
" h264_videotoolbox " , # macOS
" hevc_videotoolbox " , # macOS
" h264_nvenc " , # NVIDIA GPU
" hevc_nvenc " , # NVIDIA GPU
" h264_vaapi " , # Linux Intel/AMD
" h264_qsv " , # Intel Quick Sync
]
VALID_VIDEO_CODECS = { " h264 " , " hevc " , " libsvtav1 " , " auto " } | set ( HW_ENCODERS )
def _get_codec_options (
vcodec : str ,
g : int | None = 2 ,
crf : int | None = 30 ,
preset : int | None = None ,
) - > dict :
""" Build codec-specific options dict for video encoding. """
options = { }
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
if g is not None and ( vcodec in ( " h264_videotoolbox " , " hevc_videotoolbox " ) or vcodec not in HW_ENCODERS ) :
options [ " g " ] = str ( g )
# Quality control (codec-specific parameter names)
if crf is not None :
if vcodec in ( " h264 " , " hevc " , " libsvtav1 " ) :
options [ " crf " ] = str ( crf )
elif vcodec in ( " h264_videotoolbox " , " hevc_videotoolbox " ) :
quality = max ( 1 , min ( 100 , int ( 100 - crf * 2 ) ) )
options [ " q:v " ] = str ( quality )
elif vcodec in ( " h264_nvenc " , " hevc_nvenc " ) :
options [ " rc " ] = " constqp "
options [ " qp " ] = str ( crf )
elif vcodec in ( " h264_vaapi " , ) :
options [ " qp " ] = str ( crf )
elif vcodec in ( " h264_qsv " , ) :
options [ " global_quality " ] = str ( crf )
# Preset (only for libsvtav1)
if vcodec == " libsvtav1 " :
options [ " preset " ] = str ( preset ) if preset is not None else " 12 "
return options
def detect_available_hw_encoders ( ) - > list [ str ] :
""" Probe PyAV/FFmpeg for available hardware video encoders. """
available = [ ]
for codec_name in HW_ENCODERS :
try :
av . codec . Codec ( codec_name , " w " )
available . append ( codec_name )
except Exception : # nosec B110
2026-03-15 22:12:09 -07:00
logger . debug ( " HW encoder ' %s ' not available " , codec_name ) # nosec B110
2026-02-23 13:57:43 +01:00
return available
def resolve_vcodec ( vcodec : str ) - > str :
""" Validate vcodec and resolve ' auto ' to best available HW encoder, fallback to libsvtav1. """
if vcodec not in VALID_VIDEO_CODECS :
raise ValueError ( f " Invalid vcodec ' { vcodec } ' . Must be one of: { sorted ( VALID_VIDEO_CODECS ) } " )
if vcodec != " auto " :
2026-03-15 22:12:09 -07:00
logger . info ( f " Using video codec: { vcodec } " )
2026-02-23 13:57:43 +01:00
return vcodec
available = detect_available_hw_encoders ( )
for encoder in HW_ENCODERS :
if encoder in available :
2026-03-15 22:12:09 -07:00
logger . info ( f " Auto-selected video codec: { encoder } " )
2026-02-23 13:57:43 +01:00
return encoder
2026-03-15 22:12:09 -07:00
logger . info ( " No hardware encoder available, falling back to software encoder ' libsvtav1 ' " )
2026-02-23 13:57:43 +01:00
return " libsvtav1 "
2025-03-17 13:23:11 +01:00
def get_safe_default_codec ( ) :
if importlib . util . find_spec ( " torchcodec " ) :
return " torchcodec "
else :
2026-03-15 22:12:09 -07:00
logger . warning (
2025-03-17 13:23:11 +01:00
" ' torchcodec ' is not available in your platform, falling back to ' pyav ' as a default decoder "
)
return " pyav "
2025-03-14 18:53:42 +03:00
def decode_video_frames (
video_path : Path | str ,
timestamps : list [ float ] ,
tolerance_s : float ,
2025-03-17 13:23:11 +01:00
backend : str | None = None ,
2025-03-14 18:53:42 +03:00
) - > torch . Tensor :
"""
Decodes video frames using the specified backend .
Args :
video_path ( Path ) : Path to the video file .
timestamps ( list [ float ] ) : List of timestamps to extract frames .
tolerance_s ( float ) : Allowed deviation in seconds for frame retrieval .
2025-03-17 13:23:11 +01:00
backend ( str , optional ) : Backend to use for decoding . Defaults to " torchcodec " when available in the platform ; otherwise , defaults to " pyav " . .
2025-03-14 18:53:42 +03:00
Returns :
torch . Tensor : Decoded frames .
Currently supports torchcodec on cpu and pyav .
"""
2025-03-17 13:23:11 +01:00
if backend is None :
backend = get_safe_default_codec ( )
2025-03-14 18:53:42 +03:00
if backend == " torchcodec " :
return decode_video_frames_torchcodec ( video_path , timestamps , tolerance_s )
elif backend in [ " pyav " , " video_reader " ] :
return decode_video_frames_torchvision ( video_path , timestamps , tolerance_s , backend )
else :
raise ValueError ( f " Unsupported video backend: { backend } " )
2024-05-03 00:50:19 +02:00
def decode_video_frames_torchvision (
2024-11-29 19:04:00 +01:00
video_path : Path | str ,
2024-05-03 00:50:19 +02:00
timestamps : list [ float ] ,
tolerance_s : float ,
2024-06-19 17:15:25 +02:00
backend : str = " pyav " ,
2024-05-03 00:50:19 +02:00
log_loaded_timestamps : bool = False ,
2024-07-09 20:20:25 +02:00
) - > torch . Tensor :
2024-05-03 00:50:19 +02:00
""" Loads frames associated to the requested timestamps of a video
2024-06-19 17:15:25 +02:00
The backend can be either " pyav " ( default ) or " video_reader " .
" video_reader " requires installing torchvision from source , see :
https : / / github . com / pytorch / vision / blob / main / torchvision / csrc / io / decoder / gpu / README . rst
( note that you need to compile against ffmpeg < 4.3 )
2024-07-09 20:20:25 +02:00
While both use cpu , " video_reader " is supposedly faster than " pyav " but requires additional setup .
For more info on video decoding , see ` benchmark / video / README . md `
2024-06-19 17:15:25 +02:00
See torchvision doc for more info on these two backends :
https : / / pytorch . org / vision / 0.18 / index . html ? highlight = backend #torchvision.set_video_backend
2024-05-03 00:50:19 +02:00
Note : Video benefits from inter - frame compression . Instead of storing every frame individually ,
the encoder stores a reference frame ( or a key frame ) and subsequent frames as differences relative to
that key frame . As a consequence , to access a requested frame , we need to load the preceding key frame ,
and all subsequent frames until reaching the requested frame . The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes .
"""
video_path = str ( video_path )
# set backend
keyframes_only = False
2024-06-19 17:15:25 +02:00
torchvision . set_video_backend ( backend )
if backend == " pyav " :
2025-05-25 17:20:45 +02:00
keyframes_only = True # pyav doesn't support accurate seek
2024-05-03 00:50:19 +02:00
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision . io . VideoReader ( video_path , " video " )
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
2025-02-25 15:27:29 +01:00
first_ts = min ( timestamps )
last_ts = max ( timestamps )
2024-05-03 00:50:19 +02:00
# access closest key frame of the first requested frame
2025-02-25 23:51:15 +01:00
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
2024-05-03 00:50:19 +02:00
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader . seek ( first_ts , keyframes_only = keyframes_only )
# load all frames until last requested frame
loaded_frames = [ ]
loaded_ts = [ ]
for frame in reader :
current_ts = frame [ " pts " ]
if log_loaded_timestamps :
2026-03-15 22:12:09 -07:00
logger . info ( f " frame loaded at timestamp= { current_ts : .4f } " )
2024-05-03 00:50:19 +02:00
loaded_frames . append ( frame [ " data " ] )
loaded_ts . append ( current_ts )
if current_ts > = last_ts :
break
2024-06-19 17:15:25 +02:00
if backend == " pyav " :
reader . container . close ( )
2024-05-03 00:50:19 +02:00
reader = None
query_ts = torch . tensor ( timestamps )
loaded_ts = torch . tensor ( loaded_ts )
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch . cdist ( query_ts [ : , None ] , loaded_ts [ : , None ] , p = 1 )
min_ , argmin_ = dist . min ( 1 )
is_within_tol = min_ < tolerance_s
2026-02-25 03:29:22 -08:00
if not is_within_tol . all ( ) :
raise FrameTimestampError (
f " One or several query timestamps unexpectedly violate the tolerance ( { min_ [ ~ is_within_tol ] } > { tolerance_s =} ). "
" It means that the closest frame that can be loaded from the video is too far away in time. "
" This might be due to synchronization issues with timestamps during data collection. "
" To be safe, we advise to ignore this item during training. "
f " \n queried timestamps: { query_ts } "
f " \n loaded timestamps: { loaded_ts } "
f " \n video: { video_path } "
f " \n backend: { backend } "
)
2024-05-03 00:50:19 +02:00
# get closest frames to the query timestamps
closest_frames = torch . stack ( [ loaded_frames [ idx ] for idx in argmin_ ] )
closest_ts = loaded_ts [ argmin_ ]
if log_loaded_timestamps :
2026-03-15 22:12:09 -07:00
logger . info ( f " { closest_ts =} " )
2024-05-03 00:50:19 +02:00
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames . type ( torch . float32 ) / 255
2026-02-25 03:29:22 -08:00
if len ( timestamps ) != len ( closest_frames ) :
raise FrameTimestampError (
f " Number of retrieved frames ( { len ( closest_frames ) } ) does not match "
f " number of queried timestamps ( { len ( timestamps ) } ) "
)
2024-05-03 00:50:19 +02:00
return closest_frames
2025-09-15 14:08:01 +02:00
class VideoDecoderCache :
""" Thread-safe cache for video decoders to avoid expensive re-initialization. """
def __init__ ( self ) :
self . _cache : dict [ str , tuple [ Any , Any ] ] = { }
self . _lock = Lock ( )
def get_decoder ( self , video_path : str ) :
""" Get a cached decoder or create a new one. """
if importlib . util . find_spec ( " torchcodec " ) :
from torchcodec . decoders import VideoDecoder
else :
raise ImportError ( " torchcodec is required but not available. " )
video_path = str ( video_path )
with self . _lock :
if video_path not in self . _cache :
file_handle = fsspec . open ( video_path ) . __enter__ ( )
decoder = VideoDecoder ( file_handle , seek_mode = " approximate " )
self . _cache [ video_path ] = ( decoder , file_handle )
return self . _cache [ video_path ] [ 0 ]
def clear ( self ) :
""" Clear the cache and close file handles. """
with self . _lock :
for _ , file_handle in self . _cache . values ( ) :
file_handle . close ( )
self . _cache . clear ( )
def size ( self ) - > int :
""" Return the number of cached decoders. """
with self . _lock :
return len ( self . _cache )
class FrameTimestampError ( ValueError ) :
""" Helper error to indicate the retrieved timestamps exceed the queried ones """
pass
_default_decoder_cache = VideoDecoderCache ( )
2025-03-14 18:53:42 +03:00
def decode_video_frames_torchcodec (
video_path : Path | str ,
timestamps : list [ float ] ,
tolerance_s : float ,
log_loaded_timestamps : bool = False ,
2025-09-15 14:08:01 +02:00
decoder_cache : VideoDecoderCache | None = None ,
2025-03-14 18:53:42 +03:00
) - > torch . Tensor :
""" Loads frames associated with the requested timestamps of a video using torchcodec.
2025-09-15 14:08:01 +02:00
Args :
video_path : Path to the video file .
timestamps : List of timestamps to extract frames .
tolerance_s : Allowed deviation in seconds for frame retrieval .
log_loaded_timestamps : Whether to log loaded timestamps .
decoder_cache : Optional decoder cache instance . Uses default if None .
2025-03-14 18:53:42 +03:00
Note : Setting device = " cuda " outside the main process , e . g . in data loader workers , will lead to CUDA initialization errors .
Note : Video benefits from inter - frame compression . Instead of storing every frame individually ,
the encoder stores a reference frame ( or a key frame ) and subsequent frames as differences relative to
that key frame . As a consequence , to access a requested frame , we need to load the preceding key frame ,
and all subsequent frames until reaching the requested frame . The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes .
"""
2025-09-15 14:08:01 +02:00
if decoder_cache is None :
decoder_cache = _default_decoder_cache
2025-03-17 13:23:11 +01:00
2025-09-15 14:08:01 +02:00
# Use cached decoder instead of creating new one each time
decoder = decoder_cache . get_decoder ( str ( video_path ) )
2025-03-17 13:23:11 +01:00
2025-03-14 18:53:42 +03:00
loaded_ts = [ ]
2025-09-15 14:08:01 +02:00
loaded_frames = [ ]
2025-03-14 18:53:42 +03:00
# get metadata for frame information
metadata = decoder . metadata
average_fps = metadata . average_fps
# convert timestamps to frame indices
frame_indices = [ round ( ts * average_fps ) for ts in timestamps ]
# retrieve frames based on indices
frames_batch = decoder . get_frames_at ( indices = frame_indices )
2025-09-15 14:08:01 +02:00
for frame , pts in zip ( frames_batch . data , frames_batch . pts_seconds , strict = True ) :
2025-03-14 18:53:42 +03:00
loaded_frames . append ( frame )
loaded_ts . append ( pts . item ( ) )
if log_loaded_timestamps :
2026-03-15 22:12:09 -07:00
logger . info ( f " Frame loaded at timestamp= { pts : .4f } " )
2025-03-14 18:53:42 +03:00
query_ts = torch . tensor ( timestamps )
loaded_ts = torch . tensor ( loaded_ts )
# compute distances between each query timestamp and loaded timestamps
dist = torch . cdist ( query_ts [ : , None ] , loaded_ts [ : , None ] , p = 1 )
min_ , argmin_ = dist . min ( 1 )
is_within_tol = min_ < tolerance_s
2026-02-25 03:29:22 -08:00
if not is_within_tol . all ( ) :
raise FrameTimestampError (
f " One or several query timestamps unexpectedly violate the tolerance ( { min_ [ ~ is_within_tol ] } > { tolerance_s =} ). "
" It means that the closest frame that can be loaded from the video is too far away in time. "
" This might be due to synchronization issues with timestamps during data collection. "
" To be safe, we advise to ignore this item during training. "
f " \n queried timestamps: { query_ts } "
f " \n loaded timestamps: { loaded_ts } "
f " \n video: { video_path } "
)
2025-03-14 18:53:42 +03:00
# get closest frames to the query timestamps
closest_frames = torch . stack ( [ loaded_frames [ idx ] for idx in argmin_ ] )
closest_ts = loaded_ts [ argmin_ ]
if log_loaded_timestamps :
2026-03-15 22:12:09 -07:00
logger . info ( f " { closest_ts =} " )
2025-03-14 18:53:42 +03:00
2025-09-15 14:08:01 +02:00
# convert to float32 in [0,1] range
closest_frames = ( closest_frames / 255.0 ) . type ( torch . float32 )
if not len ( timestamps ) == len ( closest_frames ) :
raise FrameTimestampError (
f " Retrieved timestamps differ from queried { set ( closest_frames ) - set ( timestamps ) } "
)
2025-03-14 18:53:42 +03:00
return closest_frames
2024-07-09 20:20:25 +02:00
def encode_video_frames (
2024-11-29 19:04:00 +01:00
imgs_dir : Path | str ,
video_path : Path | str ,
2024-07-09 20:20:25 +02:00
fps : int ,
2024-07-22 20:08:59 +02:00
vcodec : str = " libsvtav1 " ,
pix_fmt : str = " yuv420p " ,
g : int | None = 2 ,
crf : int | None = 30 ,
2024-07-09 20:20:25 +02:00
fast_decode : int = 0 ,
2026-02-23 13:57:43 +01:00
log_level : int | None = av . logging . WARNING ,
2024-07-09 20:20:25 +02:00
overwrite : bool = False ,
2025-11-25 16:46:12 +01:00
preset : int | None = None ,
2026-02-23 13:57:43 +01:00
encoder_threads : int | None = None ,
2024-07-09 20:20:25 +02:00
) - > None :
""" More info on ffmpeg arguments tuning on `benchmark/video/README.md` """
2026-02-23 13:57:43 +01:00
vcodec = resolve_vcodec ( vcodec )
2025-04-29 17:39:35 +02:00
2024-05-03 00:50:19 +02:00
video_path = Path ( video_path )
2025-03-28 18:08:12 +01:00
imgs_dir = Path ( imgs_dir )
2025-04-29 17:39:35 +02:00
2025-09-15 09:53:30 +02:00
if video_path . exists ( ) and not overwrite :
2026-03-15 22:12:09 -07:00
logger . warning ( f " Video file already exists: { video_path } . Skipping encoding. " )
2025-09-15 09:53:30 +02:00
return
video_path . parent . mkdir ( parents = True , exist_ok = True )
2025-04-29 17:39:35 +02:00
# Encoders/pixel formats incompatibility check
if ( vcodec == " libsvtav1 " or vcodec == " hevc " ) and pix_fmt == " yuv444p " :
2026-03-15 22:12:09 -07:00
logger . warning (
2025-04-29 17:39:35 +02:00
f " Incompatible pixel format ' yuv444p ' for codec { vcodec } , auto-selecting format ' yuv420p ' "
)
pix_fmt = " yuv420p "
# Get input frames
2025-09-15 09:53:30 +02:00
template = " frame- " + ( " [0-9] " * 6 ) + " .png "
2025-04-29 17:39:35 +02:00
input_list = sorted (
2025-09-15 09:53:30 +02:00
glob . glob ( str ( imgs_dir / template ) ) , key = lambda x : int ( x . split ( " - " ) [ - 1 ] . split ( " . " ) [ 0 ] )
2024-05-03 00:50:19 +02:00
)
2024-07-09 20:20:25 +02:00
2025-04-29 17:39:35 +02:00
# Define video output frame size (assuming all input frames are the same size)
if len ( input_list ) == 0 :
raise FileNotFoundError ( f " No images found in { imgs_dir } . " )
2025-10-18 12:47:07 +08:00
with Image . open ( input_list [ 0 ] ) as dummy_image :
width , height = dummy_image . size
2025-04-29 17:39:35 +02:00
# Define video codec options
2026-02-23 13:57:43 +01:00
video_options = _get_codec_options ( vcodec , g , crf , preset )
2024-07-09 20:20:25 +02:00
if fast_decode :
2025-04-29 17:39:35 +02:00
key = " svtav1-params " if vcodec == " libsvtav1 " else " tune "
2024-07-22 20:08:59 +02:00
value = f " fast-decode= { fast_decode } " if vcodec == " libsvtav1 " else " fastdecode "
2025-04-29 17:39:35 +02:00
video_options [ key ] = value
2024-07-09 20:20:25 +02:00
2026-02-23 13:57:43 +01:00
if encoder_threads is not None :
if vcodec == " libsvtav1 " :
lp_param = f " lp= { encoder_threads } "
if " svtav1-params " in video_options :
video_options [ " svtav1-params " ] + = f " : { lp_param } "
else :
video_options [ " svtav1-params " ] = lp_param
else :
video_options [ " threads " ] = str ( encoder_threads )
2025-11-25 16:46:12 +01:00
2025-04-29 17:39:35 +02:00
# Set logging level
2024-07-09 20:20:25 +02:00
if log_level is not None :
2025-09-15 09:53:30 +02:00
# "While less efficient, it is generally preferable to modify logging with Python's logging"
2025-04-29 17:39:35 +02:00
logging . getLogger ( " libav " ) . setLevel ( log_level )
# Create and open output file (overwrite by default)
with av . open ( str ( video_path ) , " w " ) as output :
output_stream = output . add_stream ( vcodec , fps , options = video_options )
output_stream . pix_fmt = pix_fmt
output_stream . width = width
output_stream . height = height
# Loop through input frames and encode them
for input_data in input_list :
2025-10-18 12:47:07 +08:00
with Image . open ( input_data ) as input_image :
input_image = input_image . convert ( " RGB " )
input_frame = av . VideoFrame . from_image ( input_image )
packet = output_stream . encode ( input_frame )
if packet :
output . mux ( packet )
2025-04-29 17:39:35 +02:00
# Flush the encoder
packet = output_stream . encode ( )
if packet :
output . mux ( packet )
# Reset logging level
if log_level is not None :
av . logging . restore_default_callback ( )
2024-05-03 00:50:19 +02:00
2024-08-15 18:11:33 +02:00
if not video_path . exists ( ) :
2025-04-29 17:39:35 +02:00
raise OSError ( f " Video encoding did not work. File not found: { video_path } . " )
2024-08-15 18:11:33 +02:00
2024-05-03 00:50:19 +02:00
2025-09-15 09:53:30 +02:00
def concatenate_video_files (
input_video_paths : list [ Path | str ] , output_video_path : Path , overwrite : bool = True
) :
"""
Concatenate multiple video files into a single video file using pyav .
This function takes a list of video input file paths and concatenates them into a single
output video file . It uses ffmpeg ' s concat demuxer with stream copy mode for fast
concatenation without re - encoding .
Args :
input_video_paths : Ordered list of input video file paths to concatenate .
output_video_path : Path to the output video file .
overwrite : Whether to overwrite the output video file if it already exists . Default is True .
Note :
- Creates a temporary directory for intermediate files that is cleaned up after use .
- Uses ffmpeg ' s concat demuxer which requires all input videos to have the same
codec , resolution , and frame rate for proper concatenation .
"""
output_video_path = Path ( output_video_path )
if output_video_path . exists ( ) and not overwrite :
2026-03-15 22:12:09 -07:00
logger . warning ( f " Video file already exists: { output_video_path } . Skipping concatenation. " )
2025-09-15 09:53:30 +02:00
return
output_video_path . parent . mkdir ( parents = True , exist_ok = True )
if len ( input_video_paths ) == 0 :
raise FileNotFoundError ( " No input video paths provided. " )
# Create a temporary .ffconcat file to list the input video paths
with tempfile . NamedTemporaryFile ( mode = " w " , suffix = " .ffconcat " , delete = False ) as tmp_concatenate_file :
tmp_concatenate_file . write ( " ffconcat version 1.0 \n " )
for input_path in input_video_paths :
2025-09-28 20:18:22 +08:00
tmp_concatenate_file . write ( f " file ' { str ( input_path . resolve ( ) ) } ' \n " )
2025-09-15 09:53:30 +02:00
tmp_concatenate_file . flush ( )
tmp_concatenate_path = tmp_concatenate_file . name
# Create input and output containers
input_container = av . open (
tmp_concatenate_path , mode = " r " , format = " concat " , options = { " safe " : " 0 " }
) # safe = 0 allows absolute paths as well as relative paths
2025-09-29 15:06:56 +02:00
with tempfile . NamedTemporaryFile ( suffix = " .mp4 " , delete = False ) as tmp_named_file :
tmp_output_video_path = tmp_named_file . name
2025-09-15 09:53:30 +02:00
output_container = av . open (
tmp_output_video_path , mode = " w " , options = { " movflags " : " faststart " }
) # faststart is to move the metadata to the beginning of the file to speed up loading
# Replicate input streams in output container
stream_map = { }
for input_stream in input_container . streams :
if input_stream . type in ( " video " , " audio " , " subtitle " ) : # only copy compatible streams
stream_map [ input_stream . index ] = output_container . add_stream_from_template (
template = input_stream , opaque = True
)
2025-10-10 12:32:07 +02:00
# set the time base to the input stream time base (missing in the codec context)
stream_map [ input_stream . index ] . time_base = input_stream . time_base
2025-09-15 09:53:30 +02:00
# Demux + remux packets (no re-encode)
for packet in input_container . demux ( ) :
# Skip packets from un-mapped streams
if packet . stream . index not in stream_map :
continue
# Skip demux flushing packets
if packet . dts is None :
continue
output_stream = stream_map [ packet . stream . index ]
packet . stream = output_stream
output_container . mux ( packet )
input_container . close ( )
output_container . close ( )
shutil . move ( tmp_output_video_path , output_video_path )
Path ( tmp_concatenate_path ) . unlink ( )
2026-02-23 13:57:43 +01:00
class _CameraEncoderThread ( threading . Thread ) :
""" A thread that encodes video frames streamed via a queue into an MP4 file.
One instance is created per camera per episode . Frames are received as numpy arrays
from the main thread , encoded in real - time using PyAV ( which releases the GIL during
encoding ) , and written to disk . Stats are computed incrementally using
RunningQuantileStats and returned via result_queue .
"""
def __init__ (
self ,
video_path : Path ,
fps : int ,
vcodec : str ,
pix_fmt : str ,
g : int | None ,
crf : int | None ,
preset : int | None ,
frame_queue : queue . Queue ,
result_queue : queue . Queue ,
stop_event : threading . Event ,
encoder_threads : int | None = None ,
) :
super ( ) . __init__ ( daemon = True )
self . video_path = video_path
self . fps = fps
self . vcodec = vcodec
self . pix_fmt = pix_fmt
self . g = g
self . crf = crf
self . preset = preset
self . frame_queue = frame_queue
self . result_queue = result_queue
self . stop_event = stop_event
self . encoder_threads = encoder_threads
def run ( self ) - > None :
from lerobot . datasets . compute_stats import RunningQuantileStats , auto_downsample_height_width
container = None
output_stream = None
stats_tracker = RunningQuantileStats ( )
frame_count = 0
try :
logging . getLogger ( " libav " ) . setLevel ( av . logging . WARNING )
while True :
try :
frame_data = self . frame_queue . get ( timeout = 1 )
except queue . Empty :
if self . stop_event . is_set ( ) :
break
continue
if frame_data is None :
# Sentinel: flush and close
break
# Ensure HWC uint8 numpy array
if isinstance ( frame_data , np . ndarray ) :
if frame_data . ndim == 3 and frame_data . shape [ 0 ] == 3 :
# CHW -> HWC
frame_data = frame_data . transpose ( 1 , 2 , 0 )
if frame_data . dtype != np . uint8 :
frame_data = ( frame_data * 255 ) . astype ( np . uint8 )
# Open container on first frame (to get width/height)
if container is None :
height , width = frame_data . shape [ : 2 ]
video_options = _get_codec_options ( self . vcodec , self . g , self . crf , self . preset )
if self . encoder_threads is not None :
if self . vcodec == " libsvtav1 " :
lp_param = f " lp= { self . encoder_threads } "
if " svtav1-params " in video_options :
video_options [ " svtav1-params " ] + = f " : { lp_param } "
else :
video_options [ " svtav1-params " ] = lp_param
else :
video_options [ " threads " ] = str ( self . encoder_threads )
Path ( self . video_path ) . parent . mkdir ( parents = True , exist_ok = True )
container = av . open ( str ( self . video_path ) , " w " )
output_stream = container . add_stream ( self . vcodec , self . fps , options = video_options )
output_stream . pix_fmt = self . pix_fmt
output_stream . width = width
output_stream . height = height
output_stream . time_base = Fraction ( 1 , self . fps )
# Encode frame with explicit timestamps
pil_img = Image . fromarray ( frame_data )
video_frame = av . VideoFrame . from_image ( pil_img )
video_frame . pts = frame_count
video_frame . time_base = Fraction ( 1 , self . fps )
packet = output_stream . encode ( video_frame )
if packet :
container . mux ( packet )
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
img_chw = frame_data . transpose ( 2 , 0 , 1 ) # HWC -> CHW
img_downsampled = auto_downsample_height_width ( img_chw )
# Reshape CHW to (H*W, C) for per-channel stats
channels = img_downsampled . shape [ 0 ]
img_for_stats = img_downsampled . transpose ( 1 , 2 , 0 ) . reshape ( - 1 , channels )
stats_tracker . update ( img_for_stats )
frame_count + = 1
# Flush encoder
if output_stream is not None :
packet = output_stream . encode ( )
if packet :
container . mux ( packet )
if container is not None :
container . close ( )
av . logging . restore_default_callback ( )
# Get stats and put on result queue
if frame_count > = 2 :
stats = stats_tracker . get_statistics ( )
self . result_queue . put ( ( " ok " , stats ) )
else :
self . result_queue . put ( ( " ok " , None ) )
except Exception as e :
2026-03-15 22:12:09 -07:00
logger . error ( f " Encoder thread error: { e } " )
2026-02-23 13:57:43 +01:00
if container is not None :
with contextlib . suppress ( Exception ) :
container . close ( )
self . result_queue . put ( ( " error " , str ( e ) ) )
class StreamingVideoEncoder :
""" Manages per-camera encoder threads for real-time video encoding during recording.
Instead of writing frames as PNG images and then encoding to MP4 at episode end ,
this class streams frames directly to encoder threads , eliminating the
PNG round - trip and making save_episode ( ) near - instant .
Uses threading instead of multiprocessing to avoid the overhead of pickling large
numpy arrays through multiprocessing . Queue . PyAV ' s encode() releases the GIL,
so encoding runs in parallel with the main recording loop .
"""
def __init__ (
self ,
fps : int ,
vcodec : str = " libsvtav1 " ,
pix_fmt : str = " yuv420p " ,
g : int | None = 2 ,
crf : int | None = 30 ,
preset : int | None = None ,
queue_maxsize : int = 30 ,
encoder_threads : int | None = None ,
) :
self . fps = fps
self . vcodec = resolve_vcodec ( vcodec )
self . pix_fmt = pix_fmt
self . g = g
self . crf = crf
self . preset = preset
self . queue_maxsize = queue_maxsize
self . encoder_threads = encoder_threads
self . _frame_queues : dict [ str , queue . Queue ] = { }
self . _result_queues : dict [ str , queue . Queue ] = { }
self . _threads : dict [ str , _CameraEncoderThread ] = { }
self . _stop_events : dict [ str , threading . Event ] = { }
self . _video_paths : dict [ str , Path ] = { }
self . _dropped_frames : dict [ str , int ] = { }
self . _episode_active = False
def start_episode ( self , video_keys : list [ str ] , temp_dir : Path ) - > None :
""" Start encoder threads for a new episode.
Args :
video_keys : List of video feature keys ( e . g . [ " observation.images.laptop " ] )
temp_dir : Base directory for temporary MP4 files
"""
if self . _episode_active :
self . cancel_episode ( )
self . _dropped_frames . clear ( )
for video_key in video_keys :
frame_queue : queue . Queue = queue . Queue ( maxsize = self . queue_maxsize )
result_queue : queue . Queue = queue . Queue ( maxsize = 1 )
stop_event = threading . Event ( )
temp_video_dir = Path ( tempfile . mkdtemp ( dir = temp_dir ) )
video_path = temp_video_dir / f " { video_key . replace ( ' / ' , ' _ ' ) } _streaming.mp4 "
encoder_thread = _CameraEncoderThread (
video_path = video_path ,
fps = self . fps ,
vcodec = self . vcodec ,
pix_fmt = self . pix_fmt ,
g = self . g ,
crf = self . crf ,
preset = self . preset ,
frame_queue = frame_queue ,
result_queue = result_queue ,
stop_event = stop_event ,
encoder_threads = self . encoder_threads ,
)
encoder_thread . start ( )
self . _frame_queues [ video_key ] = frame_queue
self . _result_queues [ video_key ] = result_queue
self . _threads [ video_key ] = encoder_thread
self . _stop_events [ video_key ] = stop_event
self . _video_paths [ video_key ] = video_path
self . _episode_active = True
def feed_frame ( self , video_key : str , image : np . ndarray ) - > None :
""" Feed a frame to the encoder for a specific camera.
A copy of the image is made before enqueueing to prevent race conditions
with camera drivers that may reuse buffers . If the encoder queue is full
( encoder can ' t keep up), the frame is dropped with a warning instead of
crashing the recording session .
Args :
video_key : The video feature key
image : numpy array in ( H , W , C ) or ( C , H , W ) format , uint8 or float
Raises :
RuntimeError : If the encoder thread has crashed
"""
if not self . _episode_active :
raise RuntimeError ( " No active episode. Call start_episode() first. " )
thread = self . _threads [ video_key ]
if not thread . is_alive ( ) :
# Check for error
try :
status , msg = self . _result_queues [ video_key ] . get_nowait ( )
if status == " error " :
raise RuntimeError ( f " Encoder thread for { video_key } crashed: { msg } " )
except queue . Empty :
pass
raise RuntimeError ( f " Encoder thread for { video_key } is not alive " )
try :
self . _frame_queues [ video_key ] . put ( image . copy ( ) , timeout = 0.1 )
except queue . Full :
self . _dropped_frames [ video_key ] = self . _dropped_frames . get ( video_key , 0 ) + 1
count = self . _dropped_frames [ video_key ]
# Log periodically to avoid spam (1st, then every 10th)
if count == 1 or count % 10 == 0 :
2026-03-15 22:12:09 -07:00
logger . warning (
2026-02-23 13:57:43 +01:00
f " Encoder queue full for { video_key } , dropped { count } frame(s). "
f " Consider using vcodec= ' auto ' for hardware encoding or increasing encoder_queue_maxsize. "
)
def finish_episode ( self ) - > dict [ str , tuple [ Path , dict | None ] ] :
""" Finish encoding the current episode.
Sends sentinel values , waits for encoder threads to complete ,
and collects results .
Returns :
Dict mapping video_key to ( mp4_path , stats_dict_or_None )
"""
if not self . _episode_active :
raise RuntimeError ( " No active episode to finish. " )
results = { }
# Report dropped frames
for video_key , count in self . _dropped_frames . items ( ) :
if count > 0 :
2026-03-15 22:12:09 -07:00
logger . warning ( f " Episode finished with { count } dropped frame(s) for { video_key } . " )
2026-02-23 13:57:43 +01:00
# Send sentinel to all queues
for video_key in self . _frame_queues :
self . _frame_queues [ video_key ] . put ( None )
# Wait for all threads and collect results
for video_key in self . _threads :
self . _threads [ video_key ] . join ( timeout = 120 )
if self . _threads [ video_key ] . is_alive ( ) :
2026-03-15 22:12:09 -07:00
logger . error ( f " Encoder thread for { video_key } did not finish in time " )
2026-02-23 13:57:43 +01:00
self . _stop_events [ video_key ] . set ( )
self . _threads [ video_key ] . join ( timeout = 5 )
results [ video_key ] = ( self . _video_paths [ video_key ] , None )
continue
try :
status , data = self . _result_queues [ video_key ] . get ( timeout = 5 )
if status == " error " :
raise RuntimeError ( f " Encoder thread for { video_key } failed: { data } " )
results [ video_key ] = ( self . _video_paths [ video_key ] , data )
except queue . Empty :
2026-03-15 22:12:09 -07:00
logger . error ( f " No result from encoder thread for { video_key } " )
2026-02-23 13:57:43 +01:00
results [ video_key ] = ( self . _video_paths [ video_key ] , None )
self . _cleanup ( )
self . _episode_active = False
return results
def cancel_episode ( self ) - > None :
""" Cancel the current episode, stopping encoder threads and cleaning up. """
if not self . _episode_active :
return
# Signal all threads to stop
for video_key in self . _stop_events :
self . _stop_events [ video_key ] . set ( )
# Wait for threads to finish
for video_key in self . _threads :
self . _threads [ video_key ] . join ( timeout = 5 )
# Clean up temp MP4 files
video_path = self . _video_paths . get ( video_key )
if video_path is not None and video_path . exists ( ) :
shutil . rmtree ( str ( video_path . parent ) , ignore_errors = True )
self . _cleanup ( )
self . _episode_active = False
def close ( self ) - > None :
""" Close the encoder, canceling any in-progress episode. """
if self . _episode_active :
self . cancel_episode ( )
def _cleanup ( self ) - > None :
""" Clean up queues and thread tracking dicts. """
for q in self . _frame_queues . values ( ) :
with contextlib . suppress ( Exception ) :
while not q . empty ( ) :
q . get_nowait ( )
self . _frame_queues . clear ( )
self . _result_queues . clear ( )
self . _threads . clear ( )
self . _stop_events . clear ( )
self . _video_paths . clear ( )
2024-05-03 00:50:19 +02:00
@dataclass
class VideoFrame :
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
"""
Provides a type for a dataset containing video frames .
Example :
` ` ` python
data_dict = [ { " image " : { " path " : " videos/episode_0.mp4 " , " timestamp " : 0.3 } } ]
features = { " image " : VideoFrame ( ) }
Dataset . from_dict ( data_dict , features = Features ( features ) )
` ` `
"""
pa_type : ClassVar [ Any ] = pa . struct ( { " path " : pa . string ( ) , " timestamp " : pa . float32 ( ) } )
_type : str = field ( default = " VideoFrame " , init = False , repr = False )
def __call__ ( self ) :
return self . pa_type
with warnings . catch_warnings ( ) :
warnings . filterwarnings (
" ignore " ,
" ' register_feature ' is experimental and might be subject to breaking changes in the future. " ,
category = UserWarning ,
)
# to make VideoFrame available in HuggingFace `datasets`
register_feature ( VideoFrame , " VideoFrame " )
2024-11-29 19:04:00 +01:00
def get_audio_info ( video_path : Path | str ) - > dict :
2025-04-29 17:39:35 +02:00
# Set logging level
2026-02-23 13:57:43 +01:00
logging . getLogger ( " libav " ) . setLevel ( av . logging . WARNING )
2025-04-29 17:39:35 +02:00
# Getting audio stream information
audio_info = { }
with av . open ( str ( video_path ) , " r " ) as audio_file :
try :
audio_stream = audio_file . streams . audio [ 0 ]
except IndexError :
# Reset logging level
av . logging . restore_default_callback ( )
return { " has_audio " : False }
audio_info [ " audio.channels " ] = audio_stream . channels
audio_info [ " audio.codec " ] = audio_stream . codec . canonical_name
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
audio_info [ " audio.bit_rate " ] = audio_stream . bit_rate
audio_info [ " audio.sample_rate " ] = audio_stream . sample_rate # Number of samples per second
# In an ideal loseless case : fixed number of bits per sample.
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
audio_info [ " audio.bit_depth " ] = audio_stream . format . bits
audio_info [ " audio.channel_layout " ] = audio_stream . layout . name
audio_info [ " has_audio " ] = True
# Reset logging level
av . logging . restore_default_callback ( )
return audio_info
2024-11-29 19:04:00 +01:00
def get_video_info ( video_path : Path | str ) - > dict :
2025-04-29 17:39:35 +02:00
# Set logging level
2026-02-23 13:57:43 +01:00
logging . getLogger ( " libav " ) . setLevel ( av . logging . WARNING )
2025-04-29 17:39:35 +02:00
# Getting video stream information
video_info = { }
with av . open ( str ( video_path ) , " r " ) as video_file :
try :
video_stream = video_file . streams . video [ 0 ]
except IndexError :
# Reset logging level
av . logging . restore_default_callback ( )
return { }
video_info [ " video.height " ] = video_stream . height
video_info [ " video.width " ] = video_stream . width
video_info [ " video.codec " ] = video_stream . codec . canonical_name
video_info [ " video.pix_fmt " ] = video_stream . pix_fmt
video_info [ " video.is_depth_map " ] = False
# Calculate fps from r_frame_rate
video_info [ " video.fps " ] = int ( video_stream . base_rate )
pixel_channels = get_video_pixel_channels ( video_stream . pix_fmt )
video_info [ " video.channels " ] = pixel_channels
# Reset logging level
av . logging . restore_default_callback ( )
# Adding audio stream information
video_info . update ( * * get_audio_info ( video_path ) )
2024-11-29 19:04:00 +01:00
return video_info
def get_video_pixel_channels ( pix_fmt : str ) - > int :
if " gray " in pix_fmt or " depth " in pix_fmt or " monochrome " in pix_fmt :
return 1
elif " rgba " in pix_fmt or " yuva " in pix_fmt :
return 4
elif " rgb " in pix_fmt or " yuv " in pix_fmt :
return 3
else :
raise ValueError ( " Unknown format " )
2025-09-15 09:53:30 +02:00
def get_video_duration_in_s ( video_path : Path | str ) - > float :
"""
Get the duration of a video file in seconds using PyAV .
Args :
video_path : Path to the video file .
Returns :
Duration of the video in seconds .
"""
with av . open ( str ( video_path ) ) as container :
# Get the first video stream
video_stream = container . streams . video [ 0 ]
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
if video_stream . duration is not None :
duration = float ( video_stream . duration * video_stream . time_base )
else :
# Fallback to container duration if stream duration is not available
duration = float ( container . duration / av . time_base )
return duration
2025-07-18 19:18:52 +09:00
class VideoEncodingManager :
"""
Context manager that ensures proper video encoding and data cleanup even if exceptions occur .
This manager handles :
- Batch encoding for any remaining episodes when recording interrupted
- Cleaning up temporary image files from interrupted episodes
- Removing empty image directories
Args :
dataset : The LeRobotDataset instance
"""
def __init__ ( self , dataset ) :
self . dataset = dataset
def __enter__ ( self ) :
return self
def __exit__ ( self , exc_type , exc_val , exc_tb ) :
2026-02-23 13:57:43 +01:00
streaming_encoder = getattr ( self . dataset , " _streaming_encoder " , None )
if streaming_encoder is not None :
# Handle streaming encoder cleanup
if exc_type is not None :
streaming_encoder . cancel_episode ( )
streaming_encoder . close ( )
elif self . dataset . episodes_since_last_encoding > 0 :
# Handle any remaining episodes that haven't been batch encoded
2025-07-18 19:18:52 +09:00
if exc_type is not None :
2026-03-15 22:12:09 -07:00
logger . info ( " Exception occurred. Encoding remaining episodes before exit... " )
2025-07-18 19:18:52 +09:00
else :
2026-03-15 22:12:09 -07:00
logger . info ( " Recording stopped. Encoding remaining episodes... " )
2025-07-18 19:18:52 +09:00
start_ep = self . dataset . num_episodes - self . dataset . episodes_since_last_encoding
end_ep = self . dataset . num_episodes
2026-03-15 22:12:09 -07:00
logger . info (
2025-07-18 19:18:52 +09:00
f " Encoding remaining { self . dataset . episodes_since_last_encoding } episodes, "
f " from episode { start_ep } to { end_ep - 1 } "
)
2025-09-15 09:53:30 +02:00
self . dataset . _batch_save_episode_video ( start_ep , end_ep )
2025-07-18 19:18:52 +09:00
2025-10-11 11:01:30 +02:00
# Finalize the dataset to properly close all writers
self . dataset . finalize ( )
2026-02-23 13:57:43 +01:00
# Clean up episode images if recording was interrupted (only for non-streaming mode)
if exc_type is not None and streaming_encoder is None :
2025-07-18 19:18:52 +09:00
interrupted_episode_index = self . dataset . num_episodes
for key in self . dataset . meta . video_keys :
img_dir = self . dataset . _get_image_file_path (
episode_index = interrupted_episode_index , image_key = key , frame_index = 0
) . parent
if img_dir . exists ( ) :
2026-03-15 22:12:09 -07:00
logger . debug (
2025-07-18 19:18:52 +09:00
f " Cleaning up interrupted episode images for episode { interrupted_episode_index } , camera { key } "
)
shutil . rmtree ( img_dir )
# Clean up any remaining images directory if it's empty
img_dir = self . dataset . root / " images "
2026-02-23 13:57:43 +01:00
if img_dir . exists ( ) :
png_files = list ( img_dir . rglob ( " *.png " ) )
if len ( png_files ) == 0 :
2025-07-18 19:18:52 +09:00
shutil . rmtree ( img_dir )
2026-03-15 22:12:09 -07:00
logger . debug ( " Cleaned up empty images directory " )
2026-02-23 13:57:43 +01:00
else :
2026-03-15 22:12:09 -07:00
logger . debug ( f " Images directory is not empty, containing { len ( png_files ) } PNG files " )
2025-07-18 19:18:52 +09:00
return False # Don't suppress the original exception