2024-05-04 16:07:14 +02:00
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
Note : The last frame of the episode doesnt always correspond to a final state .
That ' s because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state .
However , there might not be a transition from a final state to another state .
Note : This script aims to visualize the data used to train the neural networks .
~ What you see is what you get ~ . When visualizing image modality , it is often expected to observe
lossly compression artifacts since these images have been decoded from compressed mp4 videos to
save disk space . The compression factor applied has been tuned to not affect success rate .
Examples :
- Visualize data stored on a local machine :
` ` `
local $ python lerobot / scripts / visualize_dataset . py \
- - repo - id lerobot / pusht \
- - episode - index 0
` ` `
- Visualize data stored on a distant machine with a local viewer :
` ` `
distant $ python lerobot / scripts / visualize_dataset . py \
- - repo - id lerobot / pusht \
- - episode - index 0 \
- - save 1 \
- - output - dir path / to / directory
local $ scp distant : path / to / directory / lerobot_pusht_episode_0 . rrd .
local $ rerun lerobot_pusht_episode_0 . rrd
` ` `
- Visualize data stored on a distant machine through streaming :
2024-05-06 03:03:14 +02:00
( You need to forward the websocket port to the distant machine , with
2024-05-04 16:07:14 +02:00
` ssh - L 9087 : localhost : 9087 username @remote - host ` )
` ` `
distant $ python lerobot / scripts / visualize_dataset . py \
- - repo - id lerobot / pusht \
- - episode - index 0 \
- - mode distant \
- - ws - port 9087
local $ rerun ws : / / localhost : 9087
` ` `
"""
import argparse
2024-05-11 19:28:22 +03:00
import gc
2024-03-06 10:14:03 +00:00
import logging
2024-05-04 16:07:14 +02:00
import time
2024-02-10 15:46:24 +00:00
from pathlib import Path
2024-05-04 16:07:14 +02:00
import rerun as rr
2024-02-10 15:46:24 +00:00
import torch
2024-05-04 16:07:14 +02:00
import tqdm
from lerobot . common . datasets . lerobot_dataset import LeRobotDataset
class EpisodeSampler ( torch . utils . data . Sampler ) :
def __init__ ( self , dataset , episode_index ) :
from_idx = dataset . episode_data_index [ " from " ] [ episode_index ] . item ( )
to_idx = dataset . episode_data_index [ " to " ] [ episode_index ] . item ( )
self . frame_ids = range ( from_idx , to_idx )
def __iter__ ( self ) :
return iter ( self . frame_ids )
def __len__ ( self ) :
return len ( self . frame_ids )
def to_hwc_uint8_numpy ( chw_float32_torch ) :
assert chw_float32_torch . dtype == torch . float32
assert chw_float32_torch . ndim == 3
c , h , w = chw_float32_torch . shape
assert c < h and c < w , f " expect channel first images, but instead { chw_float32_torch . shape } "
hwc_uint8_numpy = ( chw_float32_torch * 255 ) . type ( torch . uint8 ) . permute ( 1 , 2 , 0 ) . numpy ( )
return hwc_uint8_numpy
def visualize_dataset (
repo_id : str ,
episode_index : int ,
batch_size : int = 32 ,
num_workers : int = 0 ,
mode : str = " local " ,
web_port : int = 9090 ,
ws_port : int = 9087 ,
save : bool = False ,
output_dir : Path | None = None ,
) - > Path | None :
if save :
assert (
output_dir is not None
) , " Set an output directory where to write .rrd files with `--output-dir path/to/directory`. "
logging . info ( " Loading dataset " )
dataset = LeRobotDataset ( repo_id )
logging . info ( " Loading dataloader " )
episode_sampler = EpisodeSampler ( dataset , episode_index )
2024-04-10 13:45:45 +00:00
dataloader = torch . utils . data . DataLoader (
dataset ,
2024-05-04 16:07:14 +02:00
num_workers = num_workers ,
batch_size = batch_size ,
sampler = episode_sampler ,
2024-04-10 13:45:45 +00:00
)
2024-05-04 16:07:14 +02:00
logging . info ( " Starting Rerun " )
if mode not in [ " local " , " distant " ] :
raise ValueError ( mode )
spawn_local_viewer = mode == " local " and not save
rr . init ( f " { repo_id } /episode_ { episode_index } " , spawn = spawn_local_viewer )
2024-05-11 19:28:22 +03:00
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc . collect ( )
2024-05-04 16:07:14 +02:00
if mode == " distant " :
rr . serve ( open_browser = False , web_port = web_port , ws_port = ws_port )
logging . info ( " Logging to Rerun " )
for batch in tqdm . tqdm ( dataloader , total = len ( dataloader ) ) :
# iterate over the batch
for i in range ( len ( batch [ " index " ] ) ) :
rr . set_time_sequence ( " frame_index " , batch [ " frame_index " ] [ i ] . item ( ) )
rr . set_time_seconds ( " timestamp " , batch [ " timestamp " ] [ i ] . item ( ) )
# display each camera image
2024-05-06 03:03:14 +02:00
for key in dataset . camera_keys :
2024-05-04 16:07:14 +02:00
# TODO(rcadene): add `.compress()`? is it lossless?
rr . log ( key , rr . Image ( to_hwc_uint8_numpy ( batch [ key ] [ i ] ) ) )
# display each dimension of action space (e.g. actuators command)
if " action " in batch :
for dim_idx , val in enumerate ( batch [ " action " ] [ i ] ) :
rr . log ( f " action/ { dim_idx } " , rr . Scalar ( val . item ( ) ) )
# display each dimension of observed state space (e.g. agent position in joint space)
if " observation.state " in batch :
for dim_idx , val in enumerate ( batch [ " observation.state " ] [ i ] ) :
rr . log ( f " state/ { dim_idx } " , rr . Scalar ( val . item ( ) ) )
if " next.done " in batch :
rr . log ( " next.done " , rr . Scalar ( batch [ " next.done " ] [ i ] . item ( ) ) )
if " next.reward " in batch :
rr . log ( " next.reward " , rr . Scalar ( batch [ " next.reward " ] [ i ] . item ( ) ) )
if " next.success " in batch :
rr . log ( " next.success " , rr . Scalar ( batch [ " next.success " ] [ i ] . item ( ) ) )
if mode == " local " and save :
# save .rrd locally
output_dir = Path ( output_dir )
output_dir . mkdir ( parents = True , exist_ok = True )
repo_id_str = repo_id . replace ( " / " , " _ " )
rrd_path = output_dir / f " { repo_id_str } _episode_ { episode_index } .rrd "
rr . save ( rrd_path )
return rrd_path
elif mode == " distant " :
# stop the process from exiting since it is serving the websocket connection
try :
while True :
time . sleep ( 1 )
except KeyboardInterrupt :
print ( " Ctrl-C received. Exiting. " )
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --repo-id " ,
type = str ,
required = True ,
help = " Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht`). " ,
)
parser . add_argument (
" --episode-index " ,
type = int ,
required = True ,
help = " Episode to visualize. " ,
)
parser . add_argument (
" --batch-size " ,
type = int ,
default = 32 ,
help = " Batch size loaded by DataLoader. " ,
)
parser . add_argument (
" --num-workers " ,
type = int ,
2024-05-11 19:28:22 +03:00
default = 4 ,
2024-05-04 16:07:14 +02:00
help = " Number of processes of Dataloader for loading the data. " ,
)
parser . add_argument (
" --mode " ,
type = str ,
default = " local " ,
help = (
" Mode of viewing between ' local ' or ' distant ' . "
" ' local ' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
" ' distant ' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine. "
) ,
)
parser . add_argument (
" --web-port " ,
type = int ,
default = 9090 ,
help = " Web port for rerun.io when `--mode distant` is set. " ,
)
parser . add_argument (
" --ws-port " ,
type = int ,
default = 9087 ,
help = " Web socket port for rerun.io when `--mode distant` is set. " ,
)
parser . add_argument (
" --save " ,
type = int ,
default = 0 ,
help = (
" Save a .rrd file in the directory provided by `--output-dir`. "
" It also deactivates the spawning of a viewer. " ,
" Visualize the data by running `rerun path/to/file.rrd` on your local machine. " ,
) ,
)
parser . add_argument (
" --output-dir " ,
type = str ,
help = " Directory path to write a .rrd file when `--save 1` is set. " ,
)
2024-03-06 10:14:03 +00:00
2024-05-04 16:07:14 +02:00
args = parser . parse_args ( )
visualize_dataset ( * * vars ( args ) )
2024-03-06 10:14:03 +00:00
2024-02-10 15:46:24 +00:00
if __name__ == " __main__ " :
2024-05-04 16:07:14 +02:00
main ( )