2024-05-15 12:13:09 +02:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2024-05-03 00:50:19 +02:00
import logging
import subprocess
import warnings
from dataclasses import dataclass , field
from pathlib import Path
from typing import Any , ClassVar
import pyarrow as pa
import torch
import torchvision
from datasets . features . features import register_feature
def load_from_videos (
item : dict [ str , torch . Tensor ] , video_frame_keys : list [ str ] , videos_dir : Path , tolerance_s : float
) :
""" Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process ( e . g . by using a second Dataloader with num_workers = 0 ) . It will result in a Segmentation Fault .
This probably happens because a memory reference to the video loader is created in the main process and a
subprocess fails to access it .
"""
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
data_dir = videos_dir . parent
for key in video_frame_keys :
if isinstance ( item [ key ] , list ) :
# load multiple frames at once (expected when delta_timestamps is not None)
timestamps = [ frame [ " timestamp " ] for frame in item [ key ] ]
paths = [ frame [ " path " ] for frame in item [ key ] ]
if len ( set ( paths ) ) > 1 :
raise NotImplementedError ( " All video paths are expected to be the same for now. " )
video_path = data_dir / paths [ 0 ]
frames = decode_video_frames_torchvision ( video_path , timestamps , tolerance_s )
item [ key ] = frames
else :
# load one frame
timestamps = [ item [ key ] [ " timestamp " ] ]
video_path = data_dir / item [ key ] [ " path " ]
frames = decode_video_frames_torchvision ( video_path , timestamps , tolerance_s )
item [ key ] = frames [ 0 ]
return item
def decode_video_frames_torchvision (
video_path : str ,
timestamps : list [ float ] ,
tolerance_s : float ,
device : str = " cpu " ,
log_loaded_timestamps : bool = False ,
) :
""" Loads frames associated to the requested timestamps of a video
Note : Video benefits from inter - frame compression . Instead of storing every frame individually ,
the encoder stores a reference frame ( or a key frame ) and subsequent frames as differences relative to
that key frame . As a consequence , to access a requested frame , we need to load the preceding key frame ,
and all subsequent frames until reaching the requested frame . The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes .
"""
video_path = str ( video_path )
# set backend
keyframes_only = False
if device == " cpu " :
# explicitely use pyav
torchvision . set_video_backend ( " pyav " )
keyframes_only = True # pyav doesnt support accuracte seek
elif device == " cuda " :
# TODO(rcadene, aliberts): implement video decoding with GPU
# torchvision.set_video_backend("cuda")
# torchvision.set_video_backend("video_reader")
# requires installing torchvision from source, see: https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst
# check possible bug: https://github.com/pytorch/vision/issues/7745
raise NotImplementedError (
" Video decoding on gpu with cuda is currently not supported. Use `device= ' cpu ' `. "
)
else :
raise ValueError ( device )
# set a video stream reader
# TODO(rcadene): also load audio stream at the same time
reader = torchvision . io . VideoReader ( video_path , " video " )
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = timestamps [ 0 ]
last_ts = timestamps [ - 1 ]
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader . seek ( first_ts , keyframes_only = keyframes_only )
# load all frames until last requested frame
loaded_frames = [ ]
loaded_ts = [ ]
for frame in reader :
current_ts = frame [ " pts " ]
if log_loaded_timestamps :
logging . info ( f " frame loaded at timestamp= { current_ts : .4f } " )
loaded_frames . append ( frame [ " data " ] )
loaded_ts . append ( current_ts )
if current_ts > = last_ts :
break
reader . container . close ( )
reader = None
query_ts = torch . tensor ( timestamps )
loaded_ts = torch . tensor ( loaded_ts )
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch . cdist ( query_ts [ : , None ] , loaded_ts [ : , None ] , p = 1 )
min_ , argmin_ = dist . min ( 1 )
is_within_tol = min_ < tolerance_s
assert is_within_tol . all ( ) , (
f " One or several query timestamps unexpectedly violate the tolerance ( { min_ [ ~ is_within_tol ] } > { tolerance_s =} ). "
" It means that the closest frame that can be loaded from the video is too far away in time. "
" This might be due to synchronization issues with timestamps during data collection. "
" To be safe, we advise to ignore this item during training. "
)
# get closest frames to the query timestamps
closest_frames = torch . stack ( [ loaded_frames [ idx ] for idx in argmin_ ] )
closest_ts = loaded_ts [ argmin_ ]
if log_loaded_timestamps :
logging . info ( f " { closest_ts =} " )
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames . type ( torch . float32 ) / 255
assert len ( timestamps ) == len ( closest_frames )
return closest_frames
def encode_video_frames ( imgs_dir : Path , video_path : Path , fps : int ) :
""" More info on ffmpeg arguments tuning on `lerobot/common/datasets/_video_benchmark/README.md` """
video_path = Path ( video_path )
video_path . parent . mkdir ( parents = True , exist_ok = True )
ffmpeg_cmd = (
f " ffmpeg -r { fps } "
" -f image2 "
" -loglevel error "
f " -i { str ( imgs_dir / ' frame_ %06d .png ' ) } "
" -vcodec libx264 "
" -g 2 "
" -pix_fmt yuv444p "
f " { str ( video_path ) } "
)
subprocess . run ( ffmpeg_cmd . split ( " " ) , check = True )
@dataclass
class VideoFrame :
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
"""
Provides a type for a dataset containing video frames .
Example :
` ` ` python
data_dict = [ { " image " : { " path " : " videos/episode_0.mp4 " , " timestamp " : 0.3 } } ]
features = { " image " : VideoFrame ( ) }
Dataset . from_dict ( data_dict , features = Features ( features ) )
` ` `
"""
pa_type : ClassVar [ Any ] = pa . struct ( { " path " : pa . string ( ) , " timestamp " : pa . float32 ( ) } )
_type : str = field ( default = " VideoFrame " , init = False , repr = False )
def __call__ ( self ) :
return self . pa_type
with warnings . catch_warnings ( ) :
warnings . filterwarnings (
" ignore " ,
" ' register_feature ' is experimental and might be subject to breaking changes in the future. " ,
category = UserWarning ,
)
# to make VideoFrame available in HuggingFace `datasets`
register_feature ( VideoFrame , " VideoFrame " )