2024-08-08 20:19:06 +03:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note : The last frame of the episode doesnt always correspond to a final state .
That ' s because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state .
However , there might not be a transition from a final state to another state .
Note : This script aims to visualize the data used to train the neural networks .
~ What you see is what you get ~ . When visualizing image modality , it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space . The compression factor applied has been tuned to not affect success rate .
Example of usage :
- Visualize data stored on a local machine :
` ` ` bash
local $ python lerobot / scripts / visualize_dataset_html . py \
- - repo - id lerobot / pusht
local $ open http : / / localhost : 9090
` ` `
- Visualize data stored on a distant machine with a local viewer :
` ` ` bash
distant $ python lerobot / scripts / visualize_dataset_html . py \
- - repo - id lerobot / pusht
local $ ssh - L 9090 : localhost : 9090 distant # create a ssh tunnel
local $ open http : / / localhost : 9090
` ` `
- Select episodes to visualize :
` ` ` bash
python lerobot / scripts / visualize_dataset_html . py \
- - repo - id lerobot / pusht \
- - episodes 7 3 5 1 4
` ` `
"""
import argparse
import logging
import shutil
from pathlib import Path
import tqdm
from flask import Flask , redirect , render_template , url_for
from lerobot . common . datasets . lerobot_dataset import LeRobotDataset
from lerobot . common . utils . utils import init_logging
def run_server (
dataset : LeRobotDataset ,
episodes : list [ int ] ,
host : str ,
port : str ,
static_folder : Path ,
template_folder : Path ,
) :
app = Flask ( __name__ , static_folder = static_folder . resolve ( ) , template_folder = template_folder . resolve ( ) )
app . config [ " SEND_FILE_MAX_AGE_DEFAULT " ] = 0 # specifying not to cache
@app.route ( " / " )
def index ( ) :
# home page redirects to the first episode page
[ dataset_namespace , dataset_name ] = dataset . repo_id . split ( " / " )
first_episode_id = episodes [ 0 ]
return redirect (
url_for (
" show_episode " ,
dataset_namespace = dataset_namespace ,
dataset_name = dataset_name ,
episode_id = first_episode_id ,
)
)
@app.route ( " /<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id> " )
def show_episode ( dataset_namespace , dataset_name , episode_id ) :
dataset_info = {
" repo_id " : dataset . repo_id ,
" num_samples " : dataset . num_samples ,
" num_episodes " : dataset . num_episodes ,
" fps " : dataset . fps ,
}
video_paths = get_episode_video_paths ( dataset , episode_id )
2024-08-28 11:50:31 +02:00
language_instruction = get_episode_language_instruction ( dataset , episode_id )
2024-08-08 20:19:06 +03:00
videos_info = [
{ " url " : url_for ( " static " , filename = video_path ) , " filename " : Path ( video_path ) . name }
for video_path in video_paths
]
2024-08-28 11:50:31 +02:00
if language_instruction :
videos_info [ 0 ] [ " language_instruction " ] = language_instruction
2024-08-08 20:19:06 +03:00
ep_csv_url = url_for ( " static " , filename = get_ep_csv_fname ( episode_id ) )
return render_template (
" visualize_dataset_template.html " ,
episode_id = episode_id ,
episodes = episodes ,
dataset_info = dataset_info ,
videos_info = videos_info ,
ep_csv_url = ep_csv_url ,
has_policy = False ,
)
app . run ( host = host , port = port )
def get_ep_csv_fname ( episode_id : int ) :
ep_csv_fname = f " episode_ { episode_id } .csv "
return ep_csv_fname
def write_episode_data_csv ( output_dir , file_name , episode_index , dataset ) :
""" Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time . """
from_idx = dataset . episode_data_index [ " from " ] [ episode_index ]
to_idx = dataset . episode_data_index [ " to " ] [ episode_index ]
has_state = " observation.state " in dataset . hf_dataset . features
has_action = " action " in dataset . hf_dataset . features
# init header of csv with state and action names
header = [ " timestamp " ]
if has_state :
dim_state = len ( dataset . hf_dataset [ " observation.state " ] [ 0 ] )
header + = [ f " state_ { i } " for i in range ( dim_state ) ]
if has_action :
dim_action = len ( dataset . hf_dataset [ " action " ] [ 0 ] )
header + = [ f " action_ { i } " for i in range ( dim_action ) ]
columns = [ " timestamp " ]
if has_state :
columns + = [ " observation.state " ]
if has_action :
columns + = [ " action " ]
rows = [ ]
data = dataset . hf_dataset . select_columns ( columns )
for i in range ( from_idx , to_idx ) :
row = [ data [ i ] [ " timestamp " ] . item ( ) ]
if has_state :
row + = data [ i ] [ " observation.state " ] . tolist ( )
if has_action :
row + = data [ i ] [ " action " ] . tolist ( )
rows . append ( row )
output_dir . mkdir ( parents = True , exist_ok = True )
with open ( output_dir / file_name , " w " ) as f :
f . write ( " , " . join ( header ) + " \n " )
for row in rows :
row_str = [ str ( col ) for col in row ]
f . write ( " , " . join ( row_str ) + " \n " )
def get_episode_video_paths ( dataset : LeRobotDataset , ep_index : int ) - > list [ str ] :
# get first frame of episode (hack to get video_path of the episode)
first_frame_idx = dataset . episode_data_index [ " from " ] [ ep_index ] . item ( )
return [
dataset . hf_dataset . select_columns ( key ) [ first_frame_idx ] [ key ] [ " path " ]
for key in dataset . video_frame_keys
]
2024-08-28 11:50:31 +02:00
def get_episode_language_instruction ( dataset : LeRobotDataset , ep_index : int ) - > list [ str ] :
# check if the dataset has language instructions
if " language_instruction " not in dataset . hf_dataset . features :
return None
# get first frame index
first_frame_idx = dataset . episode_data_index [ " from " ] [ ep_index ] . item ( )
language_instruction = dataset . hf_dataset [ first_frame_idx ] [ " language_instruction " ]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction . removeprefix ( " tf.Tensor(b ' " ) . removesuffix ( " ' , shape=(), dtype=string) " )
2024-08-08 20:19:06 +03:00
def visualize_dataset_html (
repo_id : str ,
root : Path | None = None ,
episodes : list [ int ] = None ,
output_dir : Path | None = None ,
serve : bool = True ,
host : str = " 127.0.0.1 " ,
port : int = 9090 ,
force_override : bool = False ,
) - > Path | None :
init_logging ( )
dataset = LeRobotDataset ( repo_id , root = root )
if not dataset . video :
raise NotImplementedError ( f " Image datasets ( { dataset . video =} ) are currently not supported. " )
if output_dir is None :
output_dir = f " outputs/visualize_dataset_html/ { repo_id } "
output_dir = Path ( output_dir )
if output_dir . exists ( ) :
if force_override :
shutil . rmtree ( output_dir )
else :
logging . info ( f " Output directory already exists. Loading from it: ' { output_dir } ' " )
output_dir . mkdir ( parents = True , exist_ok = True )
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
static_dir = output_dir / " static "
static_dir . mkdir ( parents = True , exist_ok = True )
ln_videos_dir = static_dir / " videos "
if not ln_videos_dir . exists ( ) :
ln_videos_dir . symlink_to ( dataset . videos_dir . resolve ( ) )
template_dir = Path ( __file__ ) . resolve ( ) . parent . parent / " templates "
if episodes is None :
episodes = list ( range ( dataset . num_episodes ) )
logging . info ( " Writing CSV files " )
for episode_index in tqdm . tqdm ( episodes ) :
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname ( episode_index )
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv ( static_dir , ep_csv_fname , episode_index , dataset )
if serve :
run_server ( dataset , episodes , host , port , static_dir , template_dir )
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --repo-id " ,
type = str ,
required = True ,
help = " Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht). " ,
)
parser . add_argument (
" --root " ,
type = Path ,
default = None ,
help = " Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available. " ,
)
parser . add_argument (
" --episodes " ,
type = int ,
nargs = " * " ,
default = None ,
help = " Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes. " ,
)
parser . add_argument (
" --output-dir " ,
type = Path ,
default = None ,
help = " Directory path to write html files and kickoff a web server. By default write them to ' outputs/visualize_dataset/REPO_ID ' . " ,
)
parser . add_argument (
" --serve " ,
type = int ,
default = 1 ,
help = " Launch web server. " ,
)
parser . add_argument (
" --host " ,
type = str ,
default = " 127.0.0.1 " ,
help = " Web host used by the http server. " ,
)
parser . add_argument (
" --port " ,
type = int ,
default = 9090 ,
help = " Web port used by the http server. " ,
)
parser . add_argument (
" --force-override " ,
type = int ,
default = 0 ,
help = " Delete the output directory if it exists already. " ,
)
args = parser . parse_args ( )
visualize_dataset_html ( * * vars ( args ) )
if __name__ == " __main__ " :
main ( )