2024-05-15 12:13:09 +02:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2025-04-29 17:39:35 +02:00
import glob
2025-03-17 13:23:11 +01:00
import importlib
2024-05-03 00:50:19 +02:00
import logging
2025-07-18 19:18:52 +09:00
import shutil
2024-05-03 00:50:19 +02:00
import warnings
from dataclasses import dataclass , field
from pathlib import Path
from typing import Any , ClassVar
2025-04-29 17:39:35 +02:00
import av
2024-05-03 00:50:19 +02:00
import pyarrow as pa
import torch
import torchvision
from datasets . features . features import register_feature
2024-11-29 19:04:00 +01:00
from PIL import Image
2025-03-17 13:23:11 +01:00
def get_safe_default_codec ( ) :
if importlib . util . find_spec ( " torchcodec " ) :
return " torchcodec "
else :
logging . warning (
" ' torchcodec ' is not available in your platform, falling back to ' pyav ' as a default decoder "
)
return " pyav "
2025-03-14 18:53:42 +03:00
def decode_video_frames (
video_path : Path | str ,
timestamps : list [ float ] ,
tolerance_s : float ,
2025-03-17 13:23:11 +01:00
backend : str | None = None ,
2025-03-14 18:53:42 +03:00
) - > torch . Tensor :
"""
Decodes video frames using the specified backend .
Args :
video_path ( Path ) : Path to the video file .
timestamps ( list [ float ] ) : List of timestamps to extract frames .
tolerance_s ( float ) : Allowed deviation in seconds for frame retrieval .
2025-03-17 13:23:11 +01:00
backend ( str , optional ) : Backend to use for decoding . Defaults to " torchcodec " when available in the platform ; otherwise , defaults to " pyav " . .
2025-03-14 18:53:42 +03:00
Returns :
torch . Tensor : Decoded frames .
Currently supports torchcodec on cpu and pyav .
"""
2025-03-17 13:23:11 +01:00
if backend is None :
backend = get_safe_default_codec ( )
2025-03-14 18:53:42 +03:00
if backend == " torchcodec " :
return decode_video_frames_torchcodec ( video_path , timestamps , tolerance_s )
elif backend in [ " pyav " , " video_reader " ] :
return decode_video_frames_torchvision ( video_path , timestamps , tolerance_s , backend )
else :
raise ValueError ( f " Unsupported video backend: { backend } " )
2024-05-03 00:50:19 +02:00
def decode_video_frames_torchvision (
2024-11-29 19:04:00 +01:00
video_path : Path | str ,
2024-05-03 00:50:19 +02:00
timestamps : list [ float ] ,
tolerance_s : float ,
2024-06-19 17:15:25 +02:00
backend : str = " pyav " ,
2024-05-03 00:50:19 +02:00
log_loaded_timestamps : bool = False ,
2024-07-09 20:20:25 +02:00
) - > torch . Tensor :
2024-05-03 00:50:19 +02:00
""" Loads frames associated to the requested timestamps of a video
2024-06-19 17:15:25 +02:00
The backend can be either " pyav " ( default ) or " video_reader " .
" video_reader " requires installing torchvision from source , see :
https : / / github . com / pytorch / vision / blob / main / torchvision / csrc / io / decoder / gpu / README . rst
( note that you need to compile against ffmpeg < 4.3 )
2024-07-09 20:20:25 +02:00
While both use cpu , " video_reader " is supposedly faster than " pyav " but requires additional setup .
For more info on video decoding , see ` benchmark / video / README . md `
2024-06-19 17:15:25 +02:00
See torchvision doc for more info on these two backends :
https : / / pytorch . org / vision / 0.18 / index . html ? highlight = backend #torchvision.set_video_backend
2024-05-03 00:50:19 +02:00
Note : Video benefits from inter - frame compression . Instead of storing every frame individually ,
the encoder stores a reference frame ( or a key frame ) and subsequent frames as differences relative to
that key frame . As a consequence , to access a requested frame , we need to load the preceding key frame ,
and all subsequent frames until reaching the requested frame . The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes .
"""
video_path = str ( video_path )
# set backend
keyframes_only = False
2024-06-19 17:15:25 +02:00
torchvision . set_video_backend ( backend )
if backend == " pyav " :
2025-05-25 17:20:45 +02:00
keyframes_only = True # pyav doesn't support accurate seek
2024-05-03 00:50:19 +02:00
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision . io . VideoReader ( video_path , " video " )
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
2025-02-25 15:27:29 +01:00
first_ts = min ( timestamps )
last_ts = max ( timestamps )
2024-05-03 00:50:19 +02:00
# access closest key frame of the first requested frame
2025-02-25 23:51:15 +01:00
# Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video)
2024-05-03 00:50:19 +02:00
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader . seek ( first_ts , keyframes_only = keyframes_only )
# load all frames until last requested frame
loaded_frames = [ ]
loaded_ts = [ ]
for frame in reader :
current_ts = frame [ " pts " ]
if log_loaded_timestamps :
logging . info ( f " frame loaded at timestamp= { current_ts : .4f } " )
loaded_frames . append ( frame [ " data " ] )
loaded_ts . append ( current_ts )
if current_ts > = last_ts :
break
2024-06-19 17:15:25 +02:00
if backend == " pyav " :
reader . container . close ( )
2024-05-03 00:50:19 +02:00
reader = None
query_ts = torch . tensor ( timestamps )
loaded_ts = torch . tensor ( loaded_ts )
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch . cdist ( query_ts [ : , None ] , loaded_ts [ : , None ] , p = 1 )
min_ , argmin_ = dist . min ( 1 )
is_within_tol = min_ < tolerance_s
assert is_within_tol . all ( ) , (
f " One or several query timestamps unexpectedly violate the tolerance ( { min_ [ ~ is_within_tol ] } > { tolerance_s =} ). "
" It means that the closest frame that can be loaded from the video is too far away in time. "
" This might be due to synchronization issues with timestamps during data collection. "
" To be safe, we advise to ignore this item during training. "
2024-07-09 20:20:25 +02:00
f " \n queried timestamps: { query_ts } "
f " \n loaded timestamps: { loaded_ts } "
f " \n video: { video_path } "
f " \n backend: { backend } "
2024-05-03 00:50:19 +02:00
)
# get closest frames to the query timestamps
closest_frames = torch . stack ( [ loaded_frames [ idx ] for idx in argmin_ ] )
closest_ts = loaded_ts [ argmin_ ]
if log_loaded_timestamps :
logging . info ( f " { closest_ts =} " )
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames . type ( torch . float32 ) / 255
assert len ( timestamps ) == len ( closest_frames )
return closest_frames
2025-03-14 18:53:42 +03:00
def decode_video_frames_torchcodec (
video_path : Path | str ,
timestamps : list [ float ] ,
tolerance_s : float ,
device : str = " cpu " ,
log_loaded_timestamps : bool = False ,
) - > torch . Tensor :
""" Loads frames associated with the requested timestamps of a video using torchcodec.
Note : Setting device = " cuda " outside the main process , e . g . in data loader workers , will lead to CUDA initialization errors .
Note : Video benefits from inter - frame compression . Instead of storing every frame individually ,
the encoder stores a reference frame ( or a key frame ) and subsequent frames as differences relative to
that key frame . As a consequence , to access a requested frame , we need to load the preceding key frame ,
and all subsequent frames until reaching the requested frame . The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes .
"""
2025-03-17 13:23:11 +01:00
if importlib . util . find_spec ( " torchcodec " ) :
from torchcodec . decoders import VideoDecoder
else :
raise ImportError ( " torchcodec is required but not available. " )
2025-03-14 18:53:42 +03:00
# initialize video decoder
decoder = VideoDecoder ( video_path , device = device , seek_mode = " approximate " )
loaded_frames = [ ]
loaded_ts = [ ]
# get metadata for frame information
metadata = decoder . metadata
average_fps = metadata . average_fps
# convert timestamps to frame indices
frame_indices = [ round ( ts * average_fps ) for ts in timestamps ]
# retrieve frames based on indices
frames_batch = decoder . get_frames_at ( indices = frame_indices )
for frame , pts in zip ( frames_batch . data , frames_batch . pts_seconds , strict = False ) :
loaded_frames . append ( frame )
loaded_ts . append ( pts . item ( ) )
if log_loaded_timestamps :
logging . info ( f " Frame loaded at timestamp= { pts : .4f } " )
query_ts = torch . tensor ( timestamps )
loaded_ts = torch . tensor ( loaded_ts )
# compute distances between each query timestamp and loaded timestamps
dist = torch . cdist ( query_ts [ : , None ] , loaded_ts [ : , None ] , p = 1 )
min_ , argmin_ = dist . min ( 1 )
is_within_tol = min_ < tolerance_s
assert is_within_tol . all ( ) , (
f " One or several query timestamps unexpectedly violate the tolerance ( { min_ [ ~ is_within_tol ] } > { tolerance_s =} ). "
" It means that the closest frame that can be loaded from the video is too far away in time. "
" This might be due to synchronization issues with timestamps during data collection. "
" To be safe, we advise to ignore this item during training. "
f " \n queried timestamps: { query_ts } "
f " \n loaded timestamps: { loaded_ts } "
f " \n video: { video_path } "
)
# get closest frames to the query timestamps
closest_frames = torch . stack ( [ loaded_frames [ idx ] for idx in argmin_ ] )
closest_ts = loaded_ts [ argmin_ ]
if log_loaded_timestamps :
logging . info ( f " { closest_ts =} " )
# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames . type ( torch . float32 ) / 255
assert len ( timestamps ) == len ( closest_frames )
return closest_frames
2024-07-09 20:20:25 +02:00
def encode_video_frames (
2024-11-29 19:04:00 +01:00
imgs_dir : Path | str ,
video_path : Path | str ,
2024-07-09 20:20:25 +02:00
fps : int ,
2024-07-22 20:08:59 +02:00
vcodec : str = " libsvtav1 " ,
pix_fmt : str = " yuv420p " ,
g : int | None = 2 ,
crf : int | None = 30 ,
2024-07-09 20:20:25 +02:00
fast_decode : int = 0 ,
2025-04-29 17:39:35 +02:00
log_level : int | None = av . logging . ERROR ,
2024-07-09 20:20:25 +02:00
overwrite : bool = False ,
) - > None :
""" More info on ffmpeg arguments tuning on `benchmark/video/README.md` """
2025-04-29 17:39:35 +02:00
# Check encoder availability
if vcodec not in [ " h264 " , " hevc " , " libsvtav1 " ] :
raise ValueError ( f " Unsupported video codec: { vcodec } . Supported codecs are: h264, hevc, libsvtav1. " )
2024-05-03 00:50:19 +02:00
video_path = Path ( video_path )
2025-03-28 18:08:12 +01:00
imgs_dir = Path ( imgs_dir )
2025-04-29 17:39:35 +02:00
video_path . parent . mkdir ( parents = True , exist_ok = overwrite )
# Encoders/pixel formats incompatibility check
if ( vcodec == " libsvtav1 " or vcodec == " hevc " ) and pix_fmt == " yuv444p " :
logging . warning (
f " Incompatible pixel format ' yuv444p ' for codec { vcodec } , auto-selecting format ' yuv420p ' "
)
pix_fmt = " yuv420p "
# Get input frames
template = " frame_ " + ( " [0-9] " * 6 ) + " .png "
input_list = sorted (
glob . glob ( str ( imgs_dir / template ) ) , key = lambda x : int ( x . split ( " _ " ) [ - 1 ] . split ( " . " ) [ 0 ] )
2024-05-03 00:50:19 +02:00
)
2024-07-09 20:20:25 +02:00
2025-04-29 17:39:35 +02:00
# Define video output frame size (assuming all input frames are the same size)
if len ( input_list ) == 0 :
raise FileNotFoundError ( f " No images found in { imgs_dir } . " )
dummy_image = Image . open ( input_list [ 0 ] )
width , height = dummy_image . size
# Define video codec options
video_options = { }
2024-07-22 20:08:59 +02:00
if g is not None :
2025-04-29 17:39:35 +02:00
video_options [ " g " ] = str ( g )
2024-07-09 20:20:25 +02:00
2024-07-22 20:08:59 +02:00
if crf is not None :
2025-04-29 17:39:35 +02:00
video_options [ " crf " ] = str ( crf )
2024-07-09 20:20:25 +02:00
if fast_decode :
2025-04-29 17:39:35 +02:00
key = " svtav1-params " if vcodec == " libsvtav1 " else " tune "
2024-07-22 20:08:59 +02:00
value = f " fast-decode= { fast_decode } " if vcodec == " libsvtav1 " else " fastdecode "
2025-04-29 17:39:35 +02:00
video_options [ key ] = value
2024-07-09 20:20:25 +02:00
2025-04-29 17:39:35 +02:00
# Set logging level
2024-07-09 20:20:25 +02:00
if log_level is not None :
2025-04-29 17:39:35 +02:00
# "While less efficient, it is generally preferable to modify logging with Python’ s logging"
logging . getLogger ( " libav " ) . setLevel ( log_level )
# Create and open output file (overwrite by default)
with av . open ( str ( video_path ) , " w " ) as output :
output_stream = output . add_stream ( vcodec , fps , options = video_options )
output_stream . pix_fmt = pix_fmt
output_stream . width = width
output_stream . height = height
# Loop through input frames and encode them
for input_data in input_list :
input_image = Image . open ( input_data ) . convert ( " RGB " )
input_frame = av . VideoFrame . from_image ( input_image )
packet = output_stream . encode ( input_frame )
if packet :
output . mux ( packet )
# Flush the encoder
packet = output_stream . encode ( )
if packet :
output . mux ( packet )
# Reset logging level
if log_level is not None :
av . logging . restore_default_callback ( )
2024-05-03 00:50:19 +02:00
2024-08-15 18:11:33 +02:00
if not video_path . exists ( ) :
2025-04-29 17:39:35 +02:00
raise OSError ( f " Video encoding did not work. File not found: { video_path } . " )
2024-08-15 18:11:33 +02:00
2024-05-03 00:50:19 +02:00
@dataclass
class VideoFrame :
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
"""
Provides a type for a dataset containing video frames .
Example :
` ` ` python
data_dict = [ { " image " : { " path " : " videos/episode_0.mp4 " , " timestamp " : 0.3 } } ]
features = { " image " : VideoFrame ( ) }
Dataset . from_dict ( data_dict , features = Features ( features ) )
` ` `
"""
pa_type : ClassVar [ Any ] = pa . struct ( { " path " : pa . string ( ) , " timestamp " : pa . float32 ( ) } )
_type : str = field ( default = " VideoFrame " , init = False , repr = False )
def __call__ ( self ) :
return self . pa_type
with warnings . catch_warnings ( ) :
warnings . filterwarnings (
" ignore " ,
" ' register_feature ' is experimental and might be subject to breaking changes in the future. " ,
category = UserWarning ,
)
# to make VideoFrame available in HuggingFace `datasets`
register_feature ( VideoFrame , " VideoFrame " )
2024-11-29 19:04:00 +01:00
def get_audio_info ( video_path : Path | str ) - > dict :
2025-04-29 17:39:35 +02:00
# Set logging level
logging . getLogger ( " libav " ) . setLevel ( av . logging . ERROR )
# Getting audio stream information
audio_info = { }
with av . open ( str ( video_path ) , " r " ) as audio_file :
try :
audio_stream = audio_file . streams . audio [ 0 ]
except IndexError :
# Reset logging level
av . logging . restore_default_callback ( )
return { " has_audio " : False }
audio_info [ " audio.channels " ] = audio_stream . channels
audio_info [ " audio.codec " ] = audio_stream . codec . canonical_name
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
audio_info [ " audio.bit_rate " ] = audio_stream . bit_rate
audio_info [ " audio.sample_rate " ] = audio_stream . sample_rate # Number of samples per second
# In an ideal loseless case : fixed number of bits per sample.
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
audio_info [ " audio.bit_depth " ] = audio_stream . format . bits
audio_info [ " audio.channel_layout " ] = audio_stream . layout . name
audio_info [ " has_audio " ] = True
# Reset logging level
av . logging . restore_default_callback ( )
return audio_info
2024-11-29 19:04:00 +01:00
def get_video_info ( video_path : Path | str ) - > dict :
2025-04-29 17:39:35 +02:00
# Set logging level
logging . getLogger ( " libav " ) . setLevel ( av . logging . ERROR )
# Getting video stream information
video_info = { }
with av . open ( str ( video_path ) , " r " ) as video_file :
try :
video_stream = video_file . streams . video [ 0 ]
except IndexError :
# Reset logging level
av . logging . restore_default_callback ( )
return { }
video_info [ " video.height " ] = video_stream . height
video_info [ " video.width " ] = video_stream . width
video_info [ " video.codec " ] = video_stream . codec . canonical_name
video_info [ " video.pix_fmt " ] = video_stream . pix_fmt
video_info [ " video.is_depth_map " ] = False
# Calculate fps from r_frame_rate
video_info [ " video.fps " ] = int ( video_stream . base_rate )
pixel_channels = get_video_pixel_channels ( video_stream . pix_fmt )
video_info [ " video.channels " ] = pixel_channels
# Reset logging level
av . logging . restore_default_callback ( )
# Adding audio stream information
video_info . update ( * * get_audio_info ( video_path ) )
2024-11-29 19:04:00 +01:00
return video_info
def get_video_pixel_channels ( pix_fmt : str ) - > int :
if " gray " in pix_fmt or " depth " in pix_fmt or " monochrome " in pix_fmt :
return 1
elif " rgba " in pix_fmt or " yuva " in pix_fmt :
return 4
elif " rgb " in pix_fmt or " yuv " in pix_fmt :
return 3
else :
raise ValueError ( " Unknown format " )
def get_image_pixel_channels ( image : Image ) :
if image . mode == " L " :
return 1 # Grayscale
elif image . mode == " LA " :
return 2 # Grayscale + Alpha
elif image . mode == " RGB " :
return 3 # RGB
elif image . mode == " RGBA " :
return 4 # RGBA
else :
raise ValueError ( " Unknown format " )
2025-07-18 19:18:52 +09:00
class VideoEncodingManager :
"""
Context manager that ensures proper video encoding and data cleanup even if exceptions occur .
This manager handles :
- Batch encoding for any remaining episodes when recording interrupted
- Cleaning up temporary image files from interrupted episodes
- Removing empty image directories
Args :
dataset : The LeRobotDataset instance
"""
def __init__ ( self , dataset ) :
self . dataset = dataset
def __enter__ ( self ) :
return self
def __exit__ ( self , exc_type , exc_val , exc_tb ) :
# Handle any remaining episodes that haven't been batch encoded
if self . dataset . episodes_since_last_encoding > 0 :
if exc_type is not None :
logging . info ( " Exception occurred. Encoding remaining episodes before exit... " )
else :
logging . info ( " Recording stopped. Encoding remaining episodes... " )
start_ep = self . dataset . num_episodes - self . dataset . episodes_since_last_encoding
end_ep = self . dataset . num_episodes
logging . info (
f " Encoding remaining { self . dataset . episodes_since_last_encoding } episodes, "
f " from episode { start_ep } to { end_ep - 1 } "
)
self . dataset . batch_encode_videos ( start_ep , end_ep )
# Clean up episode images if recording was interrupted
if exc_type is not None :
interrupted_episode_index = self . dataset . num_episodes
for key in self . dataset . meta . video_keys :
img_dir = self . dataset . _get_image_file_path (
episode_index = interrupted_episode_index , image_key = key , frame_index = 0
) . parent
if img_dir . exists ( ) :
logging . debug (
f " Cleaning up interrupted episode images for episode { interrupted_episode_index } , camera { key } "
)
shutil . rmtree ( img_dir )
# Clean up any remaining images directory if it's empty
img_dir = self . dataset . root / " images "
# Check for any remaining PNG files
png_files = list ( img_dir . rglob ( " *.png " ) )
if len ( png_files ) == 0 :
# Only remove the images directory if no PNG files remain
if img_dir . exists ( ) :
shutil . rmtree ( img_dir )
logging . debug ( " Cleaned up empty images directory " )
else :
logging . debug ( f " Images directory is not empty, containing { len ( png_files ) } PNG files " )
return False # Don't suppress the original exception