2024-05-15 12:13:09 +02:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2024-05-04 16:07:14 +02:00
""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset.
2025-01-28 20:07:10 +08:00
Note : The last frame of the episode doesn ' t always correspond to a final state.
2024-05-04 16:07:14 +02:00
That ' s because our datasets are composed of transition from state to state up to
the antepenultimate state associated to the ultimate action to arrive in the final state .
However , there might not be a transition from a final state to another state .
Note : This script aims to visualize the data used to train the neural networks .
~ What you see is what you get ~ . When visualizing image modality , it is often expected to observe
2025-01-28 20:07:10 +08:00
lossy compression artifacts since these images have been decoded from compressed mp4 videos to
2024-05-04 16:07:14 +02:00
save disk space . The compression factor applied has been tuned to not affect success rate .
Examples :
- Visualize data stored on a local machine :
` ` `
local $ python lerobot / scripts / visualize_dataset . py \
- - repo - id lerobot / pusht \
- - episode - index 0
` ` `
- Visualize data stored on a distant machine with a local viewer :
` ` `
distant $ python lerobot / scripts / visualize_dataset . py \
- - repo - id lerobot / pusht \
- - episode - index 0 \
- - save 1 \
- - output - dir path / to / directory
local $ scp distant : path / to / directory / lerobot_pusht_episode_0 . rrd .
local $ rerun lerobot_pusht_episode_0 . rrd
` ` `
- Visualize data stored on a distant machine through streaming :
2024-05-06 03:03:14 +02:00
( You need to forward the websocket port to the distant machine , with
2024-05-04 16:07:14 +02:00
` ssh - L 9087 : localhost : 9087 username @remote - host ` )
` ` `
distant $ python lerobot / scripts / visualize_dataset . py \
- - repo - id lerobot / pusht \
- - episode - index 0 \
- - mode distant \
- - ws - port 9087
local $ rerun ws : / / localhost : 9087
` ` `
"""
import argparse
2024-05-11 19:28:22 +03:00
import gc
2024-03-06 10:14:03 +00:00
import logging
2024-05-04 16:07:14 +02:00
import time
2024-02-10 15:46:24 +00:00
from pathlib import Path
2024-06-10 19:09:48 +01:00
from typing import Iterator
2024-02-10 15:46:24 +00:00
2024-06-10 19:09:48 +01:00
import numpy as np
2024-05-04 16:07:14 +02:00
import rerun as rr
2024-02-10 15:46:24 +00:00
import torch
2024-06-10 19:09:48 +01:00
import torch . utils . data
2024-05-04 16:07:14 +02:00
import tqdm
from lerobot . common . datasets . lerobot_dataset import LeRobotDataset
class EpisodeSampler ( torch . utils . data . Sampler ) :
2024-06-10 19:09:48 +01:00
def __init__ ( self , dataset : LeRobotDataset , episode_index : int ) :
2024-05-04 16:07:14 +02:00
from_idx = dataset . episode_data_index [ " from " ] [ episode_index ] . item ( )
to_idx = dataset . episode_data_index [ " to " ] [ episode_index ] . item ( )
self . frame_ids = range ( from_idx , to_idx )
2024-06-10 19:09:48 +01:00
def __iter__ ( self ) - > Iterator :
2024-05-04 16:07:14 +02:00
return iter ( self . frame_ids )
2024-06-10 19:09:48 +01:00
def __len__ ( self ) - > int :
2024-05-04 16:07:14 +02:00
return len ( self . frame_ids )
2024-06-10 19:09:48 +01:00
def to_hwc_uint8_numpy ( chw_float32_torch : torch . Tensor ) - > np . ndarray :
2024-05-04 16:07:14 +02:00
assert chw_float32_torch . dtype == torch . float32
assert chw_float32_torch . ndim == 3
c , h , w = chw_float32_torch . shape
2025-03-24 13:41:27 +00:00
assert c < h and c < w , f " expect channel first images, but instead { chw_float32_torch . shape } "
hwc_uint8_numpy = ( chw_float32_torch * 255 ) . type ( torch . uint8 ) . permute ( 1 , 2 , 0 ) . numpy ( )
2024-05-04 16:07:14 +02:00
return hwc_uint8_numpy
def visualize_dataset (
2024-11-29 19:04:00 +01:00
dataset : LeRobotDataset ,
2024-05-04 16:07:14 +02:00
episode_index : int ,
batch_size : int = 32 ,
num_workers : int = 0 ,
mode : str = " local " ,
web_port : int = 9090 ,
ws_port : int = 9087 ,
save : bool = False ,
2024-08-08 20:19:06 +03:00
output_dir : Path | None = None ,
2024-05-04 16:07:14 +02:00
) - > Path | None :
if save :
2025-02-15 15:51:17 +01:00
assert output_dir is not None , (
" Set an output directory where to write .rrd files with `--output-dir path/to/directory`. "
)
2024-05-04 16:07:14 +02:00
2024-11-29 19:04:00 +01:00
repo_id = dataset . repo_id
2024-05-04 16:07:14 +02:00
logging . info ( " Loading dataloader " )
episode_sampler = EpisodeSampler ( dataset , episode_index )
2024-04-10 13:45:45 +00:00
dataloader = torch . utils . data . DataLoader (
dataset ,
2024-05-04 16:07:14 +02:00
num_workers = num_workers ,
batch_size = batch_size ,
sampler = episode_sampler ,
2024-04-10 13:45:45 +00:00
)
2024-05-04 16:07:14 +02:00
logging . info ( " Starting Rerun " )
if mode not in [ " local " , " distant " ] :
raise ValueError ( mode )
spawn_local_viewer = mode == " local " and not save
rr . init ( f " { repo_id } /episode_ { episode_index } " , spawn = spawn_local_viewer )
2024-05-11 19:28:22 +03:00
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
# TODO(rcadene): remove `gc.collect` when rerun version 0.16 is out, which includes a fix
gc . collect ( )
2024-05-04 16:07:14 +02:00
if mode == " distant " :
rr . serve ( open_browser = False , web_port = web_port , ws_port = ws_port )
logging . info ( " Logging to Rerun " )
for batch in tqdm . tqdm ( dataloader , total = len ( dataloader ) ) :
# iterate over the batch
for i in range ( len ( batch [ " index " ] ) ) :
rr . set_time_sequence ( " frame_index " , batch [ " frame_index " ] [ i ] . item ( ) )
rr . set_time_seconds ( " timestamp " , batch [ " timestamp " ] [ i ] . item ( ) )
# display each camera image
2024-11-29 19:04:00 +01:00
for key in dataset . meta . camera_keys :
2024-05-04 16:07:14 +02:00
# TODO(rcadene): add `.compress()`? is it lossless?
rr . log ( key , rr . Image ( to_hwc_uint8_numpy ( batch [ key ] [ i ] ) ) )
# display each dimension of action space (e.g. actuators command)
if " action " in batch :
for dim_idx , val in enumerate ( batch [ " action " ] [ i ] ) :
rr . log ( f " action/ { dim_idx } " , rr . Scalar ( val . item ( ) ) )
# display each dimension of observed state space (e.g. agent position in joint space)
if " observation.state " in batch :
for dim_idx , val in enumerate ( batch [ " observation.state " ] [ i ] ) :
rr . log ( f " state/ { dim_idx } " , rr . Scalar ( val . item ( ) ) )
if " next.done " in batch :
rr . log ( " next.done " , rr . Scalar ( batch [ " next.done " ] [ i ] . item ( ) ) )
if " next.reward " in batch :
rr . log ( " next.reward " , rr . Scalar ( batch [ " next.reward " ] [ i ] . item ( ) ) )
if " next.success " in batch :
rr . log ( " next.success " , rr . Scalar ( batch [ " next.success " ] [ i ] . item ( ) ) )
if mode == " local " and save :
# save .rrd locally
output_dir = Path ( output_dir )
output_dir . mkdir ( parents = True , exist_ok = True )
repo_id_str = repo_id . replace ( " / " , " _ " )
rrd_path = output_dir / f " { repo_id_str } _episode_ { episode_index } .rrd "
rr . save ( rrd_path )
return rrd_path
elif mode == " distant " :
# stop the process from exiting since it is serving the websocket connection
try :
while True :
time . sleep ( 1 )
except KeyboardInterrupt :
print ( " Ctrl-C received. Exiting. " )
def main ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument (
" --repo-id " ,
type = str ,
required = True ,
2025-01-28 20:07:10 +08:00
help = " Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`). " ,
2024-05-04 16:07:14 +02:00
)
parser . add_argument (
" --episode-index " ,
type = int ,
required = True ,
help = " Episode to visualize. " ,
)
2024-08-08 20:19:06 +03:00
parser . add_argument (
" --root " ,
type = Path ,
default = None ,
2024-11-29 19:04:00 +01:00
help = " Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available. " ,
2024-08-08 20:19:06 +03:00
)
parser . add_argument (
" --output-dir " ,
type = Path ,
default = None ,
help = " Directory path to write a .rrd file when `--save 1` is set. " ,
)
2024-05-04 16:07:14 +02:00
parser . add_argument (
" --batch-size " ,
type = int ,
default = 32 ,
help = " Batch size loaded by DataLoader. " ,
)
parser . add_argument (
" --num-workers " ,
type = int ,
2024-05-11 19:28:22 +03:00
default = 4 ,
2024-05-04 16:07:14 +02:00
help = " Number of processes of Dataloader for loading the data. " ,
)
parser . add_argument (
" --mode " ,
type = str ,
default = " local " ,
help = (
" Mode of viewing between ' local ' or ' distant ' . "
" ' local ' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
2024-06-03 16:35:16 +02:00
" ' distant ' creates a server on the distant machine where the data is stored. "
" Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine. "
2024-05-04 16:07:14 +02:00
) ,
)
parser . add_argument (
" --web-port " ,
type = int ,
default = 9090 ,
help = " Web port for rerun.io when `--mode distant` is set. " ,
)
parser . add_argument (
" --ws-port " ,
type = int ,
default = 9087 ,
help = " Web socket port for rerun.io when `--mode distant` is set. " ,
)
parser . add_argument (
" --save " ,
type = int ,
default = 0 ,
help = (
" Save a .rrd file in the directory provided by `--output-dir`. "
2024-06-03 16:35:16 +02:00
" It also deactivates the spawning of a viewer. "
" Visualize the data by running `rerun path/to/file.rrd` on your local machine. "
2024-05-04 16:07:14 +02:00
) ,
)
2024-06-10 10:44:32 +02:00
2025-03-13 09:28:29 +00:00
parser . add_argument (
" --tolerance-s " ,
type = float ,
default = 1e-4 ,
help = (
" Tolerance in seconds used to ensure data timestamps respect the dataset fps value "
" This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument "
" If not given, defaults to 1e-4. "
) ,
)
2024-05-04 16:07:14 +02:00
args = parser . parse_args ( )
2024-11-29 19:04:00 +01:00
kwargs = vars ( args )
repo_id = kwargs . pop ( " repo_id " )
root = kwargs . pop ( " root " )
2025-03-13 09:28:29 +00:00
tolerance_s = kwargs . pop ( " tolerance_s " )
2024-11-29 19:04:00 +01:00
logging . info ( " Loading dataset " )
2025-03-13 09:28:29 +00:00
dataset = LeRobotDataset ( repo_id , root = root , tolerance_s = tolerance_s )
2024-11-29 19:04:00 +01:00
visualize_dataset ( dataset , * * vars ( args ) )
2024-03-06 10:14:03 +00:00
2024-02-10 15:46:24 +00:00
if __name__ == " __main__ " :
2024-05-04 16:07:14 +02:00
main ( )