2024-05-15 12:13:09 +02:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2025-02-25 15:27:29 +01:00
import contextlib
2024-12-09 10:32:25 +00:00
import importlib . resources
2024-04-25 12:23:12 +02:00
import json
2024-11-29 19:04:00 +01:00
import logging
2024-12-20 16:26:23 +01:00
from collections . abc import Iterator
2024-11-29 19:04:00 +01:00
from itertools import accumulate
2024-04-23 14:13:25 +02:00
from pathlib import Path
2024-11-29 19:04:00 +01:00
from pprint import pformat
2024-12-20 16:26:23 +01:00
from types import SimpleNamespace
2024-11-29 19:04:00 +01:00
from typing import Any
2024-03-01 14:59:05 +01:00
2024-04-18 11:43:16 +02:00
import datasets
2024-11-29 19:04:00 +01:00
import jsonlines
import numpy as np
2025-02-25 15:27:29 +01:00
import packaging . version
2024-03-31 15:05:25 +00:00
import torch
2024-11-29 19:04:00 +01:00
from datasets . table import embed_table_storage
from huggingface_hub import DatasetCard , DatasetCardData , HfApi
2025-02-28 14:36:20 +01:00
from huggingface_hub . errors import RevisionNotFoundError
2024-04-23 14:13:25 +02:00
from PIL import Image as PILImage
from torchvision import transforms
2025-02-25 15:27:29 +01:00
from lerobot . common . datasets . backward_compatibility import (
V21_MESSAGE ,
BackwardCompatibilityError ,
ForwardCompatibilityError ,
)
2024-11-29 19:04:00 +01:00
from lerobot . common . robot_devices . robots . utils import Robot
2025-02-25 15:27:29 +01:00
from lerobot . common . utils . utils import is_valid_numpy_dtype_string
2025-01-31 13:57:37 +01:00
from lerobot . configs . types import DictLike , FeatureType , PolicyFeature
2024-11-29 19:04:00 +01:00
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = " meta/info.json "
EPISODES_PATH = " meta/episodes.jsonl "
STATS_PATH = " meta/stats.json "
2025-02-25 15:27:29 +01:00
EPISODES_STATS_PATH = " meta/episodes_stats.jsonl "
2024-11-29 19:04:00 +01:00
TASKS_PATH = " meta/tasks.jsonl "
DEFAULT_VIDEO_PATH = " videos/chunk- {episode_chunk:03d} / {video_key} /episode_ {episode_index:06d} .mp4 "
DEFAULT_PARQUET_PATH = " data/chunk- {episode_chunk:03d} /episode_ {episode_index:06d} .parquet "
DEFAULT_IMAGE_PATH = " images/ {image_key} /episode_ {episode_index:06d} /frame_ {frame_index:06d} .png "
2024-08-16 10:08:44 +02:00
DATASET_CARD_TEMPLATE = """
- - -
# Metadata will go there
- - -
2024-09-25 16:56:05 +02:00
This dataset was created using [ LeRobot ] ( https : / / github . com / huggingface / lerobot ) .
2024-08-16 10:08:44 +02:00
2024-11-29 19:04:00 +01:00
## {}
2024-08-16 10:08:44 +02:00
"""
2024-11-29 19:04:00 +01:00
DEFAULT_FEATURES = {
" timestamp " : { " dtype " : " float32 " , " shape " : ( 1 , ) , " names " : None } ,
" frame_index " : { " dtype " : " int64 " , " shape " : ( 1 , ) , " names " : None } ,
" episode_index " : { " dtype " : " int64 " , " shape " : ( 1 , ) , " names " : None } ,
" index " : { " dtype " : " int64 " , " shape " : ( 1 , ) , " names " : None } ,
" task_index " : { " dtype " : " int64 " , " shape " : ( 1 , ) , " names " : None } ,
}
2024-04-23 14:13:25 +02:00
2024-11-29 19:04:00 +01:00
def flatten_dict ( d : dict , parent_key : str = " " , sep : str = " / " ) - > dict :
2024-04-23 14:13:25 +02:00
""" Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
2024-04-25 12:23:12 +02:00
2024-04-23 14:13:25 +02:00
For example :
` ` `
>> > dct = { " a " : { " b " : 1 , " c " : { " d " : 2 } } , " e " : 3 } `
>> > print ( flatten_dict ( dct ) )
{ " a/b " : 1 , " a/c/d " : 2 , " e " : 3 }
"""
items = [ ]
for k , v in d . items ( ) :
new_key = f " { parent_key } { sep } { k } " if parent_key else k
if isinstance ( v , dict ) :
items . extend ( flatten_dict ( v , new_key , sep = sep ) . items ( ) )
else :
items . append ( ( new_key , v ) )
return dict ( items )
2024-11-29 19:04:00 +01:00
def unflatten_dict ( d : dict , sep : str = " / " ) - > dict :
2024-04-23 14:13:25 +02:00
outdict = { }
for key , value in d . items ( ) :
parts = key . split ( sep )
d = outdict
for part in parts [ : - 1 ] :
if part not in d :
d [ part ] = { }
d = d [ part ]
d [ parts [ - 1 ] ] = value
return outdict
2025-01-31 13:57:37 +01:00
def get_nested_item ( obj : DictLike , flattened_key : str , sep : str = " / " ) - > Any :
split_keys = flattened_key . split ( sep )
getter = obj [ split_keys [ 0 ] ]
if len ( split_keys ) == 1 :
return getter
for key in split_keys [ 1 : ] :
getter = getter [ key ]
return getter
2024-11-29 19:04:00 +01:00
def serialize_dict ( stats : dict [ str , torch . Tensor | np . ndarray | dict ] ) - > dict :
2025-02-25 15:27:29 +01:00
serialized_dict = { }
for key , value in flatten_dict ( stats ) . items ( ) :
if isinstance ( value , ( torch . Tensor , np . ndarray ) ) :
serialized_dict [ key ] = value . tolist ( )
elif isinstance ( value , np . generic ) :
serialized_dict [ key ] = value . item ( )
elif isinstance ( value , ( int , float ) ) :
serialized_dict [ key ] = value
else :
raise NotImplementedError ( f " The value ' { value } ' of type ' { type ( value ) } ' is not supported. " )
2024-11-29 19:04:00 +01:00
return unflatten_dict ( serialized_dict )
2025-02-25 15:27:29 +01:00
def embed_images ( dataset : datasets . Dataset ) - > datasets . Dataset :
2024-11-29 19:04:00 +01:00
# Embed image bytes into the table before saving to parquet
format = dataset . format
dataset = dataset . with_format ( " arrow " )
dataset = dataset . map ( embed_table_storage , batched = False )
dataset = dataset . with_format ( * * format )
2025-02-25 15:27:29 +01:00
return dataset
2024-11-29 19:04:00 +01:00
def load_json ( fpath : Path ) - > Any :
with open ( fpath ) as f :
return json . load ( f )
def write_json ( data : dict , fpath : Path ) - > None :
fpath . parent . mkdir ( exist_ok = True , parents = True )
with open ( fpath , " w " ) as f :
json . dump ( data , f , indent = 4 , ensure_ascii = False )
def load_jsonlines ( fpath : Path ) - > list [ Any ] :
with jsonlines . open ( fpath , " r " ) as reader :
return list ( reader )
def write_jsonlines ( data : dict , fpath : Path ) - > None :
fpath . parent . mkdir ( exist_ok = True , parents = True )
with jsonlines . open ( fpath , " w " ) as writer :
writer . write_all ( data )
def append_jsonlines ( data : dict , fpath : Path ) - > None :
fpath . parent . mkdir ( exist_ok = True , parents = True )
with jsonlines . open ( fpath , " a " ) as writer :
writer . write ( data )
2025-02-25 15:27:29 +01:00
def write_info ( info : dict , local_dir : Path ) :
write_json ( info , local_dir / INFO_PATH )
2024-11-29 19:04:00 +01:00
def load_info ( local_dir : Path ) - > dict :
info = load_json ( local_dir / INFO_PATH )
for ft in info [ " features " ] . values ( ) :
ft [ " shape " ] = tuple ( ft [ " shape " ] )
return info
2025-02-25 15:27:29 +01:00
def write_stats ( stats : dict , local_dir : Path ) :
serialized_stats = serialize_dict ( stats )
write_json ( serialized_stats , local_dir / STATS_PATH )
def cast_stats_to_numpy ( stats ) - > dict [ str , dict [ str , np . ndarray ] ] :
stats = { key : np . array ( value ) for key , value in flatten_dict ( stats ) . items ( ) }
return unflatten_dict ( stats )
def load_stats ( local_dir : Path ) - > dict [ str , dict [ str , np . ndarray ] ] :
2024-11-29 19:04:00 +01:00
if not ( local_dir / STATS_PATH ) . exists ( ) :
return None
stats = load_json ( local_dir / STATS_PATH )
2025-02-25 15:27:29 +01:00
return cast_stats_to_numpy ( stats )
def write_task ( task_index : int , task : dict , local_dir : Path ) :
task_dict = {
" task_index " : task_index ,
" task " : task ,
}
append_jsonlines ( task_dict , local_dir / TASKS_PATH )
2024-11-29 19:04:00 +01:00
2025-02-25 15:27:29 +01:00
def load_tasks ( local_dir : Path ) - > tuple [ dict , dict ] :
2024-11-29 19:04:00 +01:00
tasks = load_jsonlines ( local_dir / TASKS_PATH )
2025-02-25 15:27:29 +01:00
tasks = { item [ " task_index " ] : item [ " task " ] for item in sorted ( tasks , key = lambda x : x [ " task_index " ] ) }
task_to_task_index = { task : task_index for task_index , task in tasks . items ( ) }
return tasks , task_to_task_index
def write_episode ( episode : dict , local_dir : Path ) :
append_jsonlines ( episode , local_dir / EPISODES_PATH )
2024-11-29 19:04:00 +01:00
def load_episodes ( local_dir : Path ) - > dict :
2025-02-25 15:27:29 +01:00
episodes = load_jsonlines ( local_dir / EPISODES_PATH )
return { item [ " episode_index " ] : item for item in sorted ( episodes , key = lambda x : x [ " episode_index " ] ) }
def write_episode_stats ( episode_index : int , episode_stats : dict , local_dir : Path ) :
2025-02-25 23:51:15 +01:00
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
2025-02-25 15:27:29 +01:00
# is a dictionary of stats and not an integer.
episode_stats = { " episode_index " : episode_index , " stats " : serialize_dict ( episode_stats ) }
append_jsonlines ( episode_stats , local_dir / EPISODES_STATS_PATH )
def load_episodes_stats ( local_dir : Path ) - > dict :
episodes_stats = load_jsonlines ( local_dir / EPISODES_STATS_PATH )
return {
item [ " episode_index " ] : cast_stats_to_numpy ( item [ " stats " ] )
for item in sorted ( episodes_stats , key = lambda x : x [ " episode_index " ] )
}
def backward_compatible_episodes_stats (
stats : dict [ str , dict [ str , np . ndarray ] ] , episodes : list [ int ]
) - > dict [ str , dict [ str , np . ndarray ] ] :
return { ep_idx : stats for ep_idx in episodes }
2024-11-29 19:04:00 +01:00
2025-02-25 15:27:29 +01:00
def load_image_as_numpy (
fpath : str | Path , dtype : np . dtype = np . float32 , channel_first : bool = True
) - > np . ndarray :
2024-11-29 19:04:00 +01:00
img = PILImage . open ( fpath ) . convert ( " RGB " )
img_array = np . array ( img , dtype = dtype )
if channel_first : # (H, W, C) -> (C, H, W)
img_array = np . transpose ( img_array , ( 2 , 0 , 1 ) )
2025-02-25 15:27:29 +01:00
if np . issubdtype ( dtype , np . floating ) :
2024-11-29 19:04:00 +01:00
img_array / = 255.0
return img_array
2024-05-30 16:12:21 +01:00
def hf_transform_to_torch ( items_dict : dict [ torch . Tensor | None ] ) :
2024-04-23 14:13:25 +02:00
""" Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors . Importantly , images are converted from PIL , which corresponds to
a channel last representation ( h w c ) of uint8 type , to a torch image representation
with channel first ( c h w ) of float32 type in range [ 0 , 1 ] .
"""
for key in items_dict :
first_item = items_dict [ key ] [ 0 ]
if isinstance ( first_item , PILImage . Image ) :
to_tensor = transforms . ToTensor ( )
items_dict [ key ] = [ to_tensor ( img ) for img in items_dict [ key ] ]
2024-05-30 16:12:21 +01:00
elif first_item is None :
pass
2024-04-23 14:13:25 +02:00
else :
2025-02-25 15:27:29 +01:00
items_dict [ key ] = [ x if isinstance ( x , str ) else torch . tensor ( x ) for x in items_dict [ key ] ]
2024-04-23 14:13:25 +02:00
return items_dict
2025-02-25 15:27:29 +01:00
def is_valid_version ( version : str ) - > bool :
try :
packaging . version . parse ( version )
return True
except packaging . version . InvalidVersion :
return False
2024-11-29 19:04:00 +01:00
2025-02-25 15:27:29 +01:00
def check_version_compatibility (
repo_id : str ,
version_to_check : str | packaging . version . Version ,
current_version : str | packaging . version . Version ,
enforce_breaking_major : bool = True ,
) - > None :
v_check = (
packaging . version . parse ( version_to_check )
if not isinstance ( version_to_check , packaging . version . Version )
else version_to_check
)
v_current = (
packaging . version . parse ( current_version )
if not isinstance ( current_version , packaging . version . Version )
else current_version
)
if v_check . major < v_current . major and enforce_breaking_major :
raise BackwardCompatibilityError ( repo_id , v_check )
elif v_check . minor < v_current . minor :
logging . warning ( V21_MESSAGE . format ( repo_id = repo_id , version = v_check ) )
2024-11-29 19:04:00 +01:00
2025-02-25 15:27:29 +01:00
def get_repo_versions ( repo_id : str ) - > list [ packaging . version . Version ] :
""" Returns available valid versions (branches and tags) on given repo. """
api = HfApi ( )
repo_refs = api . list_repo_refs ( repo_id , repo_type = " dataset " )
repo_refs = [ b . name for b in repo_refs . branches + repo_refs . tags ]
repo_versions = [ ]
for ref in repo_refs :
with contextlib . suppress ( packaging . version . InvalidVersion ) :
repo_versions . append ( packaging . version . parse ( ref ) )
2024-11-29 19:04:00 +01:00
2025-02-25 15:27:29 +01:00
return repo_versions
2024-11-29 19:04:00 +01:00
2025-02-25 15:27:29 +01:00
def get_safe_version ( repo_id : str , version : str | packaging . version . Version ) - > str :
"""
Returns the version if available on repo or the latest compatible one .
Otherwise , will throw a ` CompatibilityError ` .
"""
target_version = (
packaging . version . parse ( version ) if not isinstance ( version , packaging . version . Version ) else version
)
hub_versions = get_repo_versions ( repo_id )
2024-11-29 19:04:00 +01:00
2025-02-28 14:36:20 +01:00
if not hub_versions :
raise RevisionNotFoundError (
f """ Your dataset must be tagged with a codebase version.
Assuming _version_ is the codebase_version value in the info . json , you can run this :
` ` ` python
from huggingface_hub import HfApi
hub_api = HfApi ( )
hub_api . create_tag ( " {repo_id} " , tag = " _version_ " , repo_type = " dataset " )
` ` `
"""
)
2025-02-25 15:27:29 +01:00
if target_version in hub_versions :
return f " v { target_version } "
2024-11-29 19:04:00 +01:00
2025-02-25 15:27:29 +01:00
compatibles = [
v for v in hub_versions if v . major == target_version . major and v . minor < = target_version . minor
]
if compatibles :
return_version = max ( compatibles )
if return_version < target_version :
logging . warning ( f " Revision { version } for { repo_id } not found, using version v { return_version } " )
return f " v { return_version } "
lower_major = [ v for v in hub_versions if v . major < target_version . major ]
if lower_major :
raise BackwardCompatibilityError ( repo_id , max ( lower_major ) )
upper_versions = [ v for v in hub_versions if v > target_version ]
assert len ( upper_versions ) > 0
raise ForwardCompatibilityError ( repo_id , min ( upper_versions ) )
2024-07-16 23:02:31 +02:00
2024-11-29 19:04:00 +01:00
def get_hf_features_from_features ( features : dict ) - > datasets . Features :
hf_features = { }
for key , ft in features . items ( ) :
if ft [ " dtype " ] == " video " :
continue
elif ft [ " dtype " ] == " image " :
hf_features [ key ] = datasets . Image ( )
elif ft [ " shape " ] == ( 1 , ) :
hf_features [ key ] = datasets . Value ( dtype = ft [ " dtype " ] )
2025-02-25 15:27:29 +01:00
elif len ( ft [ " shape " ] ) == 1 :
2024-11-29 19:04:00 +01:00
hf_features [ key ] = datasets . Sequence (
length = ft [ " shape " ] [ 0 ] , feature = datasets . Value ( dtype = ft [ " dtype " ] )
)
2025-02-25 15:27:29 +01:00
elif len ( ft [ " shape " ] ) == 2 :
hf_features [ key ] = datasets . Array2D ( shape = ft [ " shape " ] , dtype = ft [ " dtype " ] )
elif len ( ft [ " shape " ] ) == 3 :
hf_features [ key ] = datasets . Array3D ( shape = ft [ " shape " ] , dtype = ft [ " dtype " ] )
elif len ( ft [ " shape " ] ) == 4 :
hf_features [ key ] = datasets . Array4D ( shape = ft [ " shape " ] , dtype = ft [ " dtype " ] )
elif len ( ft [ " shape " ] ) == 5 :
hf_features [ key ] = datasets . Array5D ( shape = ft [ " shape " ] , dtype = ft [ " dtype " ] )
else :
raise ValueError ( f " Corresponding feature is not valid: { ft } " )
2024-04-23 14:13:25 +02:00
2024-11-29 19:04:00 +01:00
return datasets . Features ( hf_features )
2024-03-01 14:59:05 +01:00
2024-11-29 19:04:00 +01:00
def get_features_from_robot ( robot : Robot , use_videos : bool = True ) - > dict :
camera_ft = { }
if robot . cameras :
camera_ft = {
key : { " dtype " : " video " if use_videos else " image " , * * ft }
for key , ft in robot . camera_features . items ( )
}
return { * * robot . motor_features , * * camera_ft , * * DEFAULT_FEATURES }
2025-01-31 13:57:37 +01:00
def dataset_to_policy_features ( features : dict [ str , dict ] ) - > dict [ str , PolicyFeature ] :
# TODO(aliberts): Implement "type" in dataset features and simplify this
policy_features = { }
for key , ft in features . items ( ) :
shape = ft [ " shape " ]
if ft [ " dtype " ] in [ " image " , " video " ] :
type = FeatureType . VISUAL
if len ( shape ) != 3 :
raise ValueError ( f " Number of dimensions of { key } != 3 (shape= { shape } ) " )
names = ft [ " names " ]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names [ 2 ] in [ " channel " , " channels " ] : # (h, w, c) -> (c, h, w)
shape = ( shape [ 2 ] , shape [ 0 ] , shape [ 1 ] )
elif key == " observation.environment_state " :
type = FeatureType . ENV
elif key . startswith ( " observation " ) :
type = FeatureType . STATE
elif key == " action " :
type = FeatureType . ACTION
else :
continue
policy_features [ key ] = PolicyFeature (
type = type ,
shape = shape ,
)
return policy_features
2024-11-29 19:04:00 +01:00
def create_empty_dataset_info (
codebase_version : str ,
fps : int ,
robot_type : str ,
features : dict ,
use_videos : bool ,
) - > dict :
return {
" codebase_version " : codebase_version ,
" robot_type " : robot_type ,
" total_episodes " : 0 ,
" total_frames " : 0 ,
" total_tasks " : 0 ,
" total_videos " : 0 ,
" total_chunks " : 0 ,
" chunks_size " : DEFAULT_CHUNK_SIZE ,
" fps " : fps ,
" splits " : { } ,
" data_path " : DEFAULT_PARQUET_PATH ,
" video_path " : DEFAULT_VIDEO_PATH if use_videos else None ,
" features " : features ,
}
2024-04-25 12:23:12 +02:00
2024-11-29 19:04:00 +01:00
def get_episode_data_index (
2025-02-25 15:27:29 +01:00
episode_dicts : dict [ dict ] , episodes : list [ int ] | None = None
2024-11-29 19:04:00 +01:00
) - > dict [ str , torch . Tensor ] :
2025-02-25 15:27:29 +01:00
episode_lengths = { ep_idx : ep_dict [ " length " ] for ep_idx , ep_dict in episode_dicts . items ( ) }
2024-11-29 19:04:00 +01:00
if episodes is not None :
episode_lengths = { ep_idx : episode_lengths [ ep_idx ] for ep_idx in episodes }
2024-04-25 12:23:12 +02:00
2025-02-25 23:51:15 +01:00
cumulative_lengths = list ( accumulate ( episode_lengths . values ( ) ) )
2024-11-29 19:04:00 +01:00
return {
2025-02-25 23:51:15 +01:00
" from " : torch . LongTensor ( [ 0 ] + cumulative_lengths [ : - 1 ] ) ,
" to " : torch . LongTensor ( cumulative_lengths ) ,
2024-11-29 19:04:00 +01:00
}
2024-04-25 12:23:12 +02:00
2024-05-03 00:50:19 +02:00
2024-11-29 19:04:00 +01:00
def check_timestamps_sync (
2025-02-25 15:27:29 +01:00
timestamps : np . ndarray ,
episode_indices : np . ndarray ,
episode_data_index : dict [ str , np . ndarray ] ,
2024-11-29 19:04:00 +01:00
fps : int ,
2024-05-03 00:50:19 +02:00
tolerance_s : float ,
2024-11-29 19:04:00 +01:00
raise_value_error : bool = True ,
) - > bool :
2024-04-11 12:59:09 +00:00
"""
2025-02-25 15:27:29 +01:00
This check is to make sure that each timestamp is separated from the next by ( 1 / fps ) + / - tolerance
to account for possible numerical error .
Args :
timestamps ( np . ndarray ) : Array of timestamps in seconds .
episode_indices ( np . ndarray ) : Array indicating the episode index for each timestamp .
episode_data_index ( dict [ str , np . ndarray ] ) : A dictionary that includes ' to ' ,
which identifies indices for the end of each episode .
fps ( int ) : Frames per second . Used to check the expected difference between consecutive timestamps .
tolerance_s ( float ) : Allowed deviation from the expected ( 1 / fps ) difference .
raise_value_error ( bool ) : Whether to raise a ValueError if the check fails .
Returns :
bool : True if all checked timestamp differences lie within tolerance , False otherwise .
Raises :
ValueError : If the check fails and ` raise_value_error ` is True .
2024-04-11 12:59:09 +00:00
"""
2025-02-25 15:27:29 +01:00
if timestamps . shape != episode_indices . shape :
raise ValueError (
" timestamps and episode_indices should have the same shape. "
f " Found { timestamps . shape =} and { episode_indices . shape =} . "
)
# Consecutive differences
diffs = np . diff ( timestamps )
within_tolerance = np . abs ( diffs - ( 1.0 / fps ) ) < = tolerance_s
# Mask to ignore differences at the boundaries between episodes
mask = np . ones ( len ( diffs ) , dtype = bool )
ignored_diffs = episode_data_index [ " to " ] [ : - 1 ] - 1 # indices at the end of each episode
2024-11-29 19:04:00 +01:00
mask [ ignored_diffs ] = False
filtered_within_tolerance = within_tolerance [ mask ]
2025-02-25 15:27:29 +01:00
# Check if all remaining diffs are within tolerance
if not np . all ( filtered_within_tolerance ) :
2024-11-29 19:04:00 +01:00
# Track original indices before masking
2025-02-25 15:27:29 +01:00
original_indices = np . arange ( len ( diffs ) )
2024-11-29 19:04:00 +01:00
filtered_indices = original_indices [ mask ]
2025-02-25 15:27:29 +01:00
outside_tolerance_filtered_indices = np . nonzero ( ~ filtered_within_tolerance ) [ 0 ]
2024-11-29 19:04:00 +01:00
outside_tolerance_indices = filtered_indices [ outside_tolerance_filtered_indices ]
outside_tolerances = [ ]
for idx in outside_tolerance_indices :
entry = {
" timestamps " : [ timestamps [ idx ] , timestamps [ idx + 1 ] ] ,
" diff " : diffs [ idx ] ,
2025-02-25 15:27:29 +01:00
" episode_index " : episode_indices [ idx ] . item ( )
if hasattr ( episode_indices [ idx ] , " item " )
else episode_indices [ idx ] ,
2024-11-29 19:04:00 +01:00
}
outside_tolerances . append ( entry )
if raise_value_error :
raise ValueError (
f """ One or several timestamps unexpectedly violate the tolerance inside episode range.
2025-02-25 15:27:29 +01:00
This might be due to synchronization issues during data collection .
2024-11-29 19:04:00 +01:00
\n { pformat ( outside_tolerances ) } """
)
return False
return True
def check_delta_timestamps (
delta_timestamps : dict [ str , list [ float ] ] , fps : int , tolerance_s : float , raise_value_error : bool = True
) - > bool :
""" This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
actual timestamps from the dataset .
2024-05-20 22:04:04 +10:00
"""
2024-11-29 19:04:00 +01:00
outside_tolerance = { }
for key , delta_ts in delta_timestamps . items ( ) :
within_tolerance = [ abs ( ts * fps - round ( ts * fps ) ) / fps < = tolerance_s for ts in delta_ts ]
if not all ( within_tolerance ) :
outside_tolerance [ key ] = [
ts for ts , is_within in zip ( delta_ts , within_tolerance , strict = True ) if not is_within
]
2024-05-20 22:04:04 +10:00
2024-11-29 19:04:00 +01:00
if len ( outside_tolerance ) > 0 :
if raise_value_error :
raise ValueError (
f """
The following delta_timestamps are found outside of tolerance range .
Please make sure they are multiples of 1 / { fps } + / - tolerance and adjust
their values accordingly .
\n { pformat ( outside_tolerance ) }
"""
)
return False
2024-05-20 22:04:04 +10:00
2024-11-29 19:04:00 +01:00
return True
2024-05-20 22:04:04 +10:00
2024-11-29 19:04:00 +01:00
def get_delta_indices ( delta_timestamps : dict [ str , list [ float ] ] , fps : int ) - > dict [ str , list [ int ] ] :
delta_indices = { }
for key , delta_ts in delta_timestamps . items ( ) :
2025-01-31 13:57:37 +01:00
delta_indices [ key ] = [ round ( d * fps ) for d in delta_ts ]
2024-05-20 22:04:04 +10:00
2024-11-29 19:04:00 +01:00
return delta_indices
2024-05-20 22:04:04 +10:00
2024-04-05 10:59:32 +00:00
def cycle ( iterable ) :
2024-04-16 12:51:32 +01:00
""" The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
See https : / / github . com / pytorch / pytorch / issues / 23900 for information on why itertools . cycle is not safe .
"""
2024-04-05 10:59:32 +00:00
iterator = iter ( iterable )
while True :
try :
yield next ( iterator )
except StopIteration :
iterator = iter ( iterable )
2024-08-15 18:11:33 +02:00
2024-11-29 19:04:00 +01:00
def create_branch ( repo_id , * , branch : str , repo_type : str | None = None ) - > None :
2024-08-15 18:11:33 +02:00
""" Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it .
"""
api = HfApi ( )
branches = api . list_repo_refs ( repo_id , repo_type = repo_type ) . branches
refs = [ branch . ref for branch in branches ]
ref = f " refs/heads/ { branch } "
if ref in refs :
api . delete_branch ( repo_id , repo_type = repo_type , branch = branch )
api . create_branch ( repo_id , repo_type = repo_type , branch = branch )
2024-08-16 10:08:44 +02:00
2024-11-29 19:04:00 +01:00
def create_lerobot_dataset_card (
tags : list | None = None ,
dataset_info : dict | None = None ,
* * kwargs ,
) - > DatasetCard :
"""
Keyword arguments will be used to replace values in . / lerobot / common / datasets / card_template . md .
Note : If specified , license must be one of https : / / huggingface . co / docs / hub / repositories - licenses .
"""
card_tags = [ " LeRobot " ]
2024-12-09 10:32:25 +00:00
2024-11-29 19:04:00 +01:00
if tags :
card_tags + = tags
if dataset_info :
dataset_structure = " [meta/info.json](meta/info.json): \n "
dataset_structure + = f " ```json \n { json . dumps ( dataset_info , indent = 4 ) } \n ``` \n "
kwargs = { * * kwargs , " dataset_structure " : dataset_structure }
card_data = DatasetCardData (
license = kwargs . get ( " license " ) ,
tags = card_tags ,
task_categories = [ " robotics " ] ,
configs = [
{
" config_name " : " default " ,
" data_files " : " data/*/*.parquet " ,
}
] ,
)
2024-12-09 10:32:25 +00:00
2024-12-23 21:05:59 +07:00
card_template = ( importlib . resources . files ( " lerobot.common.datasets " ) / " card_template.md " ) . read_text ( )
2024-11-29 19:04:00 +01:00
return DatasetCard . from_template (
card_data = card_data ,
2024-12-23 21:05:59 +07:00
template_str = card_template ,
2024-11-29 19:04:00 +01:00
* * kwargs ,
)
2024-12-20 16:26:23 +01:00
class IterableNamespace ( SimpleNamespace ) :
"""
A namespace object that supports both dictionary - like iteration and dot notation access .
Automatically converts nested dictionaries into IterableNamespaces .
This class extends SimpleNamespace to provide :
- Dictionary - style iteration over keys
- Access to items via both dot notation ( obj . key ) and brackets ( obj [ " key " ] )
- Dictionary - like methods : items ( ) , keys ( ) , values ( )
- Recursive conversion of nested dictionaries
Args :
dictionary : Optional dictionary to initialize the namespace
* * kwargs : Additional keyword arguments passed to SimpleNamespace
Examples :
>> > data = { " name " : " Alice " , " details " : { " age " : 25 } }
>> > ns = IterableNamespace ( data )
>> > ns . name
' Alice '
>> > ns . details . age
25
>> > list ( ns . keys ( ) )
[ ' name ' , ' details ' ]
>> > for key , value in ns . items ( ) :
. . . print ( f " { key } : { value } " )
name : Alice
details : IterableNamespace ( age = 25 )
"""
def __init__ ( self , dictionary : dict [ str , Any ] = None , * * kwargs ) :
super ( ) . __init__ ( * * kwargs )
if dictionary is not None :
for key , value in dictionary . items ( ) :
if isinstance ( value , dict ) :
setattr ( self , key , IterableNamespace ( value ) )
else :
setattr ( self , key , value )
def __iter__ ( self ) - > Iterator [ str ] :
return iter ( vars ( self ) )
def __getitem__ ( self , key : str ) - > Any :
return vars ( self ) [ key ]
def items ( self ) :
return vars ( self ) . items ( )
def values ( self ) :
return vars ( self ) . values ( )
def keys ( self ) :
return vars ( self ) . keys ( )
2025-02-25 15:27:29 +01:00
def validate_frame ( frame : dict , features : dict ) :
optional_features = { " timestamp " }
expected_features = ( set ( features ) - set ( DEFAULT_FEATURES . keys ( ) ) ) | { " task " }
actual_features = set ( frame . keys ( ) )
error_message = validate_features_presence ( actual_features , expected_features , optional_features )
if " task " in frame :
error_message + = validate_feature_string ( " task " , frame [ " task " ] )
common_features = actual_features & ( expected_features | optional_features )
for name in common_features - { " task " } :
error_message + = validate_feature_dtype_and_shape ( name , features [ name ] , frame [ name ] )
if error_message :
raise ValueError ( error_message )
def validate_features_presence (
actual_features : set [ str ] , expected_features : set [ str ] , optional_features : set [ str ]
) :
error_message = " "
missing_features = expected_features - actual_features
extra_features = actual_features - ( expected_features | optional_features )
if missing_features or extra_features :
error_message + = " Feature mismatch in `frame` dictionary: \n "
if missing_features :
error_message + = f " Missing features: { missing_features } \n "
if extra_features :
error_message + = f " Extra features: { extra_features } \n "
return error_message
def validate_feature_dtype_and_shape ( name : str , feature : dict , value : np . ndarray | PILImage . Image | str ) :
expected_dtype = feature [ " dtype " ]
expected_shape = feature [ " shape " ]
if is_valid_numpy_dtype_string ( expected_dtype ) :
return validate_feature_numpy_array ( name , expected_dtype , expected_shape , value )
elif expected_dtype in [ " image " , " video " ] :
return validate_feature_image_or_video ( name , expected_shape , value )
elif expected_dtype == " string " :
return validate_feature_string ( name , value )
else :
raise NotImplementedError ( f " The feature dtype ' { expected_dtype } ' is not implemented yet. " )
def validate_feature_numpy_array (
name : str , expected_dtype : str , expected_shape : list [ int ] , value : np . ndarray
) :
error_message = " "
if isinstance ( value , np . ndarray ) :
actual_dtype = value . dtype
actual_shape = value . shape
if actual_dtype != np . dtype ( expected_dtype ) :
error_message + = f " The feature ' { name } ' of dtype ' { actual_dtype } ' is not of the expected dtype ' { expected_dtype } ' . \n "
if actual_shape != expected_shape :
error_message + = f " The feature ' { name } ' of shape ' { actual_shape } ' does not have the expected shape ' { expected_shape } ' . \n "
else :
error_message + = f " The feature ' { name } ' is not a ' np.ndarray ' . Expected type is ' { expected_dtype } ' , but type ' { type ( value ) } ' provided instead. \n "
return error_message
def validate_feature_image_or_video ( name : str , expected_shape : list [ str ] , value : np . ndarray | PILImage . Image ) :
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = " "
if isinstance ( value , np . ndarray ) :
actual_shape = value . shape
c , h , w = expected_shape
if len ( actual_shape ) != 3 or ( actual_shape != ( c , h , w ) and actual_shape != ( h , w , c ) ) :
error_message + = f " The feature ' { name } ' of shape ' { actual_shape } ' does not have the expected shape ' { ( c , h , w ) } ' or ' { ( h , w , c ) } ' . \n "
elif isinstance ( value , PILImage . Image ) :
pass
else :
error_message + = f " The feature ' { name } ' is expected to be of type ' PIL.Image ' or ' np.ndarray ' channel first or channel last, but type ' { type ( value ) } ' provided instead. \n "
return error_message
def validate_feature_string ( name : str , value : str ) :
if not isinstance ( value , str ) :
return f " The feature ' { name } ' is expected to be of type ' str ' , but type ' { type ( value ) } ' provided instead. \n "
return " "
def validate_episode_buffer ( episode_buffer : dict , total_episodes : int , features : dict ) :
if " size " not in episode_buffer :
raise ValueError ( " size key not found in episode_buffer " )
if " task " not in episode_buffer :
raise ValueError ( " task key not found in episode_buffer " )
if episode_buffer [ " episode_index " ] != total_episodes :
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError (
" You might have manually provided the episode_buffer with an episode_index that doesn ' t "
" match the total number of episodes already in the dataset. This is not supported for now. "
)
if episode_buffer [ " size " ] == 0 :
raise ValueError ( " You must add one or several frames with `add_frame` before calling `add_episode`. " )
buffer_keys = set ( episode_buffer . keys ( ) ) - { " task " , " size " }
if not buffer_keys == set ( features ) :
raise ValueError (
f " Features from `episode_buffer` don ' t match the ones in `features`. "
f " In episode_buffer not in features: { buffer_keys - set ( features ) } "
f " In features not in episode_buffer: { set ( features ) - buffer_keys } "
)