2024-08-08 20:19:06 +03:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note : The last frame of the episode doesnt always correspond to a final state .
That ' s because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state .
However , there might not be a transition from a final state to another state .
Note : This script aims to visualize the data used to train the neural networks .
~ What you see is what you get ~ . When visualizing image modality , it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space . The compression factor applied has been tuned to not affect success rate .
Example of usage :
- Visualize data stored on a local machine :
` ` ` bash
local $ python lerobot / scripts / visualize_dataset_html . py \
- - repo - id lerobot / pusht
local $ open http : / / localhost : 9090
` ` `
- Visualize data stored on a distant machine with a local viewer :
` ` ` bash
distant $ python lerobot / scripts / visualize_dataset_html . py \
- - repo - id lerobot / pusht
local $ ssh - L 9090 : localhost : 9090 distant # create a ssh tunnel
local $ open http : / / localhost : 9090
` ` `
- Select episodes to visualize :
` ` ` bash
python lerobot / scripts / visualize_dataset_html . py \
- - repo - id lerobot / pusht \
- - episodes 7 3 5 1 4
` ` `
"""
import argparse
2024-12-20 16:26:23 +01:00
import csv
import json
2024-08-08 20:19:06 +03:00
import logging
2024-12-20 16:26:23 +01:00
import re
2024-08-08 20:19:06 +03:00
import shutil
2024-12-20 16:26:23 +01:00
import tempfile
from io import StringIO
2024-08-08 20:19:06 +03:00
from pathlib import Path
2024-12-20 16:26:23 +01:00
import numpy as np
import pandas as pd
import requests
from flask import Flask , redirect , render_template , request , url_for
2024-08-08 20:19:06 +03:00
2024-12-20 16:26:23 +01:00
from lerobot import available_datasets
2024-08-08 20:19:06 +03:00
from lerobot . common . datasets . lerobot_dataset import LeRobotDataset
2024-12-20 16:26:23 +01:00
from lerobot . common . datasets . utils import IterableNamespace
2024-08-08 20:19:06 +03:00
from lerobot . common . utils . utils import init_logging
def run_server (
2024-12-20 16:26:23 +01:00
dataset : LeRobotDataset | IterableNamespace | None ,
episodes : list [ int ] | None ,
2024-08-08 20:19:06 +03:00
host : str ,
port : str ,
static_folder : Path ,
template_folder : Path ,
) :
app = Flask ( __name__ , static_folder = static_folder . resolve ( ) , template_folder = template_folder . resolve ( ) )
app . config [ " SEND_FILE_MAX_AGE_DEFAULT " ] = 0 # specifying not to cache
@app.route ( " / " )
2024-12-20 16:26:23 +01:00
def hommepage ( dataset = dataset ) :
if dataset :
dataset_namespace , dataset_name = dataset . repo_id . split ( " / " )
return redirect (
url_for (
" show_episode " ,
dataset_namespace = dataset_namespace ,
dataset_name = dataset_name ,
episode_id = 0 ,
)
)
dataset_param , episode_param = None , None
all_params = request . args
if " dataset " in all_params :
dataset_param = all_params [ " dataset " ]
if " episode " in all_params :
episode_param = int ( all_params [ " episode " ] )
if dataset_param :
dataset_namespace , dataset_name = dataset_param . split ( " / " )
return redirect (
url_for (
" show_episode " ,
dataset_namespace = dataset_namespace ,
dataset_name = dataset_name ,
episode_id = episode_param if episode_param is not None else 0 ,
)
)
featured_datasets = [
" lerobot/aloha_static_cups_open " ,
" lerobot/columbia_cairlab_pusht_real " ,
" lerobot/taco_play " ,
]
return render_template (
" visualize_dataset_homepage.html " ,
featured_datasets = featured_datasets ,
lerobot_datasets = available_datasets ,
)
@app.route ( " /<string:dataset_namespace>/<string:dataset_name> " )
def show_first_episode ( dataset_namespace , dataset_name ) :
first_episode_id = 0
2024-08-08 20:19:06 +03:00
return redirect (
url_for (
" show_episode " ,
dataset_namespace = dataset_namespace ,
dataset_name = dataset_name ,
episode_id = first_episode_id ,
)
)
@app.route ( " /<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id> " )
2024-12-20 16:26:23 +01:00
def show_episode ( dataset_namespace , dataset_name , episode_id , dataset = dataset , episodes = episodes ) :
repo_id = f " { dataset_namespace } / { dataset_name } "
try :
if dataset is None :
dataset = get_dataset_info ( repo_id )
except FileNotFoundError :
return (
" Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461 " ,
400 ,
)
dataset_version = (
2025-02-25 15:27:29 +01:00
str ( dataset . meta . _version ) if isinstance ( dataset , LeRobotDataset ) else dataset . codebase_version
2024-12-20 16:26:23 +01:00
)
match = re . search ( r " v( \ d+) \ . " , dataset_version )
if match :
major_version = int ( match . group ( 1 ) )
if major_version < 2 :
return " Make sure to convert your LeRobotDataset to v2 & above. "
episode_data_csv_str , columns = get_episode_data ( dataset , episode_id )
2024-08-08 20:19:06 +03:00
dataset_info = {
2024-12-20 16:26:23 +01:00
" repo_id " : f " { dataset_namespace } / { dataset_name } " ,
" num_samples " : dataset . num_frames
if isinstance ( dataset , LeRobotDataset )
else dataset . total_frames ,
" num_episodes " : dataset . num_episodes
if isinstance ( dataset , LeRobotDataset )
else dataset . total_episodes ,
2024-08-08 20:19:06 +03:00
" fps " : dataset . fps ,
}
2024-12-20 16:26:23 +01:00
if isinstance ( dataset , LeRobotDataset ) :
video_paths = [
dataset . meta . get_video_file_path ( episode_id , key ) for key in dataset . meta . video_keys
]
videos_info = [
{ " url " : url_for ( " static " , filename = video_path ) , " filename " : video_path . parent . name }
for video_path in video_paths
]
2025-01-09 08:39:48 +00:00
tasks = dataset . meta . episodes [ episode_id ] [ " tasks " ]
2024-12-20 16:26:23 +01:00
else :
video_keys = [ key for key , ft in dataset . features . items ( ) if ft [ " dtype " ] == " video " ]
videos_info = [
{
" url " : f " https://huggingface.co/datasets/ { repo_id } /resolve/main/ "
+ dataset . video_path . format (
episode_chunk = int ( episode_id ) / / dataset . chunks_size ,
video_key = video_key ,
episode_index = episode_id ,
) ,
" filename " : video_key ,
}
for video_key in video_keys
]
response = requests . get (
f " https://huggingface.co/datasets/ { repo_id } /resolve/main/meta/episodes.jsonl "
)
response . raise_for_status ( )
# Split into lines and parse each line as JSON
tasks_jsonl = [ json . loads ( line ) for line in response . text . splitlines ( ) if line . strip ( ) ]
filtered_tasks_jsonl = [ row for row in tasks_jsonl if row [ " episode_index " ] == episode_id ]
tasks = filtered_tasks_jsonl [ 0 ] [ " tasks " ]
2024-11-29 19:04:00 +01:00
videos_info [ 0 ] [ " language_instruction " ] = tasks
2024-08-28 11:50:31 +02:00
2024-12-20 16:26:23 +01:00
if episodes is None :
episodes = list (
range ( dataset . num_episodes if isinstance ( dataset , LeRobotDataset ) else dataset . total_episodes )
)
2024-08-08 20:19:06 +03:00
return render_template (
" visualize_dataset_template.html " ,
episode_id = episode_id ,
episodes = episodes ,
dataset_info = dataset_info ,
videos_info = videos_info ,
2024-12-20 16:26:23 +01:00
episode_data_csv_str = episode_data_csv_str ,
columns = columns ,
2024-08-08 20:19:06 +03:00
)
app . run ( host = host , port = port )
def get_ep_csv_fname ( episode_id : int ) :
ep_csv_fname = f " episode_ { episode_id } .csv "
return ep_csv_fname
2024-12-20 16:26:23 +01:00
def get_episode_data ( dataset : LeRobotDataset | IterableNamespace , episode_index ) :
""" Get a csv str containing timeseries data of an episode (e.g. state and action).
2024-08-08 20:19:06 +03:00
This file will be loaded by Dygraph javascript to plot data in real time . """
2024-12-20 16:26:23 +01:00
columns = [ ]
2025-01-09 11:39:54 +01:00
selected_columns = [ col for col , ft in dataset . features . items ( ) if ft [ " dtype " ] == " float32 " ]
selected_columns . remove ( " timestamp " )
2024-08-08 20:19:06 +03:00
# init header of csv with state and action names
header = [ " timestamp " ]
2025-01-09 11:39:54 +01:00
for column_name in selected_columns :
2024-12-20 16:26:23 +01:00
dim_state = (
2025-01-09 11:39:54 +01:00
dataset . meta . shapes [ column_name ] [ 0 ]
2024-12-20 16:26:23 +01:00
if isinstance ( dataset , LeRobotDataset )
2025-01-09 11:39:54 +01:00
else dataset . features [ column_name ] . shape [ 0 ]
2024-12-20 16:26:23 +01:00
)
2025-01-09 11:39:54 +01:00
if " names " in dataset . features [ column_name ] and dataset . features [ column_name ] [ " names " ] :
column_names = dataset . features [ column_name ] [ " names " ]
while not isinstance ( column_names , list ) :
column_names = list ( column_names . values ( ) ) [ 0 ]
else :
2025-02-27 14:47:18 +01:00
column_names = [ f " { column_name } _ { i } " for i in range ( dim_state ) ]
2025-01-09 11:39:54 +01:00
columns . append ( { " key " : column_name , " value " : column_names } )
2025-02-27 14:47:18 +01:00
header + = column_names
2025-01-09 11:39:54 +01:00
selected_columns . insert ( 0 , " timestamp " )
2024-12-20 16:26:23 +01:00
if isinstance ( dataset , LeRobotDataset ) :
from_idx = dataset . episode_data_index [ " from " ] [ episode_index ]
to_idx = dataset . episode_data_index [ " to " ] [ episode_index ]
data = (
dataset . hf_dataset . select ( range ( from_idx , to_idx ) )
. select_columns ( selected_columns )
2025-01-09 11:39:54 +01:00
. with_format ( " pandas " )
2024-12-20 16:26:23 +01:00
)
else :
repo_id = dataset . repo_id
url = f " https://huggingface.co/datasets/ { repo_id } /resolve/main/ " + dataset . data_path . format (
episode_chunk = int ( episode_index ) / / dataset . chunks_size , episode_index = episode_index
)
df = pd . read_parquet ( url )
data = df [ selected_columns ] # Select specific columns
2025-01-09 11:39:54 +01:00
rows = np . hstack (
(
np . expand_dims ( data [ " timestamp " ] , axis = 1 ) ,
* [ np . vstack ( data [ col ] ) for col in selected_columns [ 1 : ] ] ,
)
) . tolist ( )
2024-08-08 20:19:06 +03:00
2024-12-20 16:26:23 +01:00
# Convert data to CSV string
csv_buffer = StringIO ( )
csv_writer = csv . writer ( csv_buffer )
# Write header
csv_writer . writerow ( header )
# Write data rows
csv_writer . writerows ( rows )
csv_string = csv_buffer . getvalue ( )
return csv_string , columns
2024-08-08 20:19:06 +03:00
def get_episode_video_paths ( dataset : LeRobotDataset , ep_index : int ) - > list [ str ] :
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset . episode_data_index [ " from " ] [ ep_index ] . item ( )
return [
dataset . hf_dataset . select_columns ( key ) [ first_frame_idx ] [ key ] [ " path " ]
2024-11-29 19:04:00 +01:00
for key in dataset . meta . video_keys
2024-08-08 20:19:06 +03:00
]
2024-12-20 16:26:23 +01:00
def get_episode_language_instruction ( dataset : LeRobotDataset , ep_index : int ) - > list [ str ] :
# check if the dataset has language instructions
if " language_instruction " not in dataset . features :
return None
# get first frame index
first_frame_idx = dataset . episode_data_index [ " from " ] [ ep_index ] . item ( )
language_instruction = dataset . hf_dataset [ first_frame_idx ] [ " language_instruction " ]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction . removeprefix ( " tf.Tensor(b ' " ) . removesuffix ( " ' , shape=(), dtype=string) " )
def get_dataset_info ( repo_id : str ) - > IterableNamespace :
response = requests . get ( f " https://huggingface.co/datasets/ { repo_id } /resolve/main/meta/info.json " )
response . raise_for_status ( ) # Raises an HTTPError for bad responses
dataset_info = response . json ( )
dataset_info [ " repo_id " ] = repo_id
return IterableNamespace ( dataset_info )
2024-08-08 20:19:06 +03:00
def visualize_dataset_html (
2024-12-20 16:26:23 +01:00
dataset : LeRobotDataset | None ,
episodes : list [ int ] | None = None ,
2024-08-08 20:19:06 +03:00
output_dir : Path | None = None ,
serve : bool = True ,
host : str = " 127.0.0.1 " ,
port : int = 9090 ,
force_override : bool = False ,
) - > Path | None :
init_logging ( )
2024-12-20 16:26:23 +01:00
template_dir = Path ( __file__ ) . resolve ( ) . parent . parent / " templates "
2024-08-08 20:19:06 +03:00
if output_dir is None :
2024-12-20 16:26:23 +01:00
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile . mkdtemp ( prefix = " lerobot_visualize_dataset_ " )
2024-08-08 20:19:06 +03:00
output_dir = Path ( output_dir )
if output_dir . exists ( ) :
if force_override :
shutil . rmtree ( output_dir )
else :
logging . info ( f " Output directory already exists. Loading from it: ' { output_dir } ' " )
output_dir . mkdir ( parents = True , exist_ok = True )
static_dir = output_dir / " static "
static_dir . mkdir ( parents = True , exist_ok = True )
2024-12-20 16:26:23 +01:00
if dataset is None :
if serve :
run_server (
dataset = None ,
episodes = None ,
host = host ,
port = port ,
static_folder = static_dir ,
template_folder = template_dir ,
)
else :
2025-02-25 23:51:15 +01:00
# Create a simlink from the dataset video folder containing mp4 files to the output directory
2024-12-20 16:26:23 +01:00
# so that the http server can get access to the mp4 files.
if isinstance ( dataset , LeRobotDataset ) :
ln_videos_dir = static_dir / " videos "
if not ln_videos_dir . exists ( ) :
ln_videos_dir . symlink_to ( ( dataset . root / " videos " ) . resolve ( ) )
2024-08-08 20:19:06 +03:00
2024-12-20 16:26:23 +01:00
if serve :
run_server ( dataset , episodes , host , port , static_dir , template_dir )
2024-08-08 20:19:06 +03:00
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --repo-id " ,
type = str ,
2024-12-20 16:26:23 +01:00
default = None ,
2024-08-08 20:19:06 +03:00
help = " Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht). " ,
)
parser . add_argument (
" --root " ,
type = Path ,
default = None ,
help = " Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available. " ,
)
2024-12-20 16:26:23 +01:00
parser . add_argument (
" --load-from-hf-hub " ,
type = int ,
default = 0 ,
help = " Load videos and parquet files from HF Hub rather than local system. " ,
)
2024-08-08 20:19:06 +03:00
parser . add_argument (
" --episodes " ,
type = int ,
nargs = " * " ,
default = None ,
help = " Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes. " ,
)
parser . add_argument (
" --output-dir " ,
type = Path ,
default = None ,
help = " Directory path to write html files and kickoff a web server. By default write them to ' outputs/visualize_dataset/REPO_ID ' . " ,
)
parser . add_argument (
" --serve " ,
type = int ,
default = 1 ,
help = " Launch web server. " ,
)
parser . add_argument (
" --host " ,
type = str ,
default = " 127.0.0.1 " ,
help = " Web host used by the http server. " ,
)
parser . add_argument (
" --port " ,
type = int ,
default = 9090 ,
help = " Web port used by the http server. " ,
)
parser . add_argument (
" --force-override " ,
type = int ,
default = 0 ,
help = " Delete the output directory if it exists already. " ,
)
args = parser . parse_args ( )
2024-11-29 19:04:00 +01:00
kwargs = vars ( args )
repo_id = kwargs . pop ( " repo_id " )
2024-12-20 16:26:23 +01:00
load_from_hf_hub = kwargs . pop ( " load_from_hf_hub " )
2024-11-29 19:04:00 +01:00
root = kwargs . pop ( " root " )
2024-12-20 16:26:23 +01:00
dataset = None
if repo_id :
2025-02-25 15:27:29 +01:00
dataset = LeRobotDataset ( repo_id , root = root ) if not load_from_hf_hub else get_dataset_info ( repo_id )
2024-12-20 16:26:23 +01:00
visualize_dataset_html ( dataset , * * vars ( args ) )
2024-08-08 20:19:06 +03:00
if __name__ == " __main__ " :
main ( )