mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
Compare commits
77 Commits
peft-preco
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3182bee9a | ||
|
|
e43ece3271 | ||
|
|
67485b1edc | ||
|
|
bb85f4ebea | ||
|
|
0cd84cc9f9 | ||
|
|
e9e16d77f5 | ||
|
|
8ad085d882 | ||
|
|
f1fbc29819 | ||
|
|
aa07b858ea | ||
|
|
980cd8dbc5 | ||
|
|
593d63cd68 | ||
|
|
715e33b3df | ||
|
|
858baf73db | ||
|
|
6b0dcfe30a | ||
|
|
8c54e3eece | ||
|
|
2a769a2de7 | ||
|
|
b3257fba6c | ||
|
|
114e0524db | ||
|
|
4110b168eb | ||
|
|
5afe1804a8 | ||
|
|
2818b2a55c | ||
|
|
1b41b807db | ||
|
|
dc0c7f2e93 | ||
|
|
2958b0aede | ||
|
|
d56e2d438f | ||
|
|
c49a3ecb1e | ||
|
|
38c7ac5b07 | ||
|
|
20b74ae1eb | ||
|
|
b9b880bd8b | ||
|
|
2866d0770f | ||
|
|
4375a05a9f | ||
|
|
4acf99f622 | ||
|
|
5a6ea09248 | ||
|
|
9c0836c8d0 | ||
|
|
b0cca75e5e | ||
|
|
54b5c805bf | ||
|
|
eab5543750 | ||
|
|
6b6a990f4c | ||
|
|
c2a05a1fde | ||
|
|
6c4d122198 | ||
|
|
34c5d4ce07 | ||
|
|
c1b28f0b58 | ||
|
|
53ecec5fb2 | ||
|
|
65738f0a80 | ||
|
|
5d184a7811 | ||
|
|
1a5c1ef9c7 | ||
|
|
7866c1f7d1 | ||
|
|
3666ac9346 | ||
|
|
3daab2acbb | ||
|
|
c36d2253d0 | ||
|
|
e2e6f6e666 | ||
|
|
ff0029f84b | ||
|
|
39ad2d16d4 | ||
|
|
689c5efc72 | ||
|
|
eda0b996cd | ||
|
|
15e7a9d541 | ||
|
|
52fb4143b5 | ||
|
|
93c80b2cb1 | ||
|
|
5fbbaa1bc0 | ||
|
|
71d1f5e2c9 | ||
|
|
b520941cd9 | ||
|
|
64ed5258e6 | ||
|
|
392a8c32a7 | ||
|
|
969ef745a2 | ||
|
|
6fe42a72db | ||
|
|
2487228ea7 | ||
|
|
76436ca1de | ||
|
|
fbf2f2222a | ||
|
|
02bc4e03e0 | ||
|
|
624eaf1175 | ||
|
|
aed3eb4a94 | ||
|
|
8426c64f42 | ||
|
|
7c2bbee613 | ||
|
|
9d6886dd08 | ||
|
|
d67ca342e9 | ||
|
|
57c9c21c39 | ||
|
|
38c14571cc |
@@ -108,7 +108,8 @@ def save_decoded_frames(
|
|||||||
|
|
||||||
|
|
||||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
episode_index = 0
|
||||||
|
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
|
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -265,7 +266,8 @@ def benchmark_encoding_decoding(
|
|||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
episode_index = 0
|
||||||
|
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||||
num_pixels = width * height
|
num_pixels = width * height
|
||||||
video_size_bytes = video_path.stat().st_size
|
video_size_bytes = video_path.stat().st_size
|
||||||
|
|||||||
@@ -92,11 +92,11 @@ print(dataset.hf_dataset)
|
|||||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||||
# with the latter, like iterating through the dataset.
|
# with the latter, like iterating through the dataset.
|
||||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||||
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
|
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||||
# frame indices associated to the first episode:
|
# frame indices associated to the first episode:
|
||||||
episode_index = 0
|
episode_index = 0
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||||
|
|
||||||
# Then we grab all the image frames from the first camera:
|
# Then we grab all the image frames from the first camera:
|
||||||
camera_key = dataset.meta.camera_keys[0]
|
camera_key = dataset.meta.camera_keys[0]
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
|||||||
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
||||||
|
|
||||||
# Get the index of the first observation in the first episode
|
# Get the index of the first observation in the first episode
|
||||||
first_idx = dataset.episode_data_index["from"][0].item()
|
first_idx = dataset.meta.episodes["dataset_from_index"][0]
|
||||||
|
|
||||||
# Get the frame corresponding to the first camera
|
# Get the frame corresponding to the first camera
|
||||||
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
|
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
|
||||||
|
|||||||
@@ -51,10 +51,10 @@ while i < NB_CYCLES_CLIENT_CONNECTION:
|
|||||||
action_sent = robot.send_action(action)
|
action_sent = robot.send_action(action)
|
||||||
observation = robot.get_observation()
|
observation = robot.get_observation()
|
||||||
|
|
||||||
frame = {**action_sent, **observation}
|
|
||||||
task = "Dummy Example Task Dataset"
|
task = "Dummy Example Task Dataset"
|
||||||
|
frame = {**action_sent, **observation, "task": task}
|
||||||
|
|
||||||
dataset.add_frame(frame, task)
|
dataset.add_frame(frame)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
print("Disconnecting Teleop Devices and LeKiwi Client")
|
print("Disconnecting Teleop Devices and LeKiwi Client")
|
||||||
|
|||||||
503
examples/port_datasets/agibot_hdf5/port_agibot.py
Normal file
503
examples/port_datasets/agibot_hdf5/port_agibot.py
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_VIDEO_PATH,
|
||||||
|
EPISODES_DIR,
|
||||||
|
concat_video_files,
|
||||||
|
get_video_duration_in_s,
|
||||||
|
get_video_size_in_mb,
|
||||||
|
update_chunk_file_indices,
|
||||||
|
write_info,
|
||||||
|
)
|
||||||
|
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||||
|
|
||||||
|
AGIBOT_FPS = 30
|
||||||
|
AGIBOT_ROBOT_TYPE = "AgiBot_A2D"
|
||||||
|
AGIBOT_FEATURES = {
|
||||||
|
# gripper open range in mm (0 for pull open, 1 for full close)
|
||||||
|
"observation.state.effector.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (2,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["left_gripper", "right_gripper"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# flange xyz in meters
|
||||||
|
"observation.state.end.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (6,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["left_x", "left_y", "left_z", "right_x", "right_y", "right_z"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# flange quaternion with xyzw
|
||||||
|
"observation.state.end.orientation": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (8,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["left_x", "left_y", "left_z", "left_w", "right_x", "right_y", "right_z", "right_w"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# in radians
|
||||||
|
"observation.state.head.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (2,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["yaw", "pitch"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# in motor steps
|
||||||
|
"observation.state.joint.current_value": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (14,),
|
||||||
|
"names": {
|
||||||
|
"axes": [f"left_joint_{i}" for i in range(7)] + [f"right_joint_{i}" for i in range(7)],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# same as current_value but in radians
|
||||||
|
"observation.state.joint.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (14,),
|
||||||
|
"names": {
|
||||||
|
"axes": [f"left_joint_{i}" for i in range(7)] + [f"right_joint_{i}" for i in range(7)],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# pitch in radians, lift in meters
|
||||||
|
"observation.state.waist.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (2,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["pitch", "lift"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# concatenation of head.position, joint.position, effector.position, waist.position
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (20,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["head_yaw", "head_pitch"]
|
||||||
|
+ [f"left_joint_{i}" for i in range(7)]
|
||||||
|
+ ["left_gripper"]
|
||||||
|
+ [f"right_joint_{i}" for i in range(7)]
|
||||||
|
+ ["right_gripper"]
|
||||||
|
+ ["waist_pitch", "waist_lift"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# gripper open range in mm (0 for pull open, 1 for full close)
|
||||||
|
"action.effector.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (2,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["left_gripper", "right_gripper"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# flange xyz in meters
|
||||||
|
"action.end.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (6,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["left_x", "left_y", "left_z", "right_x", "right_y", "right_z"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# flange quaternion with xyzw
|
||||||
|
"action.end.orientation": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (8,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["left_x", "left_y", "left_z", "left_w", "right_x", "right_y", "right_z", "right_w"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# in radians
|
||||||
|
"action.head.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (2,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["yaw", "pitch"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# goal joint position in radians
|
||||||
|
"action.joint.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (14,),
|
||||||
|
"names": {
|
||||||
|
"axes": [f"left_joint_{i}" for i in range(7)] + [f"right_joint_{i}" for i in range(7)],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action.robot.velocity": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (2,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["velocity_x", "yaw_rate"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# pitch in radians, lift in meters
|
||||||
|
"action.waist.position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (2,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["pitch", "lift"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# concatenation of head.position, joint.position, effector.position, waist.position, robot.velocity
|
||||||
|
"action": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (22,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["head_yaw", "head_pitch"]
|
||||||
|
+ [f"left_joint_{i}" for i in range(7)]
|
||||||
|
+ ["left_gripper"]
|
||||||
|
+ [f"right_joint_{i}" for i in range(7)]
|
||||||
|
+ ["right_gripper"]
|
||||||
|
+ ["waist_pitch", "waist_lift"]
|
||||||
|
+ ["velocity_x", "yaw_rate"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# episode level annotation
|
||||||
|
"init_scene_text": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
# frame level annotation
|
||||||
|
"action_text": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
# frame level annotation
|
||||||
|
"skill": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
AGIBOT_IMAGES_FEATURES = {
|
||||||
|
"observation.images.top_head": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (480, 640, 3),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
},
|
||||||
|
"observation.images.hand_left": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (480, 640, 3),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
},
|
||||||
|
"observation.images.hand_right": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (480, 640, 3),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
},
|
||||||
|
"observation.images.head_center_fisheye": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (748, 960, 3),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
},
|
||||||
|
"observation.images.head_left_fisheye": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (748, 960, 3),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
},
|
||||||
|
"observation.images.head_right_fisheye": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (748, 960, 3),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
},
|
||||||
|
"observation.images.back_left_fisheye": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (748, 960, 3),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
},
|
||||||
|
"observation.images.back_right_fisheye": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (748, 960, 3),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_info_per_task(raw_dir):
|
||||||
|
info_per_task = {}
|
||||||
|
task_info_dir = raw_dir / "task_info"
|
||||||
|
for path in task_info_dir.glob("task_*.json"):
|
||||||
|
task_index = int(path.name.replace("task_", "").replace(".json", ""))
|
||||||
|
with open(path) as f:
|
||||||
|
task_info = json.load(f)
|
||||||
|
|
||||||
|
task_info = {ep["episode_id"]: ep for ep in task_info}
|
||||||
|
info_per_task[task_index] = task_info
|
||||||
|
|
||||||
|
return info_per_task
|
||||||
|
|
||||||
|
|
||||||
|
def create_frame_idx_to_frames_label_idx(ep_info):
|
||||||
|
frame_idx_to_frames_label_idx = {}
|
||||||
|
for label_idx, frames_label in enumerate(ep_info["label_info"]["action_config"]):
|
||||||
|
for frame_idx in range(frames_label["start_frame"], frames_label["end_frame"]):
|
||||||
|
frame_idx_to_frames_label_idx[frame_idx] = label_idx
|
||||||
|
return frame_idx_to_frames_label_idx
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lerobot_frames(raw_dir: Path, task_index: int, episode_index: int):
|
||||||
|
r"""/!\ The frames dont contain observation.cameras.*"""
|
||||||
|
info_per_task = load_info_per_task(raw_dir)
|
||||||
|
ep_info = info_per_task[task_index][episode_index]
|
||||||
|
frame_idx_to_frames_label_idx = create_frame_idx_to_frames_label_idx(ep_info)
|
||||||
|
|
||||||
|
# Empty features are commented out.
|
||||||
|
keys_mapping = {
|
||||||
|
# STATE
|
||||||
|
# "observation.state.effector.force": "state/effector/force",
|
||||||
|
"observation.state.effector.position": "state/effector/position",
|
||||||
|
# "observation.state.end.angular": "state/end/angular",
|
||||||
|
"observation.state.end.position": "state/end/position",
|
||||||
|
"observation.state.end.orientation": "state/end/orientation",
|
||||||
|
# "observation.state.end.velocity": "state/end/velocity",
|
||||||
|
# "observation.state.end.wrench": "state/end/wrench",
|
||||||
|
# "observation.state.head.effort": "state/head/effort",
|
||||||
|
"observation.state.head.position": "state/head/position",
|
||||||
|
# "observation.state.head.velocity": "state/head/velocity",
|
||||||
|
"observation.state.joint.current_value": "state/joint/current_value",
|
||||||
|
# "observation.state.joint.effort": "state/joint/effort",
|
||||||
|
"observation.state.joint.position": "state/joint/position",
|
||||||
|
# "observation.state.joint.velocity": "state/joint/velocity",
|
||||||
|
# "observation.state.robot.orientation": "state/robot/orientation",
|
||||||
|
# "observation.state.robot.orientation_drift": "state/robot/orientation_drift",
|
||||||
|
# "observation.state.robot.position": "state/robot/position",
|
||||||
|
# "observation.state.robot.position_drift": "state/robot/position_drift",
|
||||||
|
# "observation.state.waist.effort": "state/waist/effort",
|
||||||
|
"observation.state.waist.position": "state/waist/position",
|
||||||
|
# "observation.state.waist.velocity": "state/waist/velocity",
|
||||||
|
# ----- ACTION (index are also commented out) -----
|
||||||
|
# "action.effector.index": "action/effector/index",
|
||||||
|
"action.effector.position": "action/effector/position",
|
||||||
|
# "action.effector.force": "action/effector/force",
|
||||||
|
# "action.end.index": "action/end/index",
|
||||||
|
"action.end.position": "action/end/position",
|
||||||
|
"action.end.orientation": "action/end/orientation",
|
||||||
|
# "action.head.index": "action/head/index",
|
||||||
|
"action.head.position": "action/head/position",
|
||||||
|
# "action.joint.index": "action/joint/index",
|
||||||
|
"action.joint.position": "action/joint/position",
|
||||||
|
# "action.joint.effort": "action/joint/effort",
|
||||||
|
# "action.joint.velocity": "action/joint/velocity",
|
||||||
|
# "action.robot.index": "action/robot/index",
|
||||||
|
# "action.robot.position": "action/robot/position",
|
||||||
|
# "action.robot.orientation": "action/robot/orientation",
|
||||||
|
# "action.robot.angular": "action/robot/angular",
|
||||||
|
"action.robot.velocity": "action/robot/velocity",
|
||||||
|
# "action.waist.index": "action/waist/index",
|
||||||
|
"action.waist.position": "action/waist/position",
|
||||||
|
}
|
||||||
|
|
||||||
|
h5_path = raw_dir / f"proprio_stats/{task_index}/{episode_index}/proprio_stats.h5"
|
||||||
|
with h5py.File(h5_path) as h5:
|
||||||
|
num_frames = len(h5["state/joint/position"])
|
||||||
|
|
||||||
|
for h5_key in keys_mapping.values():
|
||||||
|
col_num_frames = h5[h5_key].shape[0]
|
||||||
|
if col_num_frames != num_frames:
|
||||||
|
raise ValueError(
|
||||||
|
f"HDF5 column '{h5_key}' is expected to have {num_frames} but has {col_num_frames}' frames instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(num_frames):
|
||||||
|
# Create frame
|
||||||
|
f = {new_key: h5[h5_key][i] for new_key, h5_key in keys_mapping.items()}
|
||||||
|
|
||||||
|
for key in f:
|
||||||
|
f[key] = np.array(f[key]).astype(np.float32)
|
||||||
|
|
||||||
|
f["observation.state.end.position"] = f["observation.state.end.position"].reshape(6)
|
||||||
|
f["observation.state.end.orientation"] = f["observation.state.end.orientation"].reshape(8)
|
||||||
|
f["observation.state"] = np.concatenate(
|
||||||
|
[
|
||||||
|
f["observation.state.head.position"],
|
||||||
|
f["observation.state.joint.position"][:7], # left
|
||||||
|
f["observation.state.effector.position"][[0]], # left
|
||||||
|
f["observation.state.joint.position"][7:], # right
|
||||||
|
f["observation.state.effector.position"][[1]], # right
|
||||||
|
f["observation.state.waist.position"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
f["action.end.position"] = f["action.end.position"].reshape(6)
|
||||||
|
f["action.end.orientation"] = f["action.end.orientation"].reshape(8)
|
||||||
|
f["action"] = np.concatenate(
|
||||||
|
[
|
||||||
|
f["action.head.position"],
|
||||||
|
f["action.joint.position"][:7], # left
|
||||||
|
f["action.effector.position"][[0]], # left
|
||||||
|
f["action.joint.position"][7:], # right
|
||||||
|
f["action.effector.position"][[1]], # right
|
||||||
|
f["action.waist.position"],
|
||||||
|
f["action.robot.velocity"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# episode level annotation
|
||||||
|
f["task"] = ep_info["task_name"]
|
||||||
|
f["init_scene_text"] = ep_info["init_scene_text"]
|
||||||
|
|
||||||
|
# frame level annotation
|
||||||
|
if i in frame_idx_to_frames_label_idx:
|
||||||
|
frames_label_idx = frame_idx_to_frames_label_idx[i]
|
||||||
|
frames_label = ep_info["label_info"]["action_config"][frames_label_idx]
|
||||||
|
f["action_text"] = frames_label["action_text"]
|
||||||
|
f["skill"] = frames_label["skill"]
|
||||||
|
else:
|
||||||
|
f["action_text"] = ""
|
||||||
|
f["skill"] = ""
|
||||||
|
|
||||||
|
yield f
|
||||||
|
|
||||||
|
|
||||||
|
def update_meta_data(
|
||||||
|
df,
|
||||||
|
ep_to_meta,
|
||||||
|
):
|
||||||
|
def _update(row):
|
||||||
|
ep_idx = row["episode_index"]
|
||||||
|
for key, meta in ep_to_meta[ep_idx].items():
|
||||||
|
row[f"videos/{key}/chunk_index"] = meta["chunk_index"]
|
||||||
|
row[f"videos/{key}/file_index"] = meta["file_index"]
|
||||||
|
row[f"videos/{key}/from_timestamp"] = meta["from_timestamp"]
|
||||||
|
row[f"videos/{key}/to_timestamp"] = meta["to_timestamp"]
|
||||||
|
return row
|
||||||
|
|
||||||
|
return df.apply(_update, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
def move_videos_to_lerobot_directory(lerobot_dataset, raw_dir, task_index, episode_names):
|
||||||
|
keys_mapping = {
|
||||||
|
"observation.images.top_head": "head_color",
|
||||||
|
"observation.images.hand_left": "hand_left_color",
|
||||||
|
"observation.images.hand_right": "hand_right_color",
|
||||||
|
"observation.images.head_center_fisheye": "head_center_fisheye_color",
|
||||||
|
"observation.images.head_left_fisheye": "head_left_fisheye_color",
|
||||||
|
"observation.images.head_right_fisheye": "head_right_fisheye_color",
|
||||||
|
"observation.images.back_left_fisheye": "back_left_fisheye_color",
|
||||||
|
"observation.images.back_right_fisheye": "back_right_fisheye_color",
|
||||||
|
}
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
for key in keys_mapping:
|
||||||
|
if key not in lerobot_dataset.meta.info["features"]:
|
||||||
|
raise ValueError(f"Key '{key}' not found in features.")
|
||||||
|
|
||||||
|
video_keys = keys_mapping.keys()
|
||||||
|
chunk_idx = dict.fromkeys(video_keys, 0)
|
||||||
|
file_idx = dict.fromkeys(video_keys, 0)
|
||||||
|
latest_duration_in_s = dict.fromkeys(video_keys, 0)
|
||||||
|
ep_to_meta = {}
|
||||||
|
for ep_idx, ep_name in enumerate(episode_names):
|
||||||
|
for key in video_keys:
|
||||||
|
raw_videos_dir = raw_dir / f"observations/{task_index}/{ep_name}/videos"
|
||||||
|
old_key = keys_mapping[key]
|
||||||
|
ep_path = raw_videos_dir / f"{old_key}.mp4"
|
||||||
|
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||||
|
|
||||||
|
aggr_path = lerobot_dataset.root / DEFAULT_VIDEO_PATH.format(
|
||||||
|
video_key=key,
|
||||||
|
chunk_index=chunk_idx[key],
|
||||||
|
file_index=file_idx[key],
|
||||||
|
)
|
||||||
|
if not aggr_path.exists():
|
||||||
|
# First video
|
||||||
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(str(ep_path), str(aggr_path))
|
||||||
|
else:
|
||||||
|
size_in_mb = get_video_size_in_mb(ep_path)
|
||||||
|
aggr_size_in_mb = get_video_size_in_mb(aggr_path)
|
||||||
|
|
||||||
|
if aggr_size_in_mb + size_in_mb >= DEFAULT_VIDEO_FILE_SIZE_IN_MB:
|
||||||
|
# Size limit is reached, prepare new parquet file
|
||||||
|
chunk_idx[key], file_idx[key] = update_chunk_file_indices(
|
||||||
|
chunk_idx[key], file_idx[key], DEFAULT_CHUNK_SIZE
|
||||||
|
)
|
||||||
|
aggr_path = lerobot_dataset.root / DEFAULT_VIDEO_PATH.format(
|
||||||
|
video_key=key,
|
||||||
|
chunk_index=chunk_idx[key],
|
||||||
|
file_index=file_idx[key],
|
||||||
|
)
|
||||||
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(str(ep_path), str(aggr_path))
|
||||||
|
latest_duration_in_s[key] = 0
|
||||||
|
else:
|
||||||
|
# Update the existing parquet file with new rows
|
||||||
|
concat_video_files(
|
||||||
|
[aggr_path, ep_path],
|
||||||
|
lerobot_dataset.root,
|
||||||
|
key,
|
||||||
|
chunk_idx[key],
|
||||||
|
file_idx[key],
|
||||||
|
)
|
||||||
|
|
||||||
|
if ep_idx not in ep_to_meta:
|
||||||
|
ep_to_meta[ep_idx] = {}
|
||||||
|
ep_to_meta[ep_idx][key] = {
|
||||||
|
"chunk_index": chunk_idx[key],
|
||||||
|
"file_index": file_idx[key],
|
||||||
|
"from_timestamp": latest_duration_in_s[key],
|
||||||
|
"to_timestamp": latest_duration_in_s[key] + ep_duration_in_s,
|
||||||
|
}
|
||||||
|
latest_duration_in_s[key] += ep_duration_in_s
|
||||||
|
|
||||||
|
# Update episodes meta data
|
||||||
|
for meta_path in (lerobot_dataset.root / EPISODES_DIR).glob("chunk-*/file-*.parquet"):
|
||||||
|
df = pd.read_parquet(meta_path)
|
||||||
|
df = update_meta_data(df, ep_to_meta)
|
||||||
|
df.to_parquet(meta_path)
|
||||||
|
|
||||||
|
|
||||||
|
def port_agibot(
|
||||||
|
raw_dir: Path, repo_id: str, task_index: int, episode_indices: list[int], push_to_hub: bool = False
|
||||||
|
):
|
||||||
|
lerobot_dataset = LeRobotDataset.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
robot_type=AGIBOT_ROBOT_TYPE,
|
||||||
|
fps=AGIBOT_FPS,
|
||||||
|
features=AGIBOT_FEATURES,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
num_episodes = len(episode_indices)
|
||||||
|
logging.info(f"Number of episodes {num_episodes}")
|
||||||
|
|
||||||
|
for i, episode_index in enumerate(episode_indices):
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"{i} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
|
||||||
|
)
|
||||||
|
|
||||||
|
for frame in generate_lerobot_frames(raw_dir, task_index, episode_index):
|
||||||
|
lerobot_dataset.add_frame(frame)
|
||||||
|
|
||||||
|
lerobot_dataset.save_episode()
|
||||||
|
logging.info("Save_episode")
|
||||||
|
|
||||||
|
# Videos have already been encoded with the proper format, so we rely on hacks
|
||||||
|
# HACK: Add extra images features
|
||||||
|
lerobot_dataset.meta.info["features"].update(AGIBOT_IMAGES_FEATURES)
|
||||||
|
write_info(lerobot_dataset.meta.info, lerobot_dataset.meta.root)
|
||||||
|
move_videos_to_lerobot_directory(lerobot_dataset, raw_dir, task_index, episode_indices)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
lerobot_dataset.push_to_hub(
|
||||||
|
# Add agibot tag, since it belongs to the agibot collection of datasets
|
||||||
|
tags=["agibot"],
|
||||||
|
private=False,
|
||||||
|
)
|
||||||
183
examples/port_datasets/agibot_hdf5/slurm_port_shards.py
Normal file
183
examples/port_datasets/agibot_hdf5/slurm_port_shards.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import tarfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||||
|
from datatrove.pipeline.base import PipelineStep
|
||||||
|
|
||||||
|
from examples.port_datasets.agibot_hdf5.download import (
|
||||||
|
RAW_REPO_ID,
|
||||||
|
download_meta_data,
|
||||||
|
get_observations_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PortAgiBotShards(PipelineStep):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
raw_dir: Path | str,
|
||||||
|
repo_id: str = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.raw_dir = Path(raw_dir)
|
||||||
|
self.repo_id = repo_id
|
||||||
|
|
||||||
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from datasets.utils.tqdm import disable_progress_bars
|
||||||
|
|
||||||
|
from examples.port_datasets.agibot_hdf5.download import (
|
||||||
|
RAW_REPO_ID,
|
||||||
|
download,
|
||||||
|
get_observations_files,
|
||||||
|
no_depth,
|
||||||
|
)
|
||||||
|
from examples.port_datasets.agibot_hdf5.port_agibot import port_agibot
|
||||||
|
from examples.port_datasets.droid_rlds.port_droid import validate_dataset
|
||||||
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
disable_progress_bars()
|
||||||
|
|
||||||
|
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
||||||
|
|
||||||
|
dataset_dir = HF_LEROBOT_HOME / shard_repo_id
|
||||||
|
if dataset_dir.exists():
|
||||||
|
shutil.rmtree(dataset_dir)
|
||||||
|
|
||||||
|
obs_files, _ = get_observations_files(self.raw_dir, RAW_REPO_ID)
|
||||||
|
obs_file = obs_files[rank]
|
||||||
|
|
||||||
|
# Download subset
|
||||||
|
download(self.raw_dir, allow_patterns=obs_file)
|
||||||
|
|
||||||
|
tar_path = self.raw_dir / obs_file
|
||||||
|
with tarfile.open(tar_path, "r") as tar:
|
||||||
|
extracted_files = tar.getnames()
|
||||||
|
|
||||||
|
task_index = int(tar_path.parent.name)
|
||||||
|
episode_names = [int(p) for p in extracted_files if "/" not in p]
|
||||||
|
|
||||||
|
# Untar if needed
|
||||||
|
if not all((tar_path.parent / f"{ep_name}").exists() for ep_name in episode_names):
|
||||||
|
logging.info(f"Untar-ing {tar_path}...")
|
||||||
|
with tarfile.open(tar_path, "r") as tar:
|
||||||
|
tar.extractall(path=tar_path.parent, filter=no_depth) # nosec B202
|
||||||
|
|
||||||
|
port_agibot(self.raw_dir, shard_repo_id, task_index, episode_names, push_to_hub=False)
|
||||||
|
|
||||||
|
for ep_name in episode_names:
|
||||||
|
shutil.rmtree(str(tar_path.parent / f"{ep_name}"))
|
||||||
|
|
||||||
|
tar_path.unlink()
|
||||||
|
|
||||||
|
validate_dataset(shard_repo_id)
|
||||||
|
|
||||||
|
|
||||||
|
def make_port_executor(
|
||||||
|
raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||||
|
):
|
||||||
|
download_meta_data(raw_dir)
|
||||||
|
obs_files, _ = get_observations_files(raw_dir, RAW_REPO_ID)
|
||||||
|
num_shards = len(obs_files)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"pipeline": [
|
||||||
|
PortAgiBotShards(raw_dir, repo_id),
|
||||||
|
],
|
||||||
|
"logging_dir": str(logs_dir / job_name),
|
||||||
|
}
|
||||||
|
|
||||||
|
if slurm:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"job_name": job_name,
|
||||||
|
"tasks": num_shards,
|
||||||
|
"workers": workers,
|
||||||
|
"time": "08:00:00",
|
||||||
|
"partition": partition,
|
||||||
|
"cpus_per_task": cpus_per_task,
|
||||||
|
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
executor = SlurmPipelineExecutor(**kwargs)
|
||||||
|
else:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"tasks": num_shards,
|
||||||
|
"workers": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
executor = LocalPipelineExecutor(**kwargs)
|
||||||
|
|
||||||
|
return executor
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-dir",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logs-dir",
|
||||||
|
type=Path,
|
||||||
|
help="Path to logs directory for `datatrove`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--job-name",
|
||||||
|
type=str,
|
||||||
|
default="port_droid",
|
||||||
|
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--slurm",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--partition",
|
||||||
|
type=str,
|
||||||
|
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpus-per-task",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of cpus that each slurm worker will use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mem-per-cpu",
|
||||||
|
type=str,
|
||||||
|
default="1950M",
|
||||||
|
help="Memory per cpu that each worker will use.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
kwargs = vars(args)
|
||||||
|
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||||
|
port_executor = make_port_executor(**kwargs)
|
||||||
|
port_executor.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
144
examples/port_datasets/droid_rlds/README.md
Normal file
144
examples/port_datasets/droid_rlds/README.md
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
# Port DROID 1.0.1 dataset to LeRobotDataset
|
||||||
|
|
||||||
|
## Download
|
||||||
|
|
||||||
|
TODO
|
||||||
|
|
||||||
|
It will take 2 TB in your local disk.
|
||||||
|
|
||||||
|
## Port on a single computer
|
||||||
|
|
||||||
|
First, install tensorflow dataset utilities to read from raw files:
|
||||||
|
```bash
|
||||||
|
pip install tensorflow
|
||||||
|
pip install tensorflow_datasets
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run this script to start porting the dataset:
|
||||||
|
```bash
|
||||||
|
python examples/port_datasets/droid_rlds/port_droid.py \
|
||||||
|
--raw-dir /your/data/droid/1.0.1 \
|
||||||
|
--repo-id your_id/droid_1.0.1 \
|
||||||
|
--push-to-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
It will take 400GB in your local disk.
|
||||||
|
|
||||||
|
As usual, your LeRobotDataset will be stored in your huggingface/lerobot cache folder.
|
||||||
|
|
||||||
|
WARNING: it will take 7 days for porting the dataset locally and 3 days to upload, so we will need to parallelize over multiple nodes on a slurm cluster.
|
||||||
|
|
||||||
|
NOTE: For development, run this script to start porting a shard:
|
||||||
|
```bash
|
||||||
|
python examples/port_datasets/droid_rlds/port.py \
|
||||||
|
--raw-dir /your/data/droid/1.0.1 \
|
||||||
|
--repo-id your_id/droid_1.0.1 \
|
||||||
|
--num-shards 2048 \
|
||||||
|
--shard-index 0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Port over SLURM
|
||||||
|
|
||||||
|
Install slurm utilities from Hugging Face:
|
||||||
|
```bash
|
||||||
|
pip install datatrove
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 1. Port one shard per job
|
||||||
|
|
||||||
|
Run this script to start porting shards of the dataset:
|
||||||
|
```bash
|
||||||
|
python examples/port_datasets/droid_rlds/slurm_port_shards.py \
|
||||||
|
--raw-dir /your/data/droid/1.0.1 \
|
||||||
|
--repo-id your_id/droid_1.0.1 \
|
||||||
|
--logs-dir /your/logs \
|
||||||
|
--job-name port_droid \
|
||||||
|
--partition your_partition \
|
||||||
|
--workers 2048 \
|
||||||
|
--cpus-per-task 8 \
|
||||||
|
--mem-per-cpu 1950M
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note on how to set your command line arguments**
|
||||||
|
|
||||||
|
Regarding `--partition`, find yours by running:
|
||||||
|
```bash
|
||||||
|
info --format="%R"`
|
||||||
|
```
|
||||||
|
and select the CPU partition if you have one. No GPU needed.
|
||||||
|
|
||||||
|
Regarding `--workers`, it is the number of slurm jobs you will launch in parallel. 2048 is the maximum number, since there is 2048 shards in Droid. This big number will certainly max-out your cluster.
|
||||||
|
|
||||||
|
Regarding `--cpus-per-task` and `--mem-per-cpu`, by default it will use ~16GB of RAM (8*1950M) which is recommended to load the raw frames and 8 CPUs which can be useful to parallelize the encoding of the frames.
|
||||||
|
|
||||||
|
Find the number of CPUs and Memory of the nodes of your partition by running:
|
||||||
|
```bash
|
||||||
|
sinfo -N -p your_partition -h -o "%N cpus=%c mem=%m"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Useful commands to check progress and debug**
|
||||||
|
|
||||||
|
Check if your jobs are running:
|
||||||
|
```bash
|
||||||
|
squeue -u $USER`
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see a list with job indices like `15125385_155` where `15125385` is the index of the run and `155` is the worker index. The output/print of this worker is written in real time in `/your/logs/job_name/slurm_jobs/15125385_155.out`. For instance, you can inspect the content of this file by running `less /your/logs/job_name/slurm_jobs/15125385_155.out`.
|
||||||
|
|
||||||
|
Check the progression of your jobs by running:
|
||||||
|
```bash
|
||||||
|
jobs_status /your/logs
|
||||||
|
```
|
||||||
|
|
||||||
|
If it's not 100% and no more slurm job is running, it means that some of them failed. Inspect the logs by running:
|
||||||
|
```bash
|
||||||
|
failed_logs /your/logs/job_name
|
||||||
|
```
|
||||||
|
|
||||||
|
If there is an issue in the code, you can fix it in debug mode with `--slurm 0` which allows to set breakpoint:
|
||||||
|
```bash
|
||||||
|
python examples/port_datasets/droid_rlds/slurm_port_shards.py --slurm 0 ...
|
||||||
|
```
|
||||||
|
|
||||||
|
And you can relaunch the same command, which will skip the completed jobs:
|
||||||
|
```bash
|
||||||
|
python examples/port_datasets/droid_rlds/slurm_port_shards.py --slurm 1 ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Once all jobs are completed, you will have one dataset per shard (e.g. `droid_1.0.1_world_2048_rank_1594`) saved on disk in your `/lerobot/home/dir/your_id` directory. You can find your `/lerobot/home/dir` by running:
|
||||||
|
```bash
|
||||||
|
python -c "from lerobot.common.constants import HF_LEROBOT_HOME;print(HF_LEROBOT_HOME)"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 2. Aggregate all shards
|
||||||
|
|
||||||
|
Run this script to start aggregation:
|
||||||
|
```bash
|
||||||
|
python examples/port_datasets/droid_rlds/slurm_aggregate_shards.py \
|
||||||
|
--repo-id your_id/droid_1.0.1 \
|
||||||
|
--logs-dir /your/logs \
|
||||||
|
--job-name aggr_droid \
|
||||||
|
--partition your_partition \
|
||||||
|
--workers 2048 \
|
||||||
|
--cpus-per-task 8 \
|
||||||
|
--mem-per-cpu 1950M
|
||||||
|
```
|
||||||
|
|
||||||
|
Once all jobs are completed, you will have one dataset your `/lerobot/home/dir/your_id/droid_1.0.1` directory.
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Upload dataset
|
||||||
|
|
||||||
|
Run this script to start uploading:
|
||||||
|
```bash
|
||||||
|
python examples/port_datasets/droid_rlds/slurm_upload.py \
|
||||||
|
--repo-id your_id/droid_1.0.1 \
|
||||||
|
--logs-dir /your/logs \
|
||||||
|
--job-name upload_droid \
|
||||||
|
--partition your_partition \
|
||||||
|
--workers 50 \
|
||||||
|
--cpus-per-task 4 \
|
||||||
|
--mem-per-cpu 1950M
|
||||||
|
```
|
||||||
69
examples/port_datasets/droid_rlds/display_error_files.py
Normal file
69
examples/port_datasets/droid_rlds/display_error_files.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def find_missing_workers(completions_dir, world_size):
|
||||||
|
"""Find workers that are not completed and returns their indices."""
|
||||||
|
full = list(range(world_size))
|
||||||
|
|
||||||
|
completed = []
|
||||||
|
for path in completions_dir.glob("*"):
|
||||||
|
if path.name in [".", ".."]:
|
||||||
|
continue
|
||||||
|
index = path.name.lstrip("0")
|
||||||
|
index = 0 if index == "" else int(index)
|
||||||
|
completed.append(index)
|
||||||
|
|
||||||
|
missing_workers = set(full) - set(completed)
|
||||||
|
return missing_workers
|
||||||
|
|
||||||
|
|
||||||
|
def find_output_files(slurm_dir, worker_indices):
|
||||||
|
"""Find output files associated to worker indices, and return tuples
|
||||||
|
of (worker index, output file path)
|
||||||
|
"""
|
||||||
|
out_files = []
|
||||||
|
for path in slurm_dir.glob("*.out"):
|
||||||
|
_, worker_id = path.name.replace(".out", "").split("_")
|
||||||
|
worker_id = int(worker_id)
|
||||||
|
if worker_id in worker_indices:
|
||||||
|
out_files.append((worker_id, path))
|
||||||
|
return out_files
|
||||||
|
|
||||||
|
|
||||||
|
def display_error_files(logs_dir, job_name):
|
||||||
|
executor_path = Path(logs_dir) / job_name / "executor.json"
|
||||||
|
completions_dir = Path(logs_dir) / job_name / "completions"
|
||||||
|
|
||||||
|
with open(executor_path) as f:
|
||||||
|
executor = json.load(f)
|
||||||
|
|
||||||
|
missing_workers = find_missing_workers(completions_dir, executor["world_size"])
|
||||||
|
|
||||||
|
for missing in sorted(missing_workers)[::-1]:
|
||||||
|
print(missing)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--logs-dir",
|
||||||
|
type=str,
|
||||||
|
help="Path to logs directory for `datatrove`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--job-name",
|
||||||
|
type=str,
|
||||||
|
default="port_droid",
|
||||||
|
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
display_error_files(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
430
examples/port_datasets/droid_rlds/port_droid.py
Normal file
430
examples/port_datasets/droid_rlds/port_droid.py
Normal file
@@ -0,0 +1,430 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow_datasets as tfds
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
|
||||||
|
|
||||||
|
DROID_SHARDS = 2048
|
||||||
|
DROID_FPS = 15
|
||||||
|
DROID_ROBOT_TYPE = "Franka"
|
||||||
|
|
||||||
|
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
|
||||||
|
DROID_FEATURES = {
|
||||||
|
# true on first step of the episode
|
||||||
|
"is_first": {
|
||||||
|
"dtype": "bool",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
# true on last step of the episode
|
||||||
|
"is_last": {
|
||||||
|
"dtype": "bool",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
# true on last step of the episode if it is a terminal step, True for demos
|
||||||
|
"is_terminal": {
|
||||||
|
"dtype": "bool",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
# language_instruction is also stored as "task" to follow LeRobot standard
|
||||||
|
"language_instruction": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"language_instruction_2": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"language_instruction_3": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"observation.state.gripper_position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["gripper"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"observation.state.cartesian_position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (6,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"observation.state.joint_position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (7,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# Add this new feature to follow LeRobot standard of using joint position + gripper
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (8,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# Initially called wrist_image_left
|
||||||
|
"observation.images.wrist_left": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (180, 320, 3),
|
||||||
|
"names": [
|
||||||
|
"height",
|
||||||
|
"width",
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# Initially called exterior_image_1_left
|
||||||
|
"observation.images.exterior_1_left": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (180, 320, 3),
|
||||||
|
"names": [
|
||||||
|
"height",
|
||||||
|
"width",
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
# Initially called exterior_image_2_left
|
||||||
|
"observation.images.exterior_2_left": {
|
||||||
|
"dtype": "video",
|
||||||
|
"shape": (180, 320, 3),
|
||||||
|
"names": [
|
||||||
|
"height",
|
||||||
|
"width",
|
||||||
|
"channels",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"action.gripper_position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["gripper"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action.gripper_velocity": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["gripper"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action.cartesian_position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (6,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action.cartesian_velocity": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (6,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action.joint_position": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (7,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"action.joint_velocity": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (7,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# This feature was called "action" in RLDS dataset and consists of [6x joint velocities, 1x gripper position]
|
||||||
|
"action.original": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (7,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
# Add this new feature to follow LeRobot standard of using joint position + gripper
|
||||||
|
"action": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (8,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"discount": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"reward": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
# Meta data that are the same for all frames in the episode
|
||||||
|
"task_category": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"building": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"collector_id": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"date": {
|
||||||
|
"dtype": "string",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
"camera_extrinsics.wrist_left": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (6,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"camera_extrinsics.exterior_1_left": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (6,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"camera_extrinsics.exterior_2_left": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (6,),
|
||||||
|
"names": {
|
||||||
|
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"is_episode_successful": {
|
||||||
|
"dtype": "bool",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_episode_successful(tf_episode_metadata):
|
||||||
|
# Adapted from: https://github.com/droid-dataset/droid_policy_learning/blob/dd1020eb20d981f90b5ff07dc80d80d5c0cb108b/robomimic/utils/rlds_utils.py#L8
|
||||||
|
return "/success/" in tf_episode_metadata["file_path"].numpy().decode()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lerobot_frames(tf_episode):
|
||||||
|
m = tf_episode["episode_metadata"]
|
||||||
|
frame_meta = {
|
||||||
|
"task_category": m["building"].numpy().decode(),
|
||||||
|
"building": m["building"].numpy().decode(),
|
||||||
|
"collector_id": m["collector_id"].numpy().decode(),
|
||||||
|
"date": m["date"].numpy().decode(),
|
||||||
|
"camera_extrinsics.wrist_left": m["extrinsics_wrist_cam"].numpy(),
|
||||||
|
"camera_extrinsics.exterior_1_left": m["extrinsics_exterior_cam_1"].numpy(),
|
||||||
|
"camera_extrinsics.exterior_2_left": m["extrinsics_exterior_cam_2"].numpy(),
|
||||||
|
"is_episode_successful": np.array([is_episode_successful(m)]),
|
||||||
|
}
|
||||||
|
for f in tf_episode["steps"]:
|
||||||
|
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
|
||||||
|
frame = {
|
||||||
|
"is_first": np.array([f["is_first"].numpy()]),
|
||||||
|
"is_last": np.array([f["is_last"].numpy()]),
|
||||||
|
"is_terminal": np.array([f["is_terminal"].numpy()]),
|
||||||
|
"language_instruction": f["language_instruction"].numpy().decode(),
|
||||||
|
"language_instruction_2": f["language_instruction_2"].numpy().decode(),
|
||||||
|
"language_instruction_3": f["language_instruction_3"].numpy().decode(),
|
||||||
|
"observation.state.gripper_position": f["observation"]["gripper_position"].numpy(),
|
||||||
|
"observation.state.cartesian_position": f["observation"]["cartesian_position"].numpy(),
|
||||||
|
"observation.state.joint_position": f["observation"]["joint_position"].numpy(),
|
||||||
|
"observation.images.wrist_left": f["observation"]["wrist_image_left"].numpy(),
|
||||||
|
"observation.images.exterior_1_left": f["observation"]["exterior_image_1_left"].numpy(),
|
||||||
|
"observation.images.exterior_2_left": f["observation"]["exterior_image_2_left"].numpy(),
|
||||||
|
"action.gripper_position": f["action_dict"]["gripper_position"].numpy(),
|
||||||
|
"action.gripper_velocity": f["action_dict"]["gripper_velocity"].numpy(),
|
||||||
|
"action.cartesian_position": f["action_dict"]["cartesian_position"].numpy(),
|
||||||
|
"action.cartesian_velocity": f["action_dict"]["cartesian_velocity"].numpy(),
|
||||||
|
"action.joint_position": f["action_dict"]["joint_position"].numpy(),
|
||||||
|
"action.joint_velocity": f["action_dict"]["joint_velocity"].numpy(),
|
||||||
|
"discount": np.array([f["discount"].numpy()]),
|
||||||
|
"reward": np.array([f["reward"].numpy()]),
|
||||||
|
"action.original": f["action"].numpy(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# language_instruction is also stored as "task" to follow LeRobot standard
|
||||||
|
frame["task"] = frame["language_instruction"]
|
||||||
|
|
||||||
|
# Add this new feature to follow LeRobot standard of using joint position + gripper
|
||||||
|
frame["observation.state"] = np.concatenate(
|
||||||
|
[frame["observation.state.joint_position"], frame["observation.state.gripper_position"]]
|
||||||
|
)
|
||||||
|
frame["action"] = np.concatenate([frame["action.joint_position"], frame["action.gripper_position"]])
|
||||||
|
|
||||||
|
# Meta data that are the same for all frames in the episode
|
||||||
|
frame.update(frame_meta)
|
||||||
|
|
||||||
|
# Cast fp64 to fp32
|
||||||
|
for key in frame:
|
||||||
|
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
|
||||||
|
frame[key] = frame[key].astype(np.float32)
|
||||||
|
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
|
||||||
|
def port_droid(
|
||||||
|
raw_dir: Path,
|
||||||
|
repo_id: str,
|
||||||
|
push_to_hub: bool = False,
|
||||||
|
num_shards: int | None = None,
|
||||||
|
shard_index: int | None = None,
|
||||||
|
):
|
||||||
|
dataset_name = raw_dir.parent.name
|
||||||
|
version = raw_dir.name
|
||||||
|
data_dir = raw_dir.parent.parent
|
||||||
|
|
||||||
|
builder = tfds.builder(f"{dataset_name}/{version}", data_dir=data_dir, version="")
|
||||||
|
|
||||||
|
if num_shards is not None:
|
||||||
|
tfds_num_shards = builder.info.splits["train"].num_shards
|
||||||
|
if tfds_num_shards != DROID_SHARDS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of shards of Droid dataset is expected to be {DROID_SHARDS} but is {tfds_num_shards}."
|
||||||
|
)
|
||||||
|
if num_shards != tfds_num_shards:
|
||||||
|
raise ValueError(
|
||||||
|
f"We only shard over the fixed number of shards provided by tensorflow dataset ({tfds_num_shards}), but {num_shards} shards provided instead."
|
||||||
|
)
|
||||||
|
if shard_index >= tfds_num_shards:
|
||||||
|
raise ValueError(
|
||||||
|
f"Shard index is greater than the num of shards ({shard_index} >= {num_shards})."
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
|
||||||
|
else:
|
||||||
|
raw_dataset = builder.as_dataset(split="train")
|
||||||
|
|
||||||
|
lerobot_dataset = LeRobotDataset.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
robot_type=DROID_ROBOT_TYPE,
|
||||||
|
fps=DROID_FPS,
|
||||||
|
features=DROID_FEATURES,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
num_episodes = raw_dataset.cardinality().numpy().item()
|
||||||
|
logging.info(f"Number of episodes {num_episodes}")
|
||||||
|
|
||||||
|
for episode_index, episode in enumerate(raw_dataset):
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"{episode_index} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
|
||||||
|
)
|
||||||
|
|
||||||
|
for frame in generate_lerobot_frames(episode):
|
||||||
|
lerobot_dataset.add_frame(frame)
|
||||||
|
|
||||||
|
lerobot_dataset.save_episode()
|
||||||
|
logging.info("Save_episode")
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
lerobot_dataset.push_to_hub(
|
||||||
|
# Add openx tag, since it belongs to the openx collection of datasets
|
||||||
|
tags=["openx"],
|
||||||
|
private=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_dataset(repo_id):
|
||||||
|
"""Sanity check that ensure meta data can be loaded and all files are present."""
|
||||||
|
meta = LeRobotDatasetMetadata(repo_id)
|
||||||
|
|
||||||
|
if meta.total_episodes == 0:
|
||||||
|
raise ValueError("Number of episodes is 0.")
|
||||||
|
|
||||||
|
for ep_idx in range(meta.total_episodes):
|
||||||
|
data_path = meta.root / meta.get_data_file_path(ep_idx)
|
||||||
|
|
||||||
|
if not data_path.exists():
|
||||||
|
raise ValueError(f"Parquet file is missing in: {data_path}")
|
||||||
|
|
||||||
|
for vid_key in meta.video_keys:
|
||||||
|
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
|
||||||
|
if not vid_path.exists():
|
||||||
|
raise ValueError(f"Video file is missing in: {vid_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-dir",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--push-to-hub",
|
||||||
|
action="store_true",
|
||||||
|
help="Upload to hub.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-shards",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Number of shards. Can be either None to load the full dataset, or 2048 to load one of the 2048 tensorflow dataset files.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shard-index",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Index of the shard. Can be either None to load the full dataset, or in [0,2047] to load one of the 2048 tensorflow dataset files.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
port_droid(**vars(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
293
examples/port_datasets/droid_rlds/slurm_aggregate_shards.py
Normal file
293
examples/port_datasets/droid_rlds/slurm_aggregate_shards.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||||
|
from datatrove.pipeline.base import PipelineStep
|
||||||
|
|
||||||
|
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||||
|
from lerobot.common.datasets.aggregate import validate_all_metadata
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
legacy_write_episode_stats,
|
||||||
|
legacy_write_task,
|
||||||
|
write_episode,
|
||||||
|
write_info,
|
||||||
|
)
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
class AggregateDatasets(PipelineStep):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_ids: list[str],
|
||||||
|
aggregated_repo_id: str,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.repo_ids = repo_ids
|
||||||
|
self.aggr_repo_id = aggregated_repo_id
|
||||||
|
|
||||||
|
self.create_aggr_dataset()
|
||||||
|
|
||||||
|
def create_aggr_dataset(self):
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
logging.info("Start aggregate_datasets")
|
||||||
|
|
||||||
|
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids]
|
||||||
|
|
||||||
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||||
|
|
||||||
|
# Create resulting dataset folder
|
||||||
|
aggr_meta = LeRobotDatasetMetadata.create(
|
||||||
|
repo_id=self.aggr_repo_id,
|
||||||
|
fps=fps,
|
||||||
|
robot_type=robot_type,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Find all tasks")
|
||||||
|
# find all tasks, deduplicate them, create new task indices for each dataset
|
||||||
|
# indexed by dataset index
|
||||||
|
datasets_task_index_to_aggr_task_index = {}
|
||||||
|
aggr_task_index = 0
|
||||||
|
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Find all tasks")):
|
||||||
|
task_index_to_aggr_task_index = {}
|
||||||
|
|
||||||
|
for task_index, task in meta.tasks.items():
|
||||||
|
if task not in aggr_meta.task_to_task_index:
|
||||||
|
# add the task to aggr tasks mappings
|
||||||
|
aggr_meta.tasks[aggr_task_index] = task
|
||||||
|
aggr_meta.task_to_task_index[task] = aggr_task_index
|
||||||
|
aggr_task_index += 1
|
||||||
|
|
||||||
|
# add task_index anyway
|
||||||
|
task_index_to_aggr_task_index[task_index] = aggr_meta.task_to_task_index[task]
|
||||||
|
|
||||||
|
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
|
||||||
|
|
||||||
|
logging.info("Prepare copy data and videos")
|
||||||
|
datasets_ep_idx_to_aggr_ep_idx = {}
|
||||||
|
datasets_aggr_episode_index_shift = {}
|
||||||
|
aggr_episode_index_shift = 0
|
||||||
|
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata, desc="Prepare copy data and videos")):
|
||||||
|
ep_idx_to_aggr_ep_idx = {}
|
||||||
|
|
||||||
|
for episode_index in range(meta.total_episodes):
|
||||||
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||||
|
ep_idx_to_aggr_ep_idx[episode_index] = aggr_episode_index
|
||||||
|
|
||||||
|
datasets_ep_idx_to_aggr_ep_idx[dataset_index] = ep_idx_to_aggr_ep_idx
|
||||||
|
datasets_aggr_episode_index_shift[dataset_index] = aggr_episode_index_shift
|
||||||
|
|
||||||
|
# populate episodes
|
||||||
|
for episode_index, episode_dict in meta.episodes.items():
|
||||||
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||||
|
episode_dict["episode_index"] = aggr_episode_index
|
||||||
|
aggr_meta.episodes[aggr_episode_index] = episode_dict
|
||||||
|
|
||||||
|
# populate episodes_stats
|
||||||
|
for episode_index, episode_stats in meta.episodes_stats.items():
|
||||||
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||||
|
aggr_meta.episodes_stats[aggr_episode_index] = episode_stats
|
||||||
|
|
||||||
|
# populate info
|
||||||
|
aggr_meta.info["total_episodes"] += meta.total_episodes
|
||||||
|
aggr_meta.info["total_frames"] += meta.total_frames
|
||||||
|
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
|
||||||
|
|
||||||
|
aggr_episode_index_shift += meta.total_episodes
|
||||||
|
|
||||||
|
logging.info("Write meta data")
|
||||||
|
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
|
||||||
|
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
|
||||||
|
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
|
||||||
|
|
||||||
|
# create a new episodes jsonl with updated episode_index using write_episode
|
||||||
|
for episode_dict in tqdm.tqdm(aggr_meta.episodes.values(), desc="Write episodes"):
|
||||||
|
write_episode(episode_dict, aggr_meta.root)
|
||||||
|
|
||||||
|
# create a new episode_stats jsonl with updated episode_index using write_episode_stats
|
||||||
|
for episode_index, episode_stats in tqdm.tqdm(
|
||||||
|
aggr_meta.episodes_stats.items(), desc="Write episodes stats"
|
||||||
|
):
|
||||||
|
legacy_write_episode_stats(episode_index, episode_stats, aggr_meta.root)
|
||||||
|
|
||||||
|
# create a new task jsonl with updated episode_index using write_task
|
||||||
|
for task_index, task in tqdm.tqdm(aggr_meta.tasks.items(), desc="Write tasks"):
|
||||||
|
legacy_write_task(task_index, task, aggr_meta.root)
|
||||||
|
|
||||||
|
write_info(aggr_meta.info, aggr_meta.root)
|
||||||
|
|
||||||
|
self.datasets_task_index_to_aggr_task_index = datasets_task_index_to_aggr_task_index
|
||||||
|
self.datasets_ep_idx_to_aggr_ep_idx = datasets_ep_idx_to_aggr_ep_idx
|
||||||
|
self.datasets_aggr_episode_index_shift = datasets_aggr_episode_index_shift
|
||||||
|
|
||||||
|
logging.info("Meta data done writing!")
|
||||||
|
|
||||||
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from lerobot.common.datasets.aggregate import get_update_episode_and_task_func
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
aggr_meta = LeRobotDatasetMetadata(self.aggr_repo_id)
|
||||||
|
all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in self.repo_ids]
|
||||||
|
|
||||||
|
if world_size != len(all_metadata):
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
dataset_index = rank
|
||||||
|
meta = all_metadata[dataset_index]
|
||||||
|
aggr_episode_index_shift = self.datasets_aggr_episode_index_shift[dataset_index]
|
||||||
|
|
||||||
|
logging.info("Copy data")
|
||||||
|
for episode_index in range(meta.total_episodes):
|
||||||
|
aggr_episode_index = self.datasets_ep_idx_to_aggr_ep_idx[dataset_index][episode_index]
|
||||||
|
data_path = meta.root / meta.get_data_file_path(episode_index)
|
||||||
|
aggr_data_path = aggr_meta.root / aggr_meta.get_data_file_path(aggr_episode_index)
|
||||||
|
|
||||||
|
# update episode_index and task_index
|
||||||
|
df = pd.read_parquet(data_path)
|
||||||
|
update_row_func = get_update_episode_and_task_func(
|
||||||
|
aggr_episode_index_shift, self.datasets_task_index_to_aggr_task_index[dataset_index]
|
||||||
|
)
|
||||||
|
df = df.apply(update_row_func, axis=1)
|
||||||
|
|
||||||
|
aggr_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.to_parquet(aggr_data_path)
|
||||||
|
|
||||||
|
logging.info("Copy videos")
|
||||||
|
for episode_index in range(meta.total_episodes):
|
||||||
|
aggr_episode_index = episode_index + aggr_episode_index_shift
|
||||||
|
for vid_key in meta.video_keys:
|
||||||
|
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
|
||||||
|
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
|
||||||
|
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(video_path, aggr_video_path)
|
||||||
|
|
||||||
|
# copy_command = f"cp {video_path} {aggr_video_path} &"
|
||||||
|
# subprocess.Popen(copy_command, shell=True)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
def make_aggregate_executor(
|
||||||
|
repo_ids, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||||
|
):
|
||||||
|
kwargs = {
|
||||||
|
"pipeline": [
|
||||||
|
AggregateDatasets(repo_ids, repo_id),
|
||||||
|
],
|
||||||
|
"logging_dir": str(logs_dir / job_name),
|
||||||
|
}
|
||||||
|
|
||||||
|
if slurm:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"job_name": job_name,
|
||||||
|
"tasks": DROID_SHARDS,
|
||||||
|
"workers": workers,
|
||||||
|
"time": "08:00:00",
|
||||||
|
"partition": partition,
|
||||||
|
"cpus_per_task": cpus_per_task,
|
||||||
|
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
executor = SlurmPipelineExecutor(**kwargs)
|
||||||
|
else:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"tasks": DROID_SHARDS,
|
||||||
|
"workers": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
executor = LocalPipelineExecutor(**kwargs)
|
||||||
|
|
||||||
|
return executor
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logs-dir",
|
||||||
|
type=Path,
|
||||||
|
help="Path to logs directory for `datatrove`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--job-name",
|
||||||
|
type=str,
|
||||||
|
default="aggr_droid",
|
||||||
|
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--slurm",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--partition",
|
||||||
|
type=str,
|
||||||
|
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpus-per-task",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of cpus that each slurm worker will use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mem-per-cpu",
|
||||||
|
type=str,
|
||||||
|
default="1950M",
|
||||||
|
help="Memory per cpu that each worker will use.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
kwargs = vars(args)
|
||||||
|
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||||
|
|
||||||
|
repo_ids = [f"{args.repo_id}_world_{DROID_SHARDS}_rank_{rank}" for rank in range(DROID_SHARDS)]
|
||||||
|
aggregate_executor = make_aggregate_executor(repo_ids, **kwargs)
|
||||||
|
aggregate_executor.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
147
examples/port_datasets/droid_rlds/slurm_port_shards.py
Normal file
147
examples/port_datasets/droid_rlds/slurm_port_shards.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||||
|
from datatrove.pipeline.base import PipelineStep
|
||||||
|
|
||||||
|
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||||
|
|
||||||
|
|
||||||
|
class PortDroidShards(PipelineStep):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
raw_dir: Path | str,
|
||||||
|
repo_id: str = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.raw_dir = Path(raw_dir)
|
||||||
|
self.repo_id = repo_id
|
||||||
|
|
||||||
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
from datasets.utils.tqdm import disable_progress_bars
|
||||||
|
|
||||||
|
from examples.port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
disable_progress_bars()
|
||||||
|
|
||||||
|
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_dataset(shard_repo_id)
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass # nosec B110 - Dataset doesn't exist yet, continue with porting
|
||||||
|
|
||||||
|
port_droid(
|
||||||
|
self.raw_dir,
|
||||||
|
shard_repo_id,
|
||||||
|
push_to_hub=False,
|
||||||
|
num_shards=world_size,
|
||||||
|
shard_index=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_dataset(shard_repo_id)
|
||||||
|
|
||||||
|
|
||||||
|
def make_port_executor(
|
||||||
|
raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||||
|
):
|
||||||
|
kwargs = {
|
||||||
|
"pipeline": [
|
||||||
|
PortDroidShards(raw_dir, repo_id),
|
||||||
|
],
|
||||||
|
"logging_dir": str(logs_dir / job_name),
|
||||||
|
}
|
||||||
|
|
||||||
|
if slurm:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"job_name": job_name,
|
||||||
|
"tasks": DROID_SHARDS,
|
||||||
|
"workers": workers,
|
||||||
|
"time": "08:00:00",
|
||||||
|
"partition": partition,
|
||||||
|
"cpus_per_task": cpus_per_task,
|
||||||
|
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
executor = SlurmPipelineExecutor(**kwargs)
|
||||||
|
else:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"tasks": 1,
|
||||||
|
"workers": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
executor = LocalPipelineExecutor(**kwargs)
|
||||||
|
|
||||||
|
return executor
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw-dir",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logs-dir",
|
||||||
|
type=Path,
|
||||||
|
help="Path to logs directory for `datatrove`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--job-name",
|
||||||
|
type=str,
|
||||||
|
default="port_droid",
|
||||||
|
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--slurm",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--partition",
|
||||||
|
type=str,
|
||||||
|
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpus-per-task",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of cpus that each slurm worker will use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mem-per-cpu",
|
||||||
|
type=str,
|
||||||
|
default="1950M",
|
||||||
|
help="Memory per cpu that each worker will use.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
kwargs = vars(args)
|
||||||
|
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||||
|
port_executor = make_port_executor(**kwargs)
|
||||||
|
port_executor.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
263
examples/port_datasets/droid_rlds/slurm_upload.py
Normal file
263
examples/port_datasets/droid_rlds/slurm_upload.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from datatrove.executor import LocalPipelineExecutor
|
||||||
|
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||||
|
from datatrove.pipeline.base import PipelineStep
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.constants import REPOCARD_NAME
|
||||||
|
|
||||||
|
from examples.port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
class UploadDataset(PipelineStep):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
branch: str | None = None,
|
||||||
|
revision: str | None = None,
|
||||||
|
tags: list | None = None,
|
||||||
|
license: str | None = "apache-2.0",
|
||||||
|
private: bool = False,
|
||||||
|
distant_repo_id: str | None = None,
|
||||||
|
**card_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.distant_repo_id = self.repo_id if distant_repo_id is None else distant_repo_id
|
||||||
|
self.branch = branch
|
||||||
|
self.tags = tags
|
||||||
|
self.license = license
|
||||||
|
self.private = private
|
||||||
|
self.card_kwargs = card_kwargs
|
||||||
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
|
|
||||||
|
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1":
|
||||||
|
logging.warning(
|
||||||
|
'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env '
|
||||||
|
"variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.create_repo()
|
||||||
|
|
||||||
|
def create_repo(self):
|
||||||
|
logging.info(f"Loading meta data from {self.repo_id}...")
|
||||||
|
meta = LeRobotDatasetMetadata(self.repo_id)
|
||||||
|
|
||||||
|
logging.info(f"Creating repo {self.distant_repo_id}...")
|
||||||
|
hub_api = HfApi()
|
||||||
|
hub_api.create_repo(
|
||||||
|
repo_id=self.distant_repo_id,
|
||||||
|
private=self.private,
|
||||||
|
repo_type="dataset",
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
if self.branch:
|
||||||
|
hub_api.create_branch(
|
||||||
|
repo_id=self.distant_repo_id,
|
||||||
|
branch=self.branch,
|
||||||
|
revision=self.revision,
|
||||||
|
repo_type="dataset",
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hub_api.file_exists(
|
||||||
|
self.distant_repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch
|
||||||
|
):
|
||||||
|
card = create_lerobot_dataset_card(
|
||||||
|
tags=self.tags, dataset_info=meta.info, license=self.license, **self.card_kwargs
|
||||||
|
)
|
||||||
|
card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch)
|
||||||
|
|
||||||
|
def list_files_recursively(directory):
|
||||||
|
base_path = Path(directory)
|
||||||
|
return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]
|
||||||
|
|
||||||
|
logging.info(f"Listing all local files from {self.repo_id}...")
|
||||||
|
self.file_paths = list_files_recursively(meta.root)
|
||||||
|
self.file_paths = sorted(self.file_paths)
|
||||||
|
|
||||||
|
def create_chunks(self, lst, n):
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
|
it = iter(lst)
|
||||||
|
return [list(islice(it, size)) for size in [len(lst) // n + (i < len(lst) % n) for i in range(n)]]
|
||||||
|
|
||||||
|
def create_commits(self, additions):
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
from huggingface_hub import create_commit
|
||||||
|
from huggingface_hub.utils import HfHubHTTPError
|
||||||
|
|
||||||
|
FILES_BETWEEN_COMMITS = 10 # noqa: N806
|
||||||
|
BASE_DELAY = 0.1 # noqa: N806
|
||||||
|
MAX_RETRIES = 12 # noqa: N806
|
||||||
|
|
||||||
|
# Split the files into smaller chunks for faster commit
|
||||||
|
# and avoiding "A commit has happened since" error
|
||||||
|
num_chunks = math.ceil(len(additions) / FILES_BETWEEN_COMMITS)
|
||||||
|
chunks = self.create_chunks(additions, num_chunks)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
retries = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
create_commit(
|
||||||
|
self.distant_repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
operations=chunk,
|
||||||
|
commit_message=f"DataTrove upload ({len(chunk)} files)",
|
||||||
|
revision=self.branch,
|
||||||
|
)
|
||||||
|
# TODO: every 100 chunks super_squach_commits()
|
||||||
|
logging.info("create_commit completed!")
|
||||||
|
break
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
if "A commit has happened since" in e.server_message:
|
||||||
|
if retries >= MAX_RETRIES:
|
||||||
|
logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.")
|
||||||
|
raise e
|
||||||
|
logging.info("Commit creation race condition issue. Waiting...")
|
||||||
|
time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2))
|
||||||
|
retries += 1
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from datasets.utils.tqdm import disable_progress_bars
|
||||||
|
from huggingface_hub import CommitOperationAdd, preupload_lfs_files
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
disable_progress_bars()
|
||||||
|
|
||||||
|
chunks = self.create_chunks(self.file_paths, world_size)
|
||||||
|
file_paths = chunks[rank]
|
||||||
|
|
||||||
|
if len(file_paths) == 0:
|
||||||
|
raise ValueError(file_paths)
|
||||||
|
|
||||||
|
logging.info("Pre-uploading LFS files...")
|
||||||
|
for i, path in enumerate(file_paths):
|
||||||
|
logging.info(f"{i}: {path}")
|
||||||
|
|
||||||
|
meta = LeRobotDatasetMetadata(self.repo_id)
|
||||||
|
additions = [
|
||||||
|
CommitOperationAdd(path_in_repo=path, path_or_fileobj=meta.root / path) for path in file_paths
|
||||||
|
]
|
||||||
|
preupload_lfs_files(
|
||||||
|
repo_id=self.distant_repo_id, repo_type="dataset", additions=additions, revision=self.branch
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Creating commits...")
|
||||||
|
self.create_commits(additions)
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
def make_upload_executor(
|
||||||
|
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||||
|
):
|
||||||
|
kwargs = {
|
||||||
|
"pipeline": [
|
||||||
|
UploadDataset(repo_id),
|
||||||
|
],
|
||||||
|
"logging_dir": str(logs_dir / job_name),
|
||||||
|
}
|
||||||
|
|
||||||
|
if slurm:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"job_name": job_name,
|
||||||
|
"tasks": DROID_SHARDS,
|
||||||
|
"workers": workers,
|
||||||
|
"time": "08:00:00",
|
||||||
|
"partition": partition,
|
||||||
|
"cpus_per_task": cpus_per_task,
|
||||||
|
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
executor = SlurmPipelineExecutor(**kwargs)
|
||||||
|
else:
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"tasks": DROID_SHARDS,
|
||||||
|
"workers": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
executor = LocalPipelineExecutor(**kwargs)
|
||||||
|
|
||||||
|
return executor
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--logs-dir",
|
||||||
|
type=Path,
|
||||||
|
help="Path to logs directory for `datatrove`.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--job-name",
|
||||||
|
type=str,
|
||||||
|
default="upload_droid",
|
||||||
|
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--slurm",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--workers",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Number of slurm workers. It should be less than the maximum number of shards.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--partition",
|
||||||
|
type=str,
|
||||||
|
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpus-per-task",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of cpus that each slurm worker will use.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mem-per-cpu",
|
||||||
|
type=str,
|
||||||
|
default="1950M",
|
||||||
|
help="Memory per cpu that each worker will use.",
|
||||||
|
)
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
kwargs = vars(args)
|
||||||
|
kwargs["slurm"] = kwargs.pop("slurm") == 1
|
||||||
|
upload_executor = make_upload_executor(**kwargs)
|
||||||
|
upload_executor.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
359
lerobot/common/datasets/aggregate.py
Normal file
359
lerobot/common/datasets/aggregate.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_DATA_PATH,
|
||||||
|
DEFAULT_EPISODES_PATH,
|
||||||
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_VIDEO_PATH,
|
||||||
|
concat_video_files,
|
||||||
|
get_parquet_file_size_in_mb,
|
||||||
|
get_video_size_in_mb,
|
||||||
|
to_parquet_with_hf_images,
|
||||||
|
update_chunk_file_indices,
|
||||||
|
write_info,
|
||||||
|
write_stats,
|
||||||
|
write_tasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
|
# validate same fps, robot_type, features
|
||||||
|
|
||||||
|
fps = all_metadata[0].fps
|
||||||
|
robot_type = all_metadata[0].robot_type
|
||||||
|
features = all_metadata[0].features
|
||||||
|
|
||||||
|
for meta in tqdm.tqdm(all_metadata, desc="Validate all meta data"):
|
||||||
|
if fps != meta.fps:
|
||||||
|
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
|
||||||
|
if robot_type != meta.robot_type:
|
||||||
|
raise ValueError(
|
||||||
|
f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}."
|
||||||
|
)
|
||||||
|
if features != meta.features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Same features is expected, but got features={meta.features} instead of {features}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return fps, robot_type, features
|
||||||
|
|
||||||
|
|
||||||
|
def update_data_df(df, src_meta, dst_meta):
|
||||||
|
def _update(row):
|
||||||
|
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||||
|
row["index"] = row["index"] + dst_meta.info["total_frames"]
|
||||||
|
task = src_meta.tasks.iloc[row["task_index"]].name
|
||||||
|
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
|
||||||
|
return row
|
||||||
|
|
||||||
|
return df.apply(_update, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
def update_meta_data(
|
||||||
|
df,
|
||||||
|
dst_meta,
|
||||||
|
meta_idx,
|
||||||
|
data_idx,
|
||||||
|
videos_idx,
|
||||||
|
):
|
||||||
|
def _update(row):
|
||||||
|
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
|
||||||
|
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
|
||||||
|
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"]
|
||||||
|
row["data/file_index"] = row["data/file_index"] + data_idx["file"]
|
||||||
|
for key, video_idx in videos_idx.items():
|
||||||
|
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"]
|
||||||
|
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"]
|
||||||
|
row[f"videos/{key}/from_timestamp"] = (
|
||||||
|
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||||
|
)
|
||||||
|
row[f"videos/{key}/to_timestamp"] = (
|
||||||
|
row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||||
|
)
|
||||||
|
row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||||
|
row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||||
|
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
|
||||||
|
return row
|
||||||
|
|
||||||
|
return df.apply(_update, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
|
||||||
|
logging.info("Start aggregate_datasets")
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
all_metadata = (
|
||||||
|
[LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids]
|
||||||
|
if roots is None
|
||||||
|
else [
|
||||||
|
LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||||
|
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||||
|
|
||||||
|
# Initialize output dataset metadata
|
||||||
|
dst_meta = LeRobotDatasetMetadata.create(
|
||||||
|
repo_id=aggr_repo_id,
|
||||||
|
fps=fps,
|
||||||
|
robot_type=robot_type,
|
||||||
|
features=features,
|
||||||
|
root=aggr_root,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Aggregate task info
|
||||||
|
logging.info("Find all tasks")
|
||||||
|
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||||
|
dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||||
|
|
||||||
|
# Track counters and indices
|
||||||
|
meta_idx = {"chunk": 0, "file": 0}
|
||||||
|
data_idx = {"chunk": 0, "file": 0}
|
||||||
|
videos_idx = {
|
||||||
|
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_meta.episodes = {}
|
||||||
|
|
||||||
|
# Process each dataset
|
||||||
|
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||||
|
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx)
|
||||||
|
data_idx = aggregate_data(src_meta, dst_meta, data_idx)
|
||||||
|
|
||||||
|
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||||
|
|
||||||
|
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||||
|
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||||
|
|
||||||
|
finalize_aggregation(dst_meta, all_metadata)
|
||||||
|
logging.info("Aggregation complete.")
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# Helper Functions
|
||||||
|
# -------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_videos(src_meta, dst_meta, videos_idx):
|
||||||
|
"""
|
||||||
|
Aggregates video chunks from a dataset into the aggregated dataset folder.
|
||||||
|
"""
|
||||||
|
for key, video_idx in videos_idx.items():
|
||||||
|
# Get unique (chunk, file) combinations
|
||||||
|
unique_chunk_file_pairs = {
|
||||||
|
(chunk, file)
|
||||||
|
for chunk, file in zip(
|
||||||
|
src_meta.episodes[f"videos/{key}/chunk_index"],
|
||||||
|
src_meta.episodes[f"videos/{key}/file_index"],
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Current target chunk/file index
|
||||||
|
chunk_idx = video_idx["chunk"]
|
||||||
|
file_idx = video_idx["file"]
|
||||||
|
|
||||||
|
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||||
|
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
|
video_key=key,
|
||||||
|
chunk_index=src_chunk_idx,
|
||||||
|
file_index=src_file_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
|
video_key=key,
|
||||||
|
chunk_index=chunk_idx,
|
||||||
|
file_index=file_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not dst_path.exists():
|
||||||
|
# First write to this destination file
|
||||||
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(str(src_path), str(dst_path))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check file sizes before appending
|
||||||
|
src_size = get_video_size_in_mb(src_path)
|
||||||
|
dst_size = get_video_size_in_mb(dst_path)
|
||||||
|
|
||||||
|
if dst_size + src_size >= DEFAULT_VIDEO_FILE_SIZE_IN_MB:
|
||||||
|
# Rotate to a new chunk/file
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||||
|
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
|
video_key=key,
|
||||||
|
chunk_index=chunk_idx,
|
||||||
|
file_index=file_idx,
|
||||||
|
)
|
||||||
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(str(src_path), str(dst_path))
|
||||||
|
else:
|
||||||
|
# Append to existing video file
|
||||||
|
concat_video_files(
|
||||||
|
[dst_path, src_path],
|
||||||
|
dst_meta.root,
|
||||||
|
key,
|
||||||
|
chunk_idx,
|
||||||
|
file_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the videos_idx with the final chunk and file indices for this key
|
||||||
|
videos_idx[key]["chunk"] = chunk_idx
|
||||||
|
videos_idx[key]["file"] = file_idx
|
||||||
|
|
||||||
|
return videos_idx
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_data(src_meta, dst_meta, data_idx):
|
||||||
|
unique_chunk_file_ids = {
|
||||||
|
(c, f)
|
||||||
|
for c, f in zip(
|
||||||
|
src_meta.episodes["data/chunk_index"], src_meta.episodes["data/file_index"], strict=False
|
||||||
|
)
|
||||||
|
}
|
||||||
|
for src_chunk_idx, src_file_idx in unique_chunk_file_ids:
|
||||||
|
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
|
||||||
|
chunk_index=src_chunk_idx, file_index=src_file_idx
|
||||||
|
)
|
||||||
|
df = pd.read_parquet(src_path)
|
||||||
|
df = update_data_df(df, src_meta, dst_meta)
|
||||||
|
|
||||||
|
data_idx = append_or_create_parquet_file(
|
||||||
|
df,
|
||||||
|
src_path,
|
||||||
|
data_idx,
|
||||||
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_DATA_PATH,
|
||||||
|
contains_images=len(dst_meta.image_keys) > 0,
|
||||||
|
aggr_root=dst_meta.root,
|
||||||
|
)
|
||||||
|
|
||||||
|
return data_idx
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||||
|
chunk_file_ids = {
|
||||||
|
(c, f)
|
||||||
|
for c, f in zip(
|
||||||
|
src_meta.episodes["meta/episodes/chunk_index"],
|
||||||
|
src_meta.episodes["meta/episodes/file_index"],
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
for chunk_idx, file_idx in chunk_file_ids:
|
||||||
|
src_path = src_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
df = pd.read_parquet(src_path)
|
||||||
|
df = update_meta_data(
|
||||||
|
df,
|
||||||
|
dst_meta,
|
||||||
|
meta_idx,
|
||||||
|
data_idx,
|
||||||
|
videos_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in videos_idx:
|
||||||
|
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||||
|
|
||||||
|
meta_idx = append_or_create_parquet_file(
|
||||||
|
df,
|
||||||
|
src_path,
|
||||||
|
meta_idx,
|
||||||
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_EPISODES_PATH,
|
||||||
|
contains_images=False,
|
||||||
|
aggr_root=dst_meta.root,
|
||||||
|
)
|
||||||
|
|
||||||
|
return meta_idx
|
||||||
|
|
||||||
|
|
||||||
|
def append_or_create_parquet_file(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
src_path: Path,
|
||||||
|
idx: dict[str, int],
|
||||||
|
max_mb: float,
|
||||||
|
chunk_size: int,
|
||||||
|
default_path: str,
|
||||||
|
contains_images: bool = False,
|
||||||
|
aggr_root: Path = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Safely appends or creates a Parquet file at dst_path based on size constraints.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
df (pd.DataFrame): Data to write.
|
||||||
|
src_path (Path): Path to source file (used to get size).
|
||||||
|
idx (dict): Dictionary containing 'chunk' and 'file' indices.
|
||||||
|
max_mb (float): Maximum allowed file size in MB.
|
||||||
|
chunk_size (int): Maximum number of files per chunk.
|
||||||
|
default_path (str): Format string for generating a new file path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Updated index dictionary.
|
||||||
|
"""
|
||||||
|
# Initial destination path - use the correct default_path parameter
|
||||||
|
dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||||
|
|
||||||
|
# If destination file doesn't exist, just write the new one
|
||||||
|
if not dst_path.exists():
|
||||||
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if contains_images:
|
||||||
|
to_parquet_with_hf_images(df, dst_path)
|
||||||
|
else:
|
||||||
|
df.to_parquet(dst_path)
|
||||||
|
return idx
|
||||||
|
|
||||||
|
# Otherwise, check if we exceed the size limit
|
||||||
|
src_size = get_parquet_file_size_in_mb(src_path)
|
||||||
|
dst_size = get_parquet_file_size_in_mb(dst_path)
|
||||||
|
|
||||||
|
if dst_size + src_size >= max_mb:
|
||||||
|
# File is too large, move to a new one
|
||||||
|
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||||
|
new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||||
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
final_df = df
|
||||||
|
target_path = new_path
|
||||||
|
else:
|
||||||
|
# Append to existing file
|
||||||
|
existing_df = pd.read_parquet(dst_path)
|
||||||
|
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||||
|
target_path = dst_path
|
||||||
|
|
||||||
|
if contains_images:
|
||||||
|
to_parquet_with_hf_images(final_df, target_path)
|
||||||
|
else:
|
||||||
|
final_df.to_parquet(target_path)
|
||||||
|
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
def finalize_aggregation(aggr_meta, all_metadata):
|
||||||
|
logging.info("write tasks")
|
||||||
|
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||||
|
|
||||||
|
logging.info("write info")
|
||||||
|
aggr_meta.info.update(
|
||||||
|
{
|
||||||
|
"total_tasks": len(aggr_meta.tasks),
|
||||||
|
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
||||||
|
"total_frames": sum(m.total_frames for m in all_metadata),
|
||||||
|
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
write_info(aggr_meta.info, aggr_meta.root)
|
||||||
|
|
||||||
|
logging.info("write stats")
|
||||||
|
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
|
||||||
|
write_stats(aggr_meta.stats, aggr_meta.root)
|
||||||
@@ -47,6 +47,18 @@ If you encounter a problem, contact LeRobot maintainers on [Discord](https://dis
|
|||||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
V30_MESSAGE = """
|
||||||
|
The dataset you requested ({repo_id}) is in {version} format.
|
||||||
|
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
|
||||||
|
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
|
||||||
|
```
|
||||||
|
python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py --repo-id={repo_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||||
|
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||||
|
"""
|
||||||
|
|
||||||
FUTURE_MESSAGE = """
|
FUTURE_MESSAGE = """
|
||||||
The dataset you requested ({repo_id}) is only available in {version} format.
|
The dataset you requested ({repo_id}) is only available in {version} format.
|
||||||
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
|
||||||
@@ -58,7 +70,14 @@ class CompatibilityError(Exception): ...
|
|||||||
|
|
||||||
class BackwardCompatibilityError(CompatibilityError):
|
class BackwardCompatibilityError(CompatibilityError):
|
||||||
def __init__(self, repo_id: str, version: packaging.version.Version):
|
def __init__(self, repo_id: str, version: packaging.version.Version):
|
||||||
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
if version.major == 3:
|
||||||
|
message = V30_MESSAGE.format(repo_id=repo_id, version=version)
|
||||||
|
elif version.major == 2:
|
||||||
|
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)."
|
||||||
|
)
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,16 +16,18 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import packaging.version
|
import packaging.version
|
||||||
|
import pandas as pd
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from datasets import concatenate_datasets, load_dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import HfApi, snapshot_download
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
from huggingface_hub.constants import REPOCARD_NAME
|
from huggingface_hub.constants import REPOCARD_NAME
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
@@ -34,36 +36,41 @@ from lerobot.common.constants import HF_LEROBOT_HOME
|
|||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_EPISODES_PATH,
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_IMAGE_PATH,
|
DEFAULT_IMAGE_PATH,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
TASKS_PATH,
|
|
||||||
_validate_feature_names,
|
_validate_feature_names,
|
||||||
append_jsonlines,
|
|
||||||
backward_compatible_episodes_stats,
|
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
check_timestamps_sync,
|
|
||||||
check_version_compatibility,
|
check_version_compatibility,
|
||||||
|
concat_video_files,
|
||||||
create_empty_dataset_info,
|
create_empty_dataset_info,
|
||||||
create_lerobot_dataset_card,
|
create_lerobot_dataset_card,
|
||||||
embed_images,
|
embed_images,
|
||||||
|
flatten_dict,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_episode_data_index,
|
get_hf_dataset_size_in_mb,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
|
get_parquet_file_size_in_mb,
|
||||||
|
get_parquet_num_frames,
|
||||||
get_safe_version,
|
get_safe_version,
|
||||||
|
get_video_duration_in_s,
|
||||||
|
get_video_size_in_mb,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
is_valid_version,
|
is_valid_version,
|
||||||
load_episodes,
|
load_episodes,
|
||||||
load_episodes_stats,
|
|
||||||
load_info,
|
load_info,
|
||||||
|
load_nested_dataset,
|
||||||
load_stats,
|
load_stats,
|
||||||
load_tasks,
|
load_tasks,
|
||||||
|
to_parquet_with_hf_images,
|
||||||
|
update_chunk_file_indices,
|
||||||
validate_episode_buffer,
|
validate_episode_buffer,
|
||||||
validate_frame,
|
validate_frame,
|
||||||
write_episode,
|
|
||||||
write_episode_stats,
|
|
||||||
write_info,
|
write_info,
|
||||||
write_json,
|
write_json,
|
||||||
|
write_stats,
|
||||||
|
write_tasks,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
@@ -73,7 +80,7 @@ from lerobot.common.datasets.video_utils import (
|
|||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
CODEBASE_VERSION = "v2.1"
|
CODEBASE_VERSION = "v3.0"
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDatasetMetadata:
|
class LeRobotDatasetMetadata:
|
||||||
@@ -97,20 +104,18 @@ class LeRobotDatasetMetadata:
|
|||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
|
|
||||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||||
|
# TODO(rcadene): instead of downloading all episodes metadata files,
|
||||||
|
# download only the ones associated to the requested episodes. This would
|
||||||
|
# require adding `episodes: list[int]` as argument.
|
||||||
self.pull_from_repo(allow_patterns="meta/")
|
self.pull_from_repo(allow_patterns="meta/")
|
||||||
self.load_metadata()
|
self.load_metadata()
|
||||||
|
|
||||||
def load_metadata(self):
|
def load_metadata(self):
|
||||||
self.info = load_info(self.root)
|
self.info = load_info(self.root)
|
||||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
self.tasks = load_tasks(self.root)
|
||||||
self.episodes = load_episodes(self.root)
|
self.episodes = load_episodes(self.root)
|
||||||
if self._version < packaging.version.parse("v2.1"):
|
self.stats = load_stats(self.root)
|
||||||
self.stats = load_stats(self.root)
|
|
||||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
|
||||||
else:
|
|
||||||
self.episodes_stats = load_episodes_stats(self.root)
|
|
||||||
self.stats = aggregate_stats(list(self.episodes_stats.values()))
|
|
||||||
|
|
||||||
def pull_from_repo(
|
def pull_from_repo(
|
||||||
self,
|
self,
|
||||||
@@ -132,18 +137,19 @@ class LeRobotDatasetMetadata:
|
|||||||
return packaging.version.parse(self.info["codebase_version"])
|
return packaging.version.parse(self.info["codebase_version"])
|
||||||
|
|
||||||
def get_data_file_path(self, ep_index: int) -> Path:
|
def get_data_file_path(self, ep_index: int) -> Path:
|
||||||
ep_chunk = self.get_episode_chunk(ep_index)
|
ep = self.episodes[ep_index]
|
||||||
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
chunk_idx = ep["data/chunk_index"]
|
||||||
|
file_idx = ep["data/file_index"]
|
||||||
|
fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||||
ep_chunk = self.get_episode_chunk(ep_index)
|
ep = self.episodes[ep_index]
|
||||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
chunk_idx = ep[f"videos/{vid_key}/chunk_index"]
|
||||||
|
file_idx = ep[f"videos/{vid_key}/file_index"]
|
||||||
|
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
def get_episode_chunk(self, ep_index: int) -> int:
|
|
||||||
return ep_index // self.chunks_size
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_path(self) -> str:
|
def data_path(self) -> str:
|
||||||
"""Formattable string for the parquet files."""
|
"""Formattable string for the parquet files."""
|
||||||
@@ -210,39 +216,108 @@ class LeRobotDatasetMetadata:
|
|||||||
return self.info["total_tasks"]
|
return self.info["total_tasks"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_chunks(self) -> int:
|
def chunks_size(self) -> int:
|
||||||
"""Total number of chunks (groups of episodes)."""
|
"""Max number of files per chunk."""
|
||||||
return self.info["total_chunks"]
|
return self.info["chunks_size"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chunks_size(self) -> int:
|
def data_files_size_in_mb(self) -> int:
|
||||||
"""Max number of episodes per chunk."""
|
"""Max size of data file in mega bytes."""
|
||||||
return self.info["chunks_size"]
|
return self.info["data_files_size_in_mb"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_files_size_in_mb(self) -> int:
|
||||||
|
"""Max size of video file in mega bytes."""
|
||||||
|
return self.info["video_files_size_in_mb"]
|
||||||
|
|
||||||
def get_task_index(self, task: str) -> int | None:
|
def get_task_index(self, task: str) -> int | None:
|
||||||
"""
|
"""
|
||||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||||
otherwise return None.
|
otherwise return None.
|
||||||
"""
|
"""
|
||||||
return self.task_to_task_index.get(task, None)
|
if task in self.tasks.index:
|
||||||
|
return int(self.tasks.loc[task].task_index)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def add_task(self, task: str):
|
def save_episode_tasks(self, tasks: list[str]):
|
||||||
|
if len(set(tasks)) != len(tasks):
|
||||||
|
raise ValueError(f"Tasks are not unique: {tasks}")
|
||||||
|
|
||||||
|
if self.tasks is None:
|
||||||
|
new_tasks = tasks
|
||||||
|
task_indices = range(len(tasks))
|
||||||
|
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
|
||||||
|
else:
|
||||||
|
new_tasks = [task for task in tasks if task not in self.tasks.index]
|
||||||
|
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
|
||||||
|
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
|
||||||
|
self.tasks.loc[task] = task_idx
|
||||||
|
|
||||||
|
if len(new_tasks) > 0:
|
||||||
|
# Update on disk
|
||||||
|
write_tasks(self.tasks, self.root)
|
||||||
|
|
||||||
|
def _save_episode_metadata(self, episode_dict: dict) -> None:
|
||||||
|
"""Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata.
|
||||||
|
|
||||||
|
This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset,
|
||||||
|
and saves it as a parquet file. It handles both the creation of new parquet files and the
|
||||||
|
updating of existing ones based on size constraints. After saving the metadata, it reloads
|
||||||
|
the Hugging Face dataset to ensure it is up-to-date.
|
||||||
|
|
||||||
|
Notes: We both need to update parquet files and HF dataset:
|
||||||
|
- `pandas` loads parquet file in RAM
|
||||||
|
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||||
|
or loads directly from pyarrow cache.
|
||||||
"""
|
"""
|
||||||
Given a task in natural language, add it to the dictionary of tasks.
|
# Convert buffer into HF Dataset
|
||||||
"""
|
episode_dict = {key: [value] for key, value in episode_dict.items()}
|
||||||
if task in self.task_to_task_index:
|
ep_dataset = Dataset.from_dict(episode_dict)
|
||||||
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
|
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||||
|
df = pd.DataFrame(ep_dataset)
|
||||||
|
num_frames = episode_dict["length"][0]
|
||||||
|
|
||||||
task_index = self.info["total_tasks"]
|
if self.episodes is None:
|
||||||
self.task_to_task_index[task] = task_index
|
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||||
self.tasks[task_index] = task
|
chunk_idx, file_idx = 0, 0
|
||||||
self.info["total_tasks"] += 1
|
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||||
|
df["meta/episodes/file_index"] = [file_idx]
|
||||||
|
df["dataset_from_index"] = [0]
|
||||||
|
df["dataset_to_index"] = [num_frames]
|
||||||
|
else:
|
||||||
|
# Retrieve information from the latest parquet file
|
||||||
|
latest_ep = self.episodes[-1]
|
||||||
|
chunk_idx = latest_ep["meta/episodes/chunk_index"]
|
||||||
|
file_idx = latest_ep["meta/episodes/file_index"]
|
||||||
|
|
||||||
task_dict = {
|
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
"task_index": task_index,
|
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||||
"task": task,
|
|
||||||
}
|
if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb:
|
||||||
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
# Size limit is reached, prepare new parquet file
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size)
|
||||||
|
|
||||||
|
# Update the existing pandas dataframe with new row
|
||||||
|
df["meta/episodes/chunk_index"] = [chunk_idx]
|
||||||
|
df["meta/episodes/file_index"] = [file_idx]
|
||||||
|
df["dataset_from_index"] = [latest_ep["dataset_to_index"]]
|
||||||
|
df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames]
|
||||||
|
|
||||||
|
if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb:
|
||||||
|
# Size limit wasnt reached, concatenate latest dataframe with new one
|
||||||
|
latest_df = pd.read_parquet(latest_path)
|
||||||
|
df = pd.concat([latest_df, df], ignore_index=True)
|
||||||
|
|
||||||
|
# Write the resulting dataframe from RAM to disk
|
||||||
|
path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.to_parquet(path, index=False)
|
||||||
|
|
||||||
|
# Update the Hugging Face dataset by reloading it.
|
||||||
|
# This process should be fast because only the latest Parquet file has been modified.
|
||||||
|
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
||||||
|
self.episodes = load_episodes(self.root)
|
||||||
|
|
||||||
def save_episode(
|
def save_episode(
|
||||||
self,
|
self,
|
||||||
@@ -250,32 +325,28 @@ class LeRobotDatasetMetadata:
|
|||||||
episode_length: int,
|
episode_length: int,
|
||||||
episode_tasks: list[str],
|
episode_tasks: list[str],
|
||||||
episode_stats: dict[str, dict],
|
episode_stats: dict[str, dict],
|
||||||
|
episode_metadata: dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.info["total_episodes"] += 1
|
|
||||||
self.info["total_frames"] += episode_length
|
|
||||||
|
|
||||||
chunk = self.get_episode_chunk(episode_index)
|
|
||||||
if chunk >= self.total_chunks:
|
|
||||||
self.info["total_chunks"] += 1
|
|
||||||
|
|
||||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
|
||||||
self.info["total_videos"] += len(self.video_keys)
|
|
||||||
if len(self.video_keys) > 0:
|
|
||||||
self.update_video_info()
|
|
||||||
|
|
||||||
write_info(self.info, self.root)
|
|
||||||
|
|
||||||
episode_dict = {
|
episode_dict = {
|
||||||
"episode_index": episode_index,
|
"episode_index": episode_index,
|
||||||
"tasks": episode_tasks,
|
"tasks": episode_tasks,
|
||||||
"length": episode_length,
|
"length": episode_length,
|
||||||
}
|
}
|
||||||
self.episodes[episode_index] = episode_dict
|
episode_dict.update(episode_metadata)
|
||||||
write_episode(episode_dict, self.root)
|
episode_dict.update(flatten_dict({"stats": episode_stats}))
|
||||||
|
self._save_episode_metadata(episode_dict)
|
||||||
|
|
||||||
self.episodes_stats[episode_index] = episode_stats
|
# Update info
|
||||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
self.info["total_episodes"] += 1
|
||||||
write_episode_stats(episode_index, episode_stats, self.root)
|
self.info["total_frames"] += episode_length
|
||||||
|
self.info["total_tasks"] = len(self.tasks)
|
||||||
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||||
|
if len(self.video_keys) > 0:
|
||||||
|
self.update_video_info()
|
||||||
|
write_info(self.info, self.root)
|
||||||
|
|
||||||
|
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||||
|
write_stats(self.stats, self.root)
|
||||||
|
|
||||||
def update_video_info(self) -> None:
|
def update_video_info(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -315,12 +386,12 @@ class LeRobotDatasetMetadata:
|
|||||||
|
|
||||||
obj.root.mkdir(parents=True, exist_ok=False)
|
obj.root.mkdir(parents=True, exist_ok=False)
|
||||||
|
|
||||||
# TODO(aliberts, rcadene): implement sanity check for features
|
|
||||||
features = {**features, **DEFAULT_FEATURES}
|
features = {**features, **DEFAULT_FEATURES}
|
||||||
_validate_feature_names(features)
|
_validate_feature_names(features)
|
||||||
|
|
||||||
obj.tasks, obj.task_to_task_index = {}, {}
|
obj.tasks = None
|
||||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
obj.episodes = None
|
||||||
|
obj.stats = None
|
||||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
|
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
|
||||||
if len(obj.video_keys) > 0 and not use_videos:
|
if len(obj.video_keys) > 0 and not use_videos:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
@@ -465,29 +536,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.meta = LeRobotDatasetMetadata(
|
self.meta = LeRobotDatasetMetadata(
|
||||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||||
)
|
)
|
||||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
|
||||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
|
||||||
self.stats = aggregate_stats(episodes_stats)
|
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
try:
|
try:
|
||||||
if force_cache_sync:
|
if force_cache_sync:
|
||||||
raise FileNotFoundError
|
raise FileNotFoundError
|
||||||
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
|
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
self.download_episodes(download_videos)
|
self.download(download_videos)
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
|
||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
|
||||||
|
|
||||||
# Check timestamps
|
|
||||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
|
||||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
|
||||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
|
||||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
|
||||||
|
|
||||||
# Setup delta_indices
|
# Setup delta_indices
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||||
@@ -563,7 +622,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
def download_episodes(self, download_videos: bool = True) -> None:
|
def download(self, download_videos: bool = True) -> None:
|
||||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||||
@@ -571,11 +630,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
"""
|
"""
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
files = None
|
|
||||||
ignore_patterns = None if download_videos else "videos/"
|
ignore_patterns = None if download_videos else "videos/"
|
||||||
|
files = None
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
files = self.get_episodes_file_paths()
|
files = self.get_episodes_file_paths()
|
||||||
|
|
||||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||||
|
|
||||||
def get_episodes_file_paths(self) -> list[Path]:
|
def get_episodes_file_paths(self) -> list[Path]:
|
||||||
@@ -588,19 +646,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
for ep_idx in episodes
|
for ep_idx in episodes
|
||||||
]
|
]
|
||||||
fpaths += video_files
|
fpaths += video_files
|
||||||
|
# episodes are stored in the same files, so we return unique paths only
|
||||||
|
fpaths = list(set(fpaths))
|
||||||
return fpaths
|
return fpaths
|
||||||
|
|
||||||
def load_hf_dataset(self) -> datasets.Dataset:
|
def load_hf_dataset(self) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
if self.episodes is None:
|
hf_dataset = load_nested_dataset(self.root / "data")
|
||||||
path = str(self.root / "data")
|
|
||||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
|
||||||
else:
|
|
||||||
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
|
||||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
|
||||||
|
|
||||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
@@ -608,8 +660,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
features = get_hf_features_from_features(self.features)
|
features = get_hf_features_from_features(self.features)
|
||||||
ft_dict = {col: [] for col in features}
|
ft_dict = {col: [] for col in features}
|
||||||
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
|
||||||
|
|
||||||
# TODO(aliberts): hf_dataset.set_format("torch")
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
@@ -641,15 +691,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
return get_hf_features_from_features(self.features)
|
return get_hf_features_from_features(self.features)
|
||||||
|
|
||||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||||
ep_start = self.episode_data_index["from"][ep_idx]
|
ep = self.meta.episodes[ep_idx]
|
||||||
ep_end = self.episode_data_index["to"][ep_idx]
|
ep_start = ep["dataset_from_index"]
|
||||||
|
ep_end = ep["dataset_to_index"]
|
||||||
query_indices = {
|
query_indices = {
|
||||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||||
for key, delta_idx in self.delta_indices.items()
|
for key, delta_idx in self.delta_indices.items()
|
||||||
}
|
}
|
||||||
padding = { # Pad values outside of current episode range
|
padding = { # Pad values outside of current episode range
|
||||||
f"{key}_is_pad": torch.BoolTensor(
|
f"{key}_is_pad": torch.BoolTensor(
|
||||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||||
)
|
)
|
||||||
for key, delta_idx in self.delta_indices.items()
|
for key, delta_idx in self.delta_indices.items()
|
||||||
}
|
}
|
||||||
@@ -663,7 +714,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
query_timestamps = {}
|
query_timestamps = {}
|
||||||
for key in self.meta.video_keys:
|
for key in self.meta.video_keys:
|
||||||
if query_indices is not None and key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
|
||||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||||
else:
|
else:
|
||||||
query_timestamps[key] = [current_ts]
|
query_timestamps[key] = [current_ts]
|
||||||
@@ -672,7 +723,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
return {
|
||||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
key: torch.stack(self.hf_dataset[q_idx][key])
|
||||||
for key, q_idx in query_indices.items()
|
for key, q_idx in query_indices.items()
|
||||||
if key not in self.meta.video_keys
|
if key not in self.meta.video_keys
|
||||||
}
|
}
|
||||||
@@ -683,10 +734,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||||
the main process and a subprocess fails to access it.
|
the main process and a subprocess fails to access it.
|
||||||
"""
|
"""
|
||||||
|
ep = self.meta.episodes[ep_idx]
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
|
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||||
|
# Thus we load the start timestamp of the episode on this mp4 and,
|
||||||
|
# shift the query timestamp accordingly.
|
||||||
|
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||||
|
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||||
|
|
||||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend)
|
||||||
item[vid_key] = frames.squeeze(0)
|
item[vid_key] = frames.squeeze(0)
|
||||||
|
|
||||||
return item
|
return item
|
||||||
@@ -724,8 +782,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# Add task as a string
|
# Add task as a string
|
||||||
task_idx = item["task_index"].item()
|
task_idx = item["task_index"].item()
|
||||||
item["task"] = self.meta.tasks[task_idx]
|
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@@ -755,6 +812,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
return self.root / fpath
|
return self.root / fpath
|
||||||
|
|
||||||
|
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||||
|
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||||
|
|
||||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||||
if self.image_writer is None:
|
if self.image_writer is None:
|
||||||
if isinstance(image, torch.Tensor):
|
if isinstance(image, torch.Tensor):
|
||||||
@@ -763,7 +823,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
else:
|
else:
|
||||||
self.image_writer.save_image(image=image, fpath=fpath)
|
self.image_writer.save_image(image=image, fpath=fpath)
|
||||||
|
|
||||||
def add_frame(self, frame: dict, task: str, timestamp: float | None = None) -> None:
|
def add_frame(self, frame: dict) -> None:
|
||||||
"""
|
"""
|
||||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||||
@@ -781,11 +841,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# Automatically add frame_index and timestamp to episode buffer
|
# Automatically add frame_index and timestamp to episode buffer
|
||||||
frame_index = self.episode_buffer["size"]
|
frame_index = self.episode_buffer["size"]
|
||||||
if timestamp is None:
|
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||||
timestamp = frame_index / self.fps
|
|
||||||
self.episode_buffer["frame_index"].append(frame_index)
|
self.episode_buffer["frame_index"].append(frame_index)
|
||||||
self.episode_buffer["timestamp"].append(timestamp)
|
self.episode_buffer["timestamp"].append(timestamp)
|
||||||
self.episode_buffer["task"].append(task)
|
self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
|
||||||
|
|
||||||
# Add frame features to episode_buffer
|
# Add frame features to episode_buffer
|
||||||
for key in frame:
|
for key in frame:
|
||||||
@@ -830,11 +889,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||||
|
|
||||||
# Add new tasks to the tasks dictionary
|
# Update tasks and task indices with new tasks if any
|
||||||
for task in episode_tasks:
|
self.meta.save_episode_tasks(episode_tasks)
|
||||||
task_index = self.meta.get_task_index(task)
|
|
||||||
if task_index is None:
|
|
||||||
self.meta.add_task(task)
|
|
||||||
|
|
||||||
# Given tasks in natural language, find their corresponding task indices
|
# Given tasks in natural language, find their corresponding task indices
|
||||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||||
@@ -846,51 +902,154 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
continue
|
continue
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||||
|
|
||||||
|
# Wait for image writer to end, so that episode stats over images can be computed
|
||||||
self._wait_image_writer()
|
self._wait_image_writer()
|
||||||
self._save_episode_table(episode_buffer, episode_index)
|
|
||||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||||
|
|
||||||
if len(self.meta.video_keys) > 0:
|
ep_metadata = self._save_episode_data(episode_buffer)
|
||||||
video_paths = self.encode_episode_videos(episode_index)
|
for video_key in self.meta.video_keys:
|
||||||
for key in self.meta.video_keys:
|
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||||
episode_buffer[key] = video_paths[key]
|
|
||||||
|
|
||||||
# `meta.save_episode` be executed after encoding the videos
|
# `meta.save_episode` need to be executed after encoding the videos
|
||||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||||
|
|
||||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
# TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index
|
||||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
# ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||||
check_timestamps_sync(
|
# ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||||
episode_buffer["timestamp"],
|
# check_timestamps_sync(
|
||||||
episode_buffer["episode_index"],
|
# episode_buffer["timestamp"],
|
||||||
ep_data_index_np,
|
# episode_buffer["episode_index"],
|
||||||
self.fps,
|
# ep_data_index_np,
|
||||||
self.tolerance_s,
|
# self.fps,
|
||||||
)
|
# self.tolerance_s,
|
||||||
|
# )
|
||||||
video_files = list(self.root.rglob("*.mp4"))
|
|
||||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
|
||||||
|
|
||||||
parquet_files = list(self.root.rglob("*.parquet"))
|
|
||||||
assert len(parquet_files) == self.num_episodes
|
|
||||||
|
|
||||||
|
# TODO(rcadene): images are also deleted in clear_episode_buffer
|
||||||
# delete images
|
# delete images
|
||||||
img_dir = self.root / "images"
|
img_dir = self.root / "images"
|
||||||
if img_dir.is_dir():
|
if img_dir.is_dir():
|
||||||
shutil.rmtree(self.root / "images")
|
shutil.rmtree(self.root / "images")
|
||||||
|
|
||||||
if not episode_data: # Reset the buffer
|
if not episode_data:
|
||||||
|
# Reset episode buffer
|
||||||
self.episode_buffer = self.create_episode_buffer()
|
self.episode_buffer = self.create_episode_buffer()
|
||||||
|
|
||||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
||||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
"""Save episode data to a parquet file and update the Hugging Face dataset of frames data.
|
||||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
|
||||||
|
This function processes episodes data from a buffer, converts it into a Hugging Face dataset,
|
||||||
|
and saves it as a parquet file. It handles both the creation of new parquet files and the
|
||||||
|
updating of existing ones based on size constraints. After saving the data, it reloads
|
||||||
|
the Hugging Face dataset to ensure it is up-to-date.
|
||||||
|
|
||||||
|
Notes: We both need to update parquet files and HF dataset:
|
||||||
|
- `pandas` loads parquet file in RAM
|
||||||
|
- `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk,
|
||||||
|
or loads directly from pyarrow cache.
|
||||||
|
"""
|
||||||
|
# Convert buffer into HF Dataset
|
||||||
|
ep_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||||
|
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train")
|
||||||
ep_dataset = embed_images(ep_dataset)
|
ep_dataset = embed_images(ep_dataset)
|
||||||
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
|
ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset)
|
||||||
self.hf_dataset.set_transform(hf_transform_to_torch)
|
ep_num_frames = len(ep_dataset)
|
||||||
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
df = pd.DataFrame(ep_dataset)
|
||||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
ep_dataset.to_parquet(ep_data_path)
|
if self.meta.episodes is None:
|
||||||
|
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||||
|
chunk_idx, file_idx = 0, 0
|
||||||
|
latest_num_frames = 0
|
||||||
|
else:
|
||||||
|
# Retrieve information from the latest parquet file
|
||||||
|
latest_ep = self.meta.episodes[-1]
|
||||||
|
chunk_idx = latest_ep["data/chunk_index"]
|
||||||
|
file_idx = latest_ep["data/file_index"]
|
||||||
|
|
||||||
|
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
|
||||||
|
latest_num_frames = get_parquet_num_frames(latest_path)
|
||||||
|
|
||||||
|
# Determine if a new parquet file is needed
|
||||||
|
if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb:
|
||||||
|
# Size limit is reached, prepare new parquet file
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||||
|
latest_num_frames = 0
|
||||||
|
else:
|
||||||
|
# Update the existing parquet file with new rows
|
||||||
|
latest_df = pd.read_parquet(latest_path)
|
||||||
|
df = pd.concat([latest_df, df], ignore_index=True)
|
||||||
|
|
||||||
|
# Write the resulting dataframe from RAM to disk
|
||||||
|
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if len(self.meta.image_keys) > 0:
|
||||||
|
to_parquet_with_hf_images(df, path)
|
||||||
|
else:
|
||||||
|
df.to_parquet(path)
|
||||||
|
|
||||||
|
# Update the Hugging Face dataset by reloading it.
|
||||||
|
# This process should be fast because only the latest Parquet file has been modified.
|
||||||
|
# Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache.
|
||||||
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"data/chunk_index": chunk_idx,
|
||||||
|
"data/file_index": file_idx,
|
||||||
|
"dataset_from_index": latest_num_frames,
|
||||||
|
"dataset_to_index": latest_num_frames + ep_num_frames,
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _save_episode_video(self, video_key: str, episode_index: int):
|
||||||
|
# Encode episode frames into a temporary video
|
||||||
|
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||||
|
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||||
|
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||||
|
|
||||||
|
if self.meta.episodes is None:
|
||||||
|
# Initialize indices for a new dataset made of the first episode data
|
||||||
|
chunk_idx, file_idx = 0, 0
|
||||||
|
latest_duration_in_s = 0
|
||||||
|
new_path = self.root / self.meta.video_path.format(
|
||||||
|
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(ep_path), str(new_path))
|
||||||
|
else:
|
||||||
|
# Retrieve information from the latest video file
|
||||||
|
latest_ep = self.meta.episodes[-1]
|
||||||
|
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"]
|
||||||
|
file_idx = latest_ep[f"videos/{video_key}/file_index"]
|
||||||
|
|
||||||
|
latest_path = self.root / self.meta.video_path.format(
|
||||||
|
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
latest_size_in_mb = get_video_size_in_mb(latest_path)
|
||||||
|
latest_duration_in_s = get_video_duration_in_s(latest_path)
|
||||||
|
|
||||||
|
if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb:
|
||||||
|
# Move temporary episode video to a new video file in the dataset
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||||
|
new_path = self.root / self.meta.video_path.format(
|
||||||
|
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(ep_path), str(new_path))
|
||||||
|
else:
|
||||||
|
# Update latest video file
|
||||||
|
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx)
|
||||||
|
|
||||||
|
# Remove temporary directory
|
||||||
|
shutil.rmtree(str(ep_path.parent))
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"episode_index": episode_index,
|
||||||
|
f"videos/{video_key}/chunk_index": chunk_idx,
|
||||||
|
f"videos/{video_key}/file_index": file_idx,
|
||||||
|
f"videos/{video_key}/from_timestamp": latest_duration_in_s,
|
||||||
|
f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
|
||||||
def clear_episode_buffer(self) -> None:
|
def clear_episode_buffer(self) -> None:
|
||||||
episode_index = self.episode_buffer["episode_index"]
|
episode_index = self.episode_buffer["episode_index"]
|
||||||
@@ -919,7 +1078,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
def stop_image_writer(self) -> None:
|
def stop_image_writer(self) -> None:
|
||||||
"""
|
"""
|
||||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||||
remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized.
|
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||||
"""
|
"""
|
||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
self.image_writer.stop()
|
self.image_writer.stop()
|
||||||
@@ -930,34 +1089,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
self.image_writer.wait_until_done()
|
self.image_writer.wait_until_done()
|
||||||
|
|
||||||
def encode_videos(self) -> None:
|
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||||
"""
|
"""
|
||||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
Use ffmpeg to convert frames stored as png into mp4 videos.
|
||||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
since video encoding with ffmpeg is already using multithreading.
|
since video encoding with ffmpeg is already using multithreading.
|
||||||
"""
|
"""
|
||||||
for ep_idx in range(self.meta.total_episodes):
|
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||||
self.encode_episode_videos(ep_idx)
|
img_dir = self._get_image_file_dir(episode_index, video_key)
|
||||||
|
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
|
||||||
def encode_episode_videos(self, episode_index: int) -> dict:
|
return temp_path
|
||||||
"""
|
|
||||||
Use ffmpeg to convert frames stored as png into mp4 videos.
|
|
||||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
|
||||||
since video encoding with ffmpeg is already using multithreading.
|
|
||||||
"""
|
|
||||||
video_paths = {}
|
|
||||||
for key in self.meta.video_keys:
|
|
||||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
|
||||||
video_paths[key] = str(video_path)
|
|
||||||
if video_path.is_file():
|
|
||||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
|
||||||
continue
|
|
||||||
img_dir = self._get_image_file_path(
|
|
||||||
episode_index=episode_index, image_key=key, frame_index=0
|
|
||||||
).parent
|
|
||||||
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
|
|
||||||
|
|
||||||
return video_paths
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@@ -1000,7 +1141,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
obj.image_transforms = None
|
obj.image_transforms = None
|
||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
|
||||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|||||||
@@ -337,13 +337,11 @@ def compute_sampler_weights(
|
|||||||
if len(offline_dataset) > 0:
|
if len(offline_dataset) > 0:
|
||||||
offline_data_mask_indices = []
|
offline_data_mask_indices = []
|
||||||
for start_index, end_index in zip(
|
for start_index, end_index in zip(
|
||||||
offline_dataset.episode_data_index["from"],
|
offline_dataset.meta.episodes["dataset_from_index"],
|
||||||
offline_dataset.episode_data_index["to"],
|
offline_dataset.meta.episodes["dataset_to_index"],
|
||||||
strict=True,
|
strict=True,
|
||||||
):
|
):
|
||||||
offline_data_mask_indices.extend(
|
offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames))
|
||||||
range(start_index.item(), end_index.item() - offline_drop_n_last_frames)
|
|
||||||
)
|
|
||||||
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool)
|
||||||
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
offline_data_mask[torch.tensor(offline_data_mask_indices)] = True
|
||||||
weights.append(
|
weights.append(
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ import torch
|
|||||||
class EpisodeAwareSampler:
|
class EpisodeAwareSampler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
episode_data_index: dict,
|
dataset_from_indices: list[int],
|
||||||
|
dataset_to_indices: list[int],
|
||||||
episode_indices_to_use: Union[list, None] = None,
|
episode_indices_to_use: Union[list, None] = None,
|
||||||
drop_n_first_frames: int = 0,
|
drop_n_first_frames: int = 0,
|
||||||
drop_n_last_frames: int = 0,
|
drop_n_last_frames: int = 0,
|
||||||
@@ -30,7 +31,8 @@ class EpisodeAwareSampler:
|
|||||||
"""Sampler that optionally incorporates episode boundary information.
|
"""Sampler that optionally incorporates episode boundary information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
|
dataset_from_indices: List of indices containing the start of each episode in the dataset.
|
||||||
|
dataset_to_indices: List of indices containing the end of each episode in the dataset.
|
||||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||||
Assumes that episodes are indexed from 0 to N-1.
|
Assumes that episodes are indexed from 0 to N-1.
|
||||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||||
@@ -39,12 +41,10 @@ class EpisodeAwareSampler:
|
|||||||
"""
|
"""
|
||||||
indices = []
|
indices = []
|
||||||
for episode_idx, (start_index, end_index) in enumerate(
|
for episode_idx, (start_index, end_index) in enumerate(
|
||||||
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
|
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||||
):
|
):
|
||||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||||
indices.extend(
|
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.indices = indices
|
self.indices = indices
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
|||||||
@@ -17,18 +17,23 @@ import contextlib
|
|||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from itertools import accumulate
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import jsonlines
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import packaging.version
|
import packaging.version
|
||||||
|
import pandas
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import Dataset, concatenate_datasets
|
||||||
from datasets.table import embed_table_storage
|
from datasets.table import embed_table_storage
|
||||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
@@ -42,19 +47,25 @@ from lerobot.common.datasets.backward_compatibility import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.robots import Robot
|
from lerobot.common.robots import Robot
|
||||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||||
|
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||||
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
|
||||||
|
|
||||||
INFO_PATH = "meta/info.json"
|
INFO_PATH = "meta/info.json"
|
||||||
EPISODES_PATH = "meta/episodes.jsonl"
|
|
||||||
STATS_PATH = "meta/stats.json"
|
STATS_PATH = "meta/stats.json"
|
||||||
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
|
||||||
TASKS_PATH = "meta/tasks.jsonl"
|
|
||||||
|
|
||||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
EPISODES_DIR = "meta/episodes"
|
||||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
DATA_DIR = "data"
|
||||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
VIDEO_DIR = "videos"
|
||||||
|
|
||||||
|
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||||
|
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||||
|
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||||
|
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||||
|
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||||
|
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||||
|
|
||||||
DATASET_CARD_TEMPLATE = """
|
DATASET_CARD_TEMPLATE = """
|
||||||
---
|
---
|
||||||
@@ -75,6 +86,115 @@ DEFAULT_FEATURES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_parquet_file_size_in_mb(parquet_path):
|
||||||
|
metadata = pq.read_metadata(parquet_path)
|
||||||
|
total_uncompressed_size = 0
|
||||||
|
for row_group in range(metadata.num_row_groups):
|
||||||
|
rg_metadata = metadata.row_group(row_group)
|
||||||
|
for column in range(rg_metadata.num_columns):
|
||||||
|
col_metadata = rg_metadata.column(column)
|
||||||
|
total_uncompressed_size += col_metadata.total_uncompressed_size
|
||||||
|
return total_uncompressed_size / (1024**2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int:
|
||||||
|
return hf_ds.data.nbytes / (1024**2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pd_dataframe_size_in_mb(df: pandas.DataFrame) -> int:
|
||||||
|
# TODO(rcadene): unused?
|
||||||
|
memory_usage_bytes = df.memory_usage(deep=True).sum()
|
||||||
|
return memory_usage_bytes / (1024**2)
|
||||||
|
|
||||||
|
|
||||||
|
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
|
||||||
|
if file_idx == chunks_size - 1:
|
||||||
|
file_idx = 0
|
||||||
|
chunk_idx += 1
|
||||||
|
else:
|
||||||
|
file_idx += 1
|
||||||
|
return chunk_idx, file_idx
|
||||||
|
|
||||||
|
|
||||||
|
def load_nested_dataset(pq_dir: Path) -> Dataset:
|
||||||
|
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||||
|
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||||
|
Concatenate all pyarrow references to return HF Dataset format
|
||||||
|
"""
|
||||||
|
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||||
|
if len(paths) == 0:
|
||||||
|
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||||
|
|
||||||
|
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||||
|
datasets = [Dataset.from_parquet(str(path)) for path in paths]
|
||||||
|
return concatenate_datasets(datasets)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parquet_num_frames(parquet_path):
|
||||||
|
metadata = pq.read_metadata(parquet_path)
|
||||||
|
return metadata.num_rows
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_size_in_mb(mp4_path: Path):
|
||||||
|
file_size_bytes = mp4_path.stat().st_size
|
||||||
|
file_size_mb = file_size_bytes / (1024**2)
|
||||||
|
return file_size_mb
|
||||||
|
|
||||||
|
|
||||||
|
def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chunk_idx: int, file_idx: int):
|
||||||
|
# TODO(rcadene): move to video_utils.py
|
||||||
|
# TODO(rcadene): add docstring
|
||||||
|
tmp_dir = Path(tempfile.mkdtemp(dir=root))
|
||||||
|
# Create a text file with the list of files to concatenate
|
||||||
|
path_concat_video_files = tmp_dir / "concat_video_files.txt"
|
||||||
|
with open(path_concat_video_files, "w") as f:
|
||||||
|
for ep_path in paths_to_cat:
|
||||||
|
f.write(f"file '{str(ep_path)}'\n")
|
||||||
|
|
||||||
|
path_tmp_output = tmp_dir / "tmp_output.mp4"
|
||||||
|
command = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-f",
|
||||||
|
"concat",
|
||||||
|
"-safe",
|
||||||
|
"0",
|
||||||
|
"-i",
|
||||||
|
str(path_concat_video_files),
|
||||||
|
"-c",
|
||||||
|
"copy",
|
||||||
|
str(path_tmp_output),
|
||||||
|
]
|
||||||
|
subprocess.run(command, check=True)
|
||||||
|
|
||||||
|
output_path = root / DEFAULT_VIDEO_PATH.format(
|
||||||
|
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.move(str(path_tmp_output), str(output_path))
|
||||||
|
shutil.rmtree(str(tmp_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_duration_in_s(mp4_file: Path):
|
||||||
|
# TODO(rcadene): move to video_utils.py
|
||||||
|
command = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v",
|
||||||
|
"error",
|
||||||
|
"-show_entries",
|
||||||
|
"format=duration",
|
||||||
|
"-of",
|
||||||
|
"default=noprint_wrappers=1:nokey=1",
|
||||||
|
str(mp4_file),
|
||||||
|
]
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
return float(result.stdout)
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||||
|
|
||||||
@@ -107,23 +227,13 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
|||||||
return outdict
|
return outdict
|
||||||
|
|
||||||
|
|
||||||
def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
|
|
||||||
split_keys = flattened_key.split(sep)
|
|
||||||
getter = obj[split_keys[0]]
|
|
||||||
if len(split_keys) == 1:
|
|
||||||
return getter
|
|
||||||
|
|
||||||
for key in split_keys[1:]:
|
|
||||||
getter = getter[key]
|
|
||||||
|
|
||||||
return getter
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||||
serialized_dict = {}
|
serialized_dict = {}
|
||||||
for key, value in flatten_dict(stats).items():
|
for key, value in flatten_dict(stats).items():
|
||||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||||
serialized_dict[key] = value.tolist()
|
serialized_dict[key] = value.tolist()
|
||||||
|
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
|
||||||
|
serialized_dict[key] = value
|
||||||
elif isinstance(value, np.generic):
|
elif isinstance(value, np.generic):
|
||||||
serialized_dict[key] = value.item()
|
serialized_dict[key] = value.item()
|
||||||
elif isinstance(value, (int, float)):
|
elif isinstance(value, (int, float)):
|
||||||
@@ -153,23 +263,6 @@ def write_json(data: dict, fpath: Path) -> None:
|
|||||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
def load_jsonlines(fpath: Path) -> list[Any]:
|
|
||||||
with jsonlines.open(fpath, "r") as reader:
|
|
||||||
return list(reader)
|
|
||||||
|
|
||||||
|
|
||||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
|
||||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
|
||||||
writer.write_all(data)
|
|
||||||
|
|
||||||
|
|
||||||
def append_jsonlines(data: dict, fpath: Path) -> None:
|
|
||||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
|
||||||
with jsonlines.open(fpath, "a") as writer:
|
|
||||||
writer.write(data)
|
|
||||||
|
|
||||||
|
|
||||||
def write_info(info: dict, local_dir: Path):
|
def write_info(info: dict, local_dir: Path):
|
||||||
write_json(info, local_dir / INFO_PATH)
|
write_json(info, local_dir / INFO_PATH)
|
||||||
|
|
||||||
@@ -198,43 +291,42 @@ def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
|
|||||||
return cast_stats_to_numpy(stats)
|
return cast_stats_to_numpy(stats)
|
||||||
|
|
||||||
|
|
||||||
def write_task(task_index: int, task: dict, local_dir: Path):
|
def write_hf_dataset(hf_dataset: Dataset, local_dir: Path):
|
||||||
task_dict = {
|
if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
"task_index": task_index,
|
raise NotImplementedError("Contact a maintainer.")
|
||||||
"task": task,
|
|
||||||
}
|
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||||
append_jsonlines(task_dict, local_dir / TASKS_PATH)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
hf_dataset.to_parquet(path)
|
||||||
|
|
||||||
|
|
||||||
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
def write_tasks(tasks: pandas.DataFrame, local_dir: Path):
|
||||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
path = local_dir / DEFAULT_TASKS_PATH
|
||||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
tasks.to_parquet(path)
|
||||||
return tasks, task_to_task_index
|
|
||||||
|
|
||||||
|
|
||||||
def write_episode(episode: dict, local_dir: Path):
|
def load_tasks(local_dir: Path):
|
||||||
append_jsonlines(episode, local_dir / EPISODES_PATH)
|
tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH)
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
|
||||||
def load_episodes(local_dir: Path) -> dict:
|
def write_episodes(episodes: Dataset, local_dir: Path):
|
||||||
episodes = load_jsonlines(local_dir / EPISODES_PATH)
|
if get_hf_dataset_size_in_mb(episodes) > DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
raise NotImplementedError("Contact a maintainer.")
|
||||||
|
|
||||||
|
fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||||
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
episodes.to_parquet(fpath)
|
||||||
|
|
||||||
|
|
||||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
def load_episodes(local_dir: Path) -> datasets.Dataset:
|
||||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
episodes = load_nested_dataset(local_dir / EPISODES_DIR)
|
||||||
# is a dictionary of stats and not an integer.
|
# Select episode features/columns containing references to episode data and videos
|
||||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
# (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.)
|
||||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
# This is to speedup access to these data, instead of having to load episode stats.
|
||||||
|
episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")])
|
||||||
|
return episodes
|
||||||
def load_episodes_stats(local_dir: Path) -> dict:
|
|
||||||
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
|
|
||||||
return {
|
|
||||||
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
|
||||||
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def backward_compatible_episodes_stats(
|
def backward_compatible_episodes_stats(
|
||||||
@@ -441,6 +533,7 @@ def build_dataset_frame(
|
|||||||
|
|
||||||
|
|
||||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||||
|
# TODO(rcadene): add fps for each feature
|
||||||
camera_ft = {}
|
camera_ft = {}
|
||||||
if robot.cameras:
|
if robot.cameras:
|
||||||
camera_ft = {
|
camera_ft = {
|
||||||
@@ -494,31 +587,17 @@ def create_empty_dataset_info(
|
|||||||
"total_episodes": 0,
|
"total_episodes": 0,
|
||||||
"total_frames": 0,
|
"total_frames": 0,
|
||||||
"total_tasks": 0,
|
"total_tasks": 0,
|
||||||
"total_videos": 0,
|
|
||||||
"total_chunks": 0,
|
|
||||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||||
|
"data_files_size_in_mb": DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
|
"video_files_size_in_mb": DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": DEFAULT_PARQUET_PATH,
|
"data_path": DEFAULT_DATA_PATH,
|
||||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||||
"features": features,
|
"features": features,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_episode_data_index(
|
|
||||||
episode_dicts: dict[dict], episodes: list[int] | None = None
|
|
||||||
) -> dict[str, torch.Tensor]:
|
|
||||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
|
|
||||||
if episodes is not None:
|
|
||||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
|
||||||
|
|
||||||
cumulative_lengths = list(accumulate(episode_lengths.values()))
|
|
||||||
return {
|
|
||||||
"from": torch.LongTensor([0] + cumulative_lengths[:-1]),
|
|
||||||
"to": torch.LongTensor(cumulative_lengths),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def check_timestamps_sync(
|
def check_timestamps_sync(
|
||||||
timestamps: np.ndarray,
|
timestamps: np.ndarray,
|
||||||
episode_indices: np.ndarray,
|
episode_indices: np.ndarray,
|
||||||
@@ -755,10 +834,17 @@ def validate_frame(frame: dict, features: dict):
|
|||||||
expected_features = set(features) - set(DEFAULT_FEATURES)
|
expected_features = set(features) - set(DEFAULT_FEATURES)
|
||||||
actual_features = set(frame)
|
actual_features = set(frame)
|
||||||
|
|
||||||
error_message = validate_features_presence(actual_features, expected_features)
|
# task is a special required field that's not part of regular features
|
||||||
|
if "task" not in actual_features:
|
||||||
|
raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n")
|
||||||
|
|
||||||
common_features = actual_features & expected_features
|
# Remove task from actual_features for regular feature validation
|
||||||
for name in common_features - {"task"}:
|
actual_features_for_validation = actual_features - {"task"}
|
||||||
|
|
||||||
|
error_message = validate_features_presence(actual_features_for_validation, expected_features)
|
||||||
|
|
||||||
|
common_features = actual_features_for_validation & expected_features
|
||||||
|
for name in common_features:
|
||||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||||
|
|
||||||
if error_message:
|
if error_message:
|
||||||
@@ -858,3 +944,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
|||||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path):
|
||||||
|
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||||
|
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||||
|
"""
|
||||||
|
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||||
|
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||||
|
|||||||
@@ -121,12 +121,12 @@ from safetensors.torch import load_file
|
|||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_PARQUET_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
EPISODES_PATH,
|
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
|
LEGACY_EPISODES_PATH,
|
||||||
|
LEGACY_TASKS_PATH,
|
||||||
STATS_PATH,
|
STATS_PATH,
|
||||||
TASKS_PATH,
|
|
||||||
create_branch,
|
create_branch,
|
||||||
create_lerobot_dataset_card,
|
create_lerobot_dataset_card,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
@@ -290,14 +290,12 @@ def split_parquet_by_episodes(
|
|||||||
for ep_chunk in range(total_chunks):
|
for ep_chunk in range(total_chunks):
|
||||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
chunk_dir = "/".join(DEFAULT_DATA_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||||
episode_lengths.insert(ep_idx, len(ep_table))
|
episode_lengths.insert(ep_idx, len(ep_table))
|
||||||
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
output_file = output_dir / DEFAULT_DATA_PATH.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
|
||||||
)
|
|
||||||
pq.write_table(ep_table, output_file)
|
pq.write_table(ep_table, output_file)
|
||||||
|
|
||||||
return episode_lengths
|
return episode_lengths
|
||||||
@@ -495,7 +493,7 @@ def convert_dataset(
|
|||||||
|
|
||||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
write_jsonlines(tasks, v20_dir / LEGACY_TASKS_PATH)
|
||||||
features["task_index"] = {
|
features["task_index"] = {
|
||||||
"dtype": "int64",
|
"dtype": "int64",
|
||||||
"shape": (1,),
|
"shape": (1,),
|
||||||
@@ -545,7 +543,7 @@ def convert_dataset(
|
|||||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||||
for ep_idx in episode_indices
|
for ep_idx in episode_indices
|
||||||
]
|
]
|
||||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
write_jsonlines(episodes, v20_dir / LEGACY_EPISODES_PATH)
|
||||||
|
|
||||||
# Assemble metadata v2.0
|
# Assemble metadata v2.0
|
||||||
metadata_v2_0 = {
|
metadata_v2_0 = {
|
||||||
@@ -559,7 +557,7 @@ def convert_dataset(
|
|||||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||||
"fps": metadata_v1["fps"],
|
"fps": metadata_v1["fps"],
|
||||||
"splits": {"train": f"0:{total_episodes}"},
|
"splits": {"train": f"0:{total_episodes}"},
|
||||||
"data_path": DEFAULT_PARQUET_PATH,
|
"data_path": DEFAULT_DATA_PATH,
|
||||||
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
||||||
"features": features,
|
"features": features,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ import logging
|
|||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
from lerobot.common.datasets.utils import LEGACY_EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||||
|
|
||||||
V20 = "v2.0"
|
V20 = "v2.0"
|
||||||
@@ -61,8 +61,8 @@ def convert_dataset(
|
|||||||
with SuppressWarnings():
|
with SuppressWarnings():
|
||||||
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
|
||||||
|
|
||||||
if (dataset.root / EPISODES_STATS_PATH).is_file():
|
if (dataset.root / LEGACY_EPISODES_STATS_PATH).is_file():
|
||||||
(dataset.root / EPISODES_STATS_PATH).unlink()
|
(dataset.root / LEGACY_EPISODES_STATS_PATH).unlink()
|
||||||
|
|
||||||
convert_stats(dataset, num_workers=num_workers)
|
convert_stats(dataset, num_workers=num_workers)
|
||||||
ref_stats = load_stats(dataset.root)
|
ref_stats = load_stats(dataset.root)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import write_episode_stats
|
from lerobot.common.datasets.utils import legacy_write_episode_stats
|
||||||
|
|
||||||
|
|
||||||
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
|
||||||
@@ -72,7 +72,7 @@ def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
|
|||||||
convert_episode_stats(dataset, ep_idx)
|
convert_episode_stats(dataset, ep_idx)
|
||||||
|
|
||||||
for ep_idx in tqdm(range(total_episodes)):
|
for ep_idx in tqdm(range(total_episodes)):
|
||||||
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
legacy_write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
|
||||||
|
|
||||||
|
|
||||||
def check_aggregate_stats(
|
def check_aggregate_stats(
|
||||||
|
|||||||
452
lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py
Normal file
452
lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
"""
|
||||||
|
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.1 to
|
||||||
|
3.0. It will:
|
||||||
|
|
||||||
|
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||||
|
- Check consistency between these new stats and the old ones.
|
||||||
|
- Remove the deprecated `stats.json`.
|
||||||
|
- Update codebase_version in `info.json`.
|
||||||
|
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py \
|
||||||
|
--repo-id=lerobot/pusht
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
import pandas as pd
|
||||||
|
import pyarrow as pa
|
||||||
|
import tqdm
|
||||||
|
from datasets import Dataset, Features, Image
|
||||||
|
from huggingface_hub import HfApi, snapshot_download
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_DATA_PATH,
|
||||||
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_VIDEO_PATH,
|
||||||
|
cast_stats_to_numpy,
|
||||||
|
concat_video_files,
|
||||||
|
flatten_dict,
|
||||||
|
get_parquet_file_size_in_mb,
|
||||||
|
get_parquet_num_frames,
|
||||||
|
get_video_duration_in_s,
|
||||||
|
get_video_size_in_mb,
|
||||||
|
load_info,
|
||||||
|
update_chunk_file_indices,
|
||||||
|
write_episodes,
|
||||||
|
write_info,
|
||||||
|
write_stats,
|
||||||
|
write_tasks,
|
||||||
|
)
|
||||||
|
|
||||||
|
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||||
|
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||||
|
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
|
||||||
|
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||||
|
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||||
|
|
||||||
|
V21 = "v2.1"
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
-------------------------
|
||||||
|
OLD
|
||||||
|
data/chunk-000/episode_000000.parquet
|
||||||
|
|
||||||
|
NEW
|
||||||
|
data/chunk-000/file_000.parquet
|
||||||
|
-------------------------
|
||||||
|
OLD
|
||||||
|
videos/chunk-000/CAMERA/episode_000000.mp4
|
||||||
|
|
||||||
|
NEW
|
||||||
|
videos/chunk-000/file_000.mp4
|
||||||
|
-------------------------
|
||||||
|
OLD
|
||||||
|
episodes.jsonl
|
||||||
|
{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266}
|
||||||
|
|
||||||
|
NEW
|
||||||
|
meta/episodes/chunk-000/episodes_000.parquet
|
||||||
|
episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length
|
||||||
|
-------------------------
|
||||||
|
OLD
|
||||||
|
tasks.jsonl
|
||||||
|
{"task_index": 1, "task": "Put the blue block in the green bowl"}
|
||||||
|
|
||||||
|
NEW
|
||||||
|
meta/tasks/chunk-000/file_000.parquet
|
||||||
|
task_index | task
|
||||||
|
-------------------------
|
||||||
|
OLD
|
||||||
|
episodes_stats.jsonl
|
||||||
|
|
||||||
|
NEW
|
||||||
|
meta/episodes_stats/chunk-000/file_000.parquet
|
||||||
|
episode_index | mean | std | min | max
|
||||||
|
-------------------------
|
||||||
|
UPDATE
|
||||||
|
meta/info.json
|
||||||
|
-------------------------
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def load_jsonlines(fpath: Path) -> list[Any]:
|
||||||
|
with jsonlines.open(fpath, "r") as reader:
|
||||||
|
return list(reader)
|
||||||
|
|
||||||
|
|
||||||
|
def legacy_load_episodes(local_dir: Path) -> dict:
|
||||||
|
episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH)
|
||||||
|
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
|
||||||
|
|
||||||
|
|
||||||
|
def legacy_load_episodes_stats(local_dir: Path) -> dict:
|
||||||
|
episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH)
|
||||||
|
return {
|
||||||
|
item["episode_index"]: cast_stats_to_numpy(item["stats"])
|
||||||
|
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
|
||||||
|
tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH)
|
||||||
|
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||||
|
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
|
||||||
|
return tasks, task_to_task_index
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tasks(root, new_root):
|
||||||
|
tasks, _ = legacy_load_tasks(root)
|
||||||
|
task_indices = tasks.keys()
|
||||||
|
task_strings = tasks.values()
|
||||||
|
df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings)
|
||||||
|
write_tasks(df_tasks, new_root)
|
||||||
|
|
||||||
|
|
||||||
|
def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys):
|
||||||
|
# TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets
|
||||||
|
dataframes = [pd.read_parquet(file) for file in paths_to_cat]
|
||||||
|
# Concatenate all DataFrames along rows
|
||||||
|
concatenated_df = pd.concat(dataframes, ignore_index=True)
|
||||||
|
|
||||||
|
path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if len(image_keys) > 0:
|
||||||
|
schema = pa.Schema.from_pandas(concatenated_df)
|
||||||
|
features = Features.from_arrow_schema(schema)
|
||||||
|
for key in image_keys:
|
||||||
|
features[key] = Image()
|
||||||
|
schema = features.arrow_schema
|
||||||
|
else:
|
||||||
|
schema = None
|
||||||
|
|
||||||
|
concatenated_df.to_parquet(path, index=False, schema=schema)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_data(root, new_root):
|
||||||
|
data_dir = root / "data"
|
||||||
|
ep_paths = sorted(data_dir.glob("*/*.parquet"))
|
||||||
|
|
||||||
|
image_keys = get_image_keys(root)
|
||||||
|
|
||||||
|
ep_idx = 0
|
||||||
|
chunk_idx = 0
|
||||||
|
file_idx = 0
|
||||||
|
size_in_mb = 0
|
||||||
|
num_frames = 0
|
||||||
|
paths_to_cat = []
|
||||||
|
episodes_metadata = []
|
||||||
|
for ep_path in ep_paths:
|
||||||
|
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
|
||||||
|
ep_num_frames = get_parquet_num_frames(ep_path)
|
||||||
|
ep_metadata = {
|
||||||
|
"episode_index": ep_idx,
|
||||||
|
"data/chunk_index": chunk_idx,
|
||||||
|
"data/file_index": file_idx,
|
||||||
|
"dataset_from_index": num_frames,
|
||||||
|
"dataset_to_index": num_frames + ep_num_frames,
|
||||||
|
}
|
||||||
|
size_in_mb += ep_size_in_mb
|
||||||
|
num_frames += ep_num_frames
|
||||||
|
episodes_metadata.append(ep_metadata)
|
||||||
|
ep_idx += 1
|
||||||
|
|
||||||
|
if size_in_mb < DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
|
paths_to_cat.append(ep_path)
|
||||||
|
continue
|
||||||
|
|
||||||
|
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||||
|
|
||||||
|
# Reset for the next file
|
||||||
|
size_in_mb = ep_size_in_mb
|
||||||
|
num_frames = ep_num_frames
|
||||||
|
paths_to_cat = [ep_path]
|
||||||
|
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||||
|
|
||||||
|
# Write remaining data if any
|
||||||
|
if paths_to_cat:
|
||||||
|
concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys)
|
||||||
|
|
||||||
|
return episodes_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_keys(root):
|
||||||
|
info = load_info(root)
|
||||||
|
features = info["features"]
|
||||||
|
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||||
|
return video_keys
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_keys(root):
|
||||||
|
info = load_info(root)
|
||||||
|
features = info["features"]
|
||||||
|
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
|
||||||
|
return image_keys
|
||||||
|
|
||||||
|
|
||||||
|
def convert_videos(root: Path, new_root: Path):
|
||||||
|
video_keys = get_video_keys(root)
|
||||||
|
if len(video_keys) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
video_keys = sorted(video_keys)
|
||||||
|
|
||||||
|
eps_metadata_per_cam = []
|
||||||
|
for camera in video_keys:
|
||||||
|
eps_metadata = convert_videos_of_camera(root, new_root, camera)
|
||||||
|
eps_metadata_per_cam.append(eps_metadata)
|
||||||
|
|
||||||
|
num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam]
|
||||||
|
if len(set(num_eps_per_cam)) != 1:
|
||||||
|
raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).")
|
||||||
|
|
||||||
|
episods_metadata = []
|
||||||
|
num_cameras = len(video_keys)
|
||||||
|
num_episodes = num_eps_per_cam[0]
|
||||||
|
for ep_idx in range(num_episodes):
|
||||||
|
# Sanity check
|
||||||
|
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
|
||||||
|
ep_ids += [ep_idx]
|
||||||
|
if len(set(ep_ids)) != 1:
|
||||||
|
raise ValueError(f"All episode indices need to match ({ep_ids}).")
|
||||||
|
|
||||||
|
ep_dict = {}
|
||||||
|
for cam_idx in range(num_cameras):
|
||||||
|
ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx])
|
||||||
|
episods_metadata.append(ep_dict)
|
||||||
|
|
||||||
|
return episods_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
||||||
|
# Access old paths to mp4
|
||||||
|
videos_dir = root / "videos"
|
||||||
|
ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4"))
|
||||||
|
|
||||||
|
ep_idx = 0
|
||||||
|
chunk_idx = 0
|
||||||
|
file_idx = 0
|
||||||
|
size_in_mb = 0
|
||||||
|
duration_in_s = 0.0
|
||||||
|
paths_to_cat = []
|
||||||
|
episodes_metadata = []
|
||||||
|
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||||
|
ep_size_in_mb = get_video_size_in_mb(ep_path)
|
||||||
|
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||||
|
ep_metadata = {
|
||||||
|
"episode_index": ep_idx,
|
||||||
|
f"videos/{video_key}/chunk_index": chunk_idx,
|
||||||
|
f"videos/{video_key}/file_index": file_idx,
|
||||||
|
f"videos/{video_key}/from_timestamp": duration_in_s,
|
||||||
|
f"videos/{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
|
||||||
|
}
|
||||||
|
size_in_mb += ep_size_in_mb
|
||||||
|
duration_in_s += ep_duration_in_s
|
||||||
|
episodes_metadata.append(ep_metadata)
|
||||||
|
ep_idx += 1
|
||||||
|
|
||||||
|
if size_in_mb < DEFAULT_VIDEO_FILE_SIZE_IN_MB:
|
||||||
|
paths_to_cat.append(ep_path)
|
||||||
|
continue
|
||||||
|
|
||||||
|
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
|
||||||
|
|
||||||
|
# Reset for the next file
|
||||||
|
size_in_mb = ep_size_in_mb
|
||||||
|
duration_in_s = ep_duration_in_s
|
||||||
|
paths_to_cat = [ep_path]
|
||||||
|
|
||||||
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||||
|
|
||||||
|
# Write remaining videos if any
|
||||||
|
if paths_to_cat:
|
||||||
|
concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx)
|
||||||
|
|
||||||
|
return episodes_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def generate_episode_metadata_dict(
|
||||||
|
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
|
||||||
|
):
|
||||||
|
num_episodes = len(episodes_metadata)
|
||||||
|
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
|
||||||
|
episodes_stats_vals = list(episodes_stats.values())
|
||||||
|
episodes_stats_keys = list(episodes_stats.keys())
|
||||||
|
|
||||||
|
for i in range(num_episodes):
|
||||||
|
ep_legacy_metadata = episodes_legacy_metadata_vals[i]
|
||||||
|
ep_metadata = episodes_metadata[i]
|
||||||
|
ep_stats = episodes_stats_vals[i]
|
||||||
|
|
||||||
|
ep_ids_set = {
|
||||||
|
ep_legacy_metadata["episode_index"],
|
||||||
|
ep_metadata["episode_index"],
|
||||||
|
episodes_stats_keys[i],
|
||||||
|
}
|
||||||
|
|
||||||
|
if episodes_videos is None:
|
||||||
|
ep_video = {}
|
||||||
|
else:
|
||||||
|
ep_video = episodes_videos[i]
|
||||||
|
ep_ids_set.add(ep_video["episode_index"])
|
||||||
|
|
||||||
|
if len(ep_ids_set) != 1:
|
||||||
|
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||||
|
|
||||||
|
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||||
|
ep_dict["meta/episodes/chunk_index"] = 0
|
||||||
|
ep_dict["meta/episodes/file_index"] = 0
|
||||||
|
yield ep_dict
|
||||||
|
|
||||||
|
|
||||||
|
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||||
|
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||||
|
episodes_stats = legacy_load_episodes_stats(root)
|
||||||
|
|
||||||
|
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
|
||||||
|
if episodes_video_metadata is not None:
|
||||||
|
num_eps_set.add(len(episodes_video_metadata))
|
||||||
|
|
||||||
|
if len(num_eps_set) != 1:
|
||||||
|
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
|
||||||
|
|
||||||
|
ds_episodes = Dataset.from_generator(
|
||||||
|
lambda: generate_episode_metadata_dict(
|
||||||
|
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||||
|
)
|
||||||
|
)
|
||||||
|
write_episodes(ds_episodes, new_root)
|
||||||
|
|
||||||
|
stats = aggregate_stats(list(episodes_stats.values()))
|
||||||
|
write_stats(stats, new_root)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_info(root, new_root):
|
||||||
|
info = load_info(root)
|
||||||
|
info["codebase_version"] = "v3.0"
|
||||||
|
del info["total_chunks"]
|
||||||
|
del info["total_videos"]
|
||||||
|
info["data_files_size_in_mb"] = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||||
|
info["video_files_size_in_mb"] = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||||
|
info["data_path"] = DEFAULT_DATA_PATH
|
||||||
|
info["video_path"] = DEFAULT_VIDEO_PATH
|
||||||
|
info["fps"] = float(info["fps"])
|
||||||
|
for key in info["features"]:
|
||||||
|
if info["features"][key]["dtype"] == "video":
|
||||||
|
# already has fps in video_info
|
||||||
|
continue
|
||||||
|
info["features"][key]["fps"] = info["fps"]
|
||||||
|
write_info(info, new_root)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dataset(
|
||||||
|
repo_id: str,
|
||||||
|
branch: str | None = None,
|
||||||
|
num_workers: int = 4,
|
||||||
|
):
|
||||||
|
root = HF_LEROBOT_HOME / repo_id
|
||||||
|
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
|
||||||
|
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
|
||||||
|
|
||||||
|
if old_root.is_dir() and root.is_dir():
|
||||||
|
shutil.rmtree(str(root))
|
||||||
|
shutil.move(str(old_root), str(root))
|
||||||
|
|
||||||
|
if new_root.is_dir():
|
||||||
|
shutil.rmtree(new_root)
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=V21,
|
||||||
|
local_dir=root,
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_info(root, new_root)
|
||||||
|
convert_tasks(root, new_root)
|
||||||
|
episodes_metadata = convert_data(root, new_root)
|
||||||
|
episodes_videos_metadata = convert_videos(root, new_root)
|
||||||
|
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
||||||
|
|
||||||
|
shutil.move(str(root), str(old_root))
|
||||||
|
shutil.move(str(new_root), str(root))
|
||||||
|
|
||||||
|
hub_api = HfApi()
|
||||||
|
try:
|
||||||
|
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||||
|
except HTTPError as e:
|
||||||
|
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||||
|
pass
|
||||||
|
hub_api.delete_files(
|
||||||
|
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||||
|
repo_id=repo_id,
|
||||||
|
revision=branch,
|
||||||
|
repo_type="dataset",
|
||||||
|
)
|
||||||
|
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||||
|
|
||||||
|
LeRobotDataset(repo_id).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||||
|
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--branch",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_dataset(**vars(args))
|
||||||
@@ -13,15 +13,16 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import glob
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import av
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
@@ -101,7 +102,7 @@ def decode_video_frames_torchvision(
|
|||||||
keyframes_only = False
|
keyframes_only = False
|
||||||
torchvision.set_video_backend(backend)
|
torchvision.set_video_backend(backend)
|
||||||
if backend == "pyav":
|
if backend == "pyav":
|
||||||
keyframes_only = True # pyav doesn't support accurate seek
|
keyframes_only = True # pyav doesnt support accuracte seek
|
||||||
|
|
||||||
# set a video stream reader
|
# set a video stream reader
|
||||||
# TODO(rcadene): also load audio stream at the same time
|
# TODO(rcadene): also load audio stream at the same time
|
||||||
@@ -154,6 +155,7 @@ def decode_video_frames_torchvision(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get closest frames to the query timestamps
|
# get closest frames to the query timestamps
|
||||||
|
# TODO(rcadene): remove torch.stack
|
||||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||||
closest_ts = loaded_ts[argmin_]
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
@@ -251,83 +253,51 @@ def encode_video_frames(
|
|||||||
g: int | None = 2,
|
g: int | None = 2,
|
||||||
crf: int | None = 30,
|
crf: int | None = 30,
|
||||||
fast_decode: int = 0,
|
fast_decode: int = 0,
|
||||||
log_level: int | None = av.logging.ERROR,
|
log_level: str | None = "quiet",
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||||
# Check encoder availability
|
|
||||||
if vcodec not in ["h264", "hevc", "libsvtav1"]:
|
|
||||||
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
|
|
||||||
|
|
||||||
video_path = Path(video_path)
|
video_path = Path(video_path)
|
||||||
imgs_dir = Path(imgs_dir)
|
imgs_dir = Path(imgs_dir)
|
||||||
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
video_path.parent.mkdir(parents=True, exist_ok=overwrite)
|
ffmpeg_args = OrderedDict(
|
||||||
|
[
|
||||||
# Encoders/pixel formats incompatibility check
|
("-f", "image2"),
|
||||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
("-r", str(fps)),
|
||||||
logging.warning(
|
("-i", str(imgs_dir / "frame-%06d.png")),
|
||||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
("-vcodec", vcodec),
|
||||||
)
|
("-pix_fmt", pix_fmt),
|
||||||
pix_fmt = "yuv420p"
|
]
|
||||||
|
|
||||||
# Get input frames
|
|
||||||
template = "frame_" + ("[0-9]" * 6) + ".png"
|
|
||||||
input_list = sorted(
|
|
||||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("_")[-1].split(".")[0])
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define video output frame size (assuming all input frames are the same size)
|
|
||||||
if len(input_list) == 0:
|
|
||||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
|
||||||
dummy_image = Image.open(input_list[0])
|
|
||||||
width, height = dummy_image.size
|
|
||||||
|
|
||||||
# Define video codec options
|
|
||||||
video_options = {}
|
|
||||||
|
|
||||||
if g is not None:
|
if g is not None:
|
||||||
video_options["g"] = str(g)
|
ffmpeg_args["-g"] = str(g)
|
||||||
|
|
||||||
if crf is not None:
|
if crf is not None:
|
||||||
video_options["crf"] = str(crf)
|
ffmpeg_args["-crf"] = str(crf)
|
||||||
|
|
||||||
if fast_decode:
|
if fast_decode:
|
||||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune"
|
||||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||||
video_options[key] = value
|
ffmpeg_args[key] = value
|
||||||
|
|
||||||
# Set logging level
|
|
||||||
if log_level is not None:
|
if log_level is not None:
|
||||||
# "While less efficient, it is generally preferable to modify logging with Python’s logging"
|
ffmpeg_args["-loglevel"] = str(log_level)
|
||||||
logging.getLogger("libav").setLevel(log_level)
|
|
||||||
|
|
||||||
# Create and open output file (overwrite by default)
|
ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair]
|
||||||
with av.open(str(video_path), "w") as output:
|
if overwrite:
|
||||||
output_stream = output.add_stream(vcodec, fps, options=video_options)
|
ffmpeg_args.append("-y")
|
||||||
output_stream.pix_fmt = pix_fmt
|
|
||||||
output_stream.width = width
|
|
||||||
output_stream.height = height
|
|
||||||
|
|
||||||
# Loop through input frames and encode them
|
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||||
for input_data in input_list:
|
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||||
input_image = Image.open(input_data).convert("RGB")
|
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||||
input_frame = av.VideoFrame.from_image(input_image)
|
|
||||||
packet = output_stream.encode(input_frame)
|
|
||||||
if packet:
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
# Flush the encoder
|
|
||||||
packet = output_stream.encode()
|
|
||||||
if packet:
|
|
||||||
output.mux(packet)
|
|
||||||
|
|
||||||
# Reset logging level
|
|
||||||
if log_level is not None:
|
|
||||||
av.logging.restore_default_callback()
|
|
||||||
|
|
||||||
if not video_path.exists():
|
if not video_path.exists():
|
||||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
raise OSError(
|
||||||
|
f"Video encoding did not work. File not found: {video_path}. "
|
||||||
|
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -363,68 +333,78 @@ with warnings.catch_warnings():
|
|||||||
|
|
||||||
|
|
||||||
def get_audio_info(video_path: Path | str) -> dict:
|
def get_audio_info(video_path: Path | str) -> dict:
|
||||||
# Set logging level
|
ffprobe_audio_cmd = [
|
||||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
"ffprobe",
|
||||||
|
"-v",
|
||||||
|
"error",
|
||||||
|
"-select_streams",
|
||||||
|
"a:0",
|
||||||
|
"-show_entries",
|
||||||
|
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||||
|
"-of",
|
||||||
|
"json",
|
||||||
|
str(video_path),
|
||||||
|
]
|
||||||
|
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
# Getting audio stream information
|
info = json.loads(result.stdout)
|
||||||
audio_info = {}
|
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||||
with av.open(str(video_path), "r") as audio_file:
|
if audio_stream_info is None:
|
||||||
try:
|
return {"has_audio": False}
|
||||||
audio_stream = audio_file.streams.audio[0]
|
|
||||||
except IndexError:
|
|
||||||
# Reset logging level
|
|
||||||
av.logging.restore_default_callback()
|
|
||||||
return {"has_audio": False}
|
|
||||||
|
|
||||||
audio_info["audio.channels"] = audio_stream.channels
|
# Return the information, defaulting to None if no audio stream is present
|
||||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
return {
|
||||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
"has_audio": True,
|
||||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
"audio.channels": audio_stream_info.get("channels", None),
|
||||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||||
# In an ideal loseless case : fixed number of bits per sample.
|
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
if audio_stream_info.get("sample_rate")
|
||||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
else None,
|
||||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||||
audio_info["has_audio"] = True
|
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||||
|
}
|
||||||
# Reset logging level
|
|
||||||
av.logging.restore_default_callback()
|
|
||||||
|
|
||||||
return audio_info
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_info(video_path: Path | str) -> dict:
|
def get_video_info(video_path: Path | str) -> dict:
|
||||||
# Set logging level
|
ffprobe_video_cmd = [
|
||||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
"ffprobe",
|
||||||
|
"-v",
|
||||||
|
"error",
|
||||||
|
"-select_streams",
|
||||||
|
"v:0",
|
||||||
|
"-show_entries",
|
||||||
|
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||||
|
"-of",
|
||||||
|
"json",
|
||||||
|
str(video_path),
|
||||||
|
]
|
||||||
|
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
# Getting video stream information
|
info = json.loads(result.stdout)
|
||||||
video_info = {}
|
video_stream_info = info["streams"][0]
|
||||||
with av.open(str(video_path), "r") as video_file:
|
|
||||||
try:
|
|
||||||
video_stream = video_file.streams.video[0]
|
|
||||||
except IndexError:
|
|
||||||
# Reset logging level
|
|
||||||
av.logging.restore_default_callback()
|
|
||||||
return {}
|
|
||||||
|
|
||||||
video_info["video.height"] = video_stream.height
|
# Calculate fps from r_frame_rate
|
||||||
video_info["video.width"] = video_stream.width
|
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
num, denom = map(int, r_frame_rate.split("/"))
|
||||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
fps = num / denom
|
||||||
video_info["video.is_depth_map"] = False
|
|
||||||
|
|
||||||
# Calculate fps from r_frame_rate
|
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||||
video_info["video.fps"] = int(video_stream.base_rate)
|
|
||||||
|
|
||||||
pixel_channels = get_video_pixel_channels(video_stream.pix_fmt)
|
video_info = {
|
||||||
video_info["video.channels"] = pixel_channels
|
"video.fps": fps,
|
||||||
|
"video.height": video_stream_info["height"],
|
||||||
# Reset logging level
|
"video.width": video_stream_info["width"],
|
||||||
av.logging.restore_default_callback()
|
"video.channels": pixel_channels,
|
||||||
|
"video.codec": video_stream_info["codec_name"],
|
||||||
# Adding audio stream information
|
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||||
video_info.update(**get_audio_info(video_path))
|
"video.is_depth_map": False,
|
||||||
|
**get_audio_info(video_path),
|
||||||
|
}
|
||||||
|
|
||||||
return video_info
|
return video_info
|
||||||
|
|
||||||
|
|||||||
@@ -596,7 +596,8 @@ class ReplayBuffer:
|
|||||||
frame_dict[f"complementary_info.{key}"] = val
|
frame_dict[f"complementary_info.{key}"] = val
|
||||||
|
|
||||||
# Add to the dataset's buffer
|
# Add to the dataset's buffer
|
||||||
lerobot_dataset.add_frame(frame_dict, task=task_name)
|
frame_dict["task"] = task_name
|
||||||
|
lerobot_dataset.add_frame(frame_dict)
|
||||||
|
|
||||||
# Move to next frame
|
# Move to next frame
|
||||||
frame_idx_in_episode += 1
|
frame_idx_in_episode += 1
|
||||||
|
|||||||
@@ -263,6 +263,16 @@ def move_cursor_up(lines):
|
|||||||
print(f"\033[{lines}A", end="")
|
print(f"\033[{lines}A", end="")
|
||||||
|
|
||||||
|
|
||||||
|
def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
|
||||||
|
days = int(elapsed_time_s // (24 * 3600))
|
||||||
|
elapsed_time_s %= 24 * 3600
|
||||||
|
hours = int(elapsed_time_s // 3600)
|
||||||
|
elapsed_time_s %= 3600
|
||||||
|
minutes = int(elapsed_time_s // 60)
|
||||||
|
seconds = elapsed_time_s % 60
|
||||||
|
return days, hours, minutes, seconds
|
||||||
|
|
||||||
|
|
||||||
class TimerManager:
|
class TimerManager:
|
||||||
"""
|
"""
|
||||||
Lightweight utility to measure elapsed time.
|
Lightweight utility to measure elapsed time.
|
||||||
|
|||||||
@@ -218,8 +218,8 @@ def record_loop(
|
|||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||||
frame = {**observation_frame, **action_frame}
|
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||||
dataset.add_frame(frame, task=single_task)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
if display_data:
|
if display_data:
|
||||||
for obs, val in observation.items():
|
for obs, val in observation.items():
|
||||||
|
|||||||
@@ -227,7 +227,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
|||||||
|
|
||||||
new_frame[key] = value
|
new_frame[key] = value
|
||||||
|
|
||||||
new_dataset.add_frame(new_frame, task=task)
|
new_frame["task"] = task
|
||||||
|
new_dataset.add_frame(new_frame)
|
||||||
|
|
||||||
if frame["episode_index"].item() != prev_episode_index:
|
if frame["episode_index"].item() != prev_episode_index:
|
||||||
# Save the episode
|
# Save the episode
|
||||||
|
|||||||
@@ -2132,7 +2132,8 @@ def record_dataset(env, policy, cfg):
|
|||||||
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
||||||
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
||||||
)
|
)
|
||||||
dataset.add_frame(frame, task=cfg.task)
|
frame["task"] = cfg.task
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
# Maintain consistent timing
|
# Maintain consistent timing
|
||||||
if cfg.fps:
|
if cfg.fps:
|
||||||
|
|||||||
@@ -166,7 +166,8 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||||
shuffle = False
|
shuffle = False
|
||||||
sampler = EpisodeAwareSampler(
|
sampler = EpisodeAwareSampler(
|
||||||
dataset.episode_data_index,
|
dataset.meta.episodes["dataset_from_index"],
|
||||||
|
dataset.meta.episodes["dataset_to_index"],
|
||||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,8 +79,8 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
class EpisodeSampler(torch.utils.data.Sampler):
|
||||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||||
self.frame_ids = range(from_idx, to_idx)
|
self.frame_ids = range(from_idx, to_idx)
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
def __iter__(self) -> Iterator:
|
||||||
@@ -283,7 +283,7 @@ def main():
|
|||||||
tolerance_s = kwargs.pop("tolerance_s")
|
tolerance_s = kwargs.pop("tolerance_s")
|
||||||
|
|
||||||
logging.info("Loading dataset")
|
logging.info("Loading dataset")
|
||||||
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
|
||||||
|
|
||||||
visualize_dataset(dataset, **vars(args))
|
visualize_dataset(dataset, **vars(args))
|
||||||
|
|
||||||
|
|||||||
@@ -271,8 +271,8 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
|||||||
selected_columns.insert(0, "timestamp")
|
selected_columns.insert(0, "timestamp")
|
||||||
|
|
||||||
if isinstance(dataset, LeRobotDataset):
|
if isinstance(dataset, LeRobotDataset):
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||||
data = (
|
data = (
|
||||||
dataset.hf_dataset.select(range(from_idx, to_idx))
|
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||||
.select_columns(selected_columns)
|
.select_columns(selected_columns)
|
||||||
@@ -308,7 +308,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
|||||||
|
|
||||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||||
# get first frame of episode (hack to get video_path of the episode)
|
# get first frame of episode (hack to get video_path of the episode)
|
||||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
first_frame_idx = dataset.meta.episodes["dataset_from_index"][ep_index]
|
||||||
return [
|
return [
|
||||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||||
for key in dataset.meta.video_keys
|
for key in dataset.meta.video_keys
|
||||||
@@ -321,7 +321,7 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# get first frame index
|
# get first frame index
|
||||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
first_frame_idx = dataset.meta.episodes["dataset_from_index"][ep_index]
|
||||||
|
|
||||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||||
|
|||||||
@@ -47,17 +47,23 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# save 2 first frames of first episode
|
# save 2 first frames of first episode
|
||||||
i = dataset.episode_data_index["from"][0].item()
|
i = dataset.meta.episodes["dataset_from_index"][0].item()
|
||||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||||
|
|
||||||
# save 2 frames at the middle of first episode
|
# save 2 frames at the middle of first episode
|
||||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
i = int(
|
||||||
|
(
|
||||||
|
dataset.meta.episodes["dataset_to_index"][0].item()
|
||||||
|
- dataset.meta.episodes["dataset_from_index"][0].item()
|
||||||
|
)
|
||||||
|
/ 2
|
||||||
|
)
|
||||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||||
|
|
||||||
# save 2 last frames of first episode
|
# save 2 last frames of first episode
|
||||||
i = dataset.episode_data_index["to"][0].item()
|
i = dataset.meta.episodes["dataset_to_index"][0].item()
|
||||||
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
|
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
|
||||||
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
|
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
|
||||||
|
|
||||||
@@ -65,17 +71,17 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
|||||||
# We currently cant because our test dataset only contains the first episode
|
# We currently cant because our test dataset only contains the first episode
|
||||||
|
|
||||||
# # save 2 first frames of second episode
|
# # save 2 first frames of second episode
|
||||||
# i = dataset.episode_data_index["from"][1].item()
|
# i = dataset.meta.episodes["dataset_from_index"][1].item()
|
||||||
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||||
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||||
|
|
||||||
# # save 2 last frames of second episode
|
# # save 2 last frames of second episode
|
||||||
# i = dataset.episode_data_index["to"][1].item()
|
# i = dataset.meta.episodes["dataset_to_index"][1].item()
|
||||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||||
|
|
||||||
# # save 2 last frames of last episode
|
# # save 2 last frames of last episode
|
||||||
# i = dataset.episode_data_index["to"][-1].item()
|
# i = dataset.meta.episodes["dataset_to_index"][-1].item()
|
||||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||||
|
|
||||||
|
|||||||
203
tests/datasets/test_aggregate.py
Normal file
203
tests/datasets/test_aggregate.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.aggregate import aggregate_datasets
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||||
|
|
||||||
|
|
||||||
|
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||||
|
"""Test that total number of episodes and frames are correctly aggregated."""
|
||||||
|
assert aggr_ds.num_episodes == expected_episodes, (
|
||||||
|
f"Expected {expected_episodes} episodes, got {aggr_ds.num_episodes}"
|
||||||
|
)
|
||||||
|
assert aggr_ds.num_frames == expected_frames, (
|
||||||
|
f"Expected {expected_frames} frames, got {aggr_ds.num_frames}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
|
||||||
|
"""Test that the content of both datasets is preserved correctly in the aggregated dataset."""
|
||||||
|
# Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0]
|
||||||
|
aggr_first_item = aggr_ds[0]
|
||||||
|
ds_0_first_item = ds_0[0]
|
||||||
|
|
||||||
|
# Compare all keys except episode_index and index which should be updated
|
||||||
|
for key in ds_0_first_item:
|
||||||
|
if key not in ["episode_index", "index"]:
|
||||||
|
# Handle both tensor and non-tensor data
|
||||||
|
if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]):
|
||||||
|
assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), (
|
||||||
|
f"First item key '{key}' doesn't match between aggregated and ds_0"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert aggr_first_item[key] == ds_0_first_item[key], (
|
||||||
|
f"First item key '{key}' doesn't match between aggregated and ds_0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check last item of ds_0 part (index len(ds_0)-1) matches ds_0[-1]
|
||||||
|
aggr_ds_0_last_item = aggr_ds[len(ds_0) - 1]
|
||||||
|
ds_0_last_item = ds_0[-1]
|
||||||
|
|
||||||
|
for key in ds_0_last_item:
|
||||||
|
if key not in ["episode_index", "index"]:
|
||||||
|
# Handle both tensor and non-tensor data
|
||||||
|
if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]):
|
||||||
|
assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), (
|
||||||
|
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert aggr_ds_0_last_item[key] == ds_0_last_item[key], (
|
||||||
|
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test second part of dataset corresponds to ds_1
|
||||||
|
# Check first item of ds_1 part (index len(ds_0)) matches ds_1[0]
|
||||||
|
aggr_ds_1_first_item = aggr_ds[len(ds_0)]
|
||||||
|
ds_1_first_item = ds_1[0]
|
||||||
|
|
||||||
|
for key in ds_1_first_item:
|
||||||
|
if key not in ["episode_index", "index"]:
|
||||||
|
# Handle both tensor and non-tensor data
|
||||||
|
if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]):
|
||||||
|
assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), (
|
||||||
|
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert aggr_ds_1_first_item[key] == ds_1_first_item[key], (
|
||||||
|
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check last item matches ds_1[-1]
|
||||||
|
aggr_last_item = aggr_ds[-1]
|
||||||
|
ds_1_last_item = ds_1[-1]
|
||||||
|
|
||||||
|
for key in ds_1_last_item:
|
||||||
|
if key not in ["episode_index", "index"]:
|
||||||
|
# Handle both tensor and non-tensor data
|
||||||
|
if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]):
|
||||||
|
assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), (
|
||||||
|
f"Last item key '{key}' doesn't match between aggregated and ds_1"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert aggr_last_item[key] == ds_1_last_item[key], (
|
||||||
|
f"Last item key '{key}' doesn't match between aggregated and ds_1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
|
||||||
|
"""Test that metadata is correctly aggregated."""
|
||||||
|
# Test basic info
|
||||||
|
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
|
||||||
|
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
|
||||||
|
"Robot type should be the same"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test features are the same
|
||||||
|
assert aggr_ds.features == ds_0.features == ds_1.features, "Features should be the same"
|
||||||
|
|
||||||
|
# Test tasks aggregation
|
||||||
|
expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index)
|
||||||
|
actual_tasks = set(aggr_ds.meta.tasks.index)
|
||||||
|
assert actual_tasks == expected_tasks, f"Expected tasks {expected_tasks}, got {actual_tasks}"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1):
|
||||||
|
"""Test that episode indices are correctly updated after aggregation."""
|
||||||
|
# ds_0 episodes should have episode_index 0 to ds_0.num_episodes-1
|
||||||
|
for i in range(len(ds_0)):
|
||||||
|
assert aggr_ds[i]["episode_index"] < ds_0.num_episodes, (
|
||||||
|
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be < {ds_0.num_episodes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def ds1_episodes_condition(ep_idx):
|
||||||
|
return (ep_idx >= ds_0.num_episodes) and (ep_idx < ds_0.num_episodes + ds_1.num_episodes)
|
||||||
|
|
||||||
|
# ds_1 episodes should have episode_index ds_0.num_episodes to total_episodes-1
|
||||||
|
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
||||||
|
expected_min_episode_idx = ds_0.num_episodes
|
||||||
|
assert ds1_episodes_condition(aggr_ds[i]["episode_index"]), (
|
||||||
|
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be >= {expected_min_episode_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||||
|
"""Test that video frames are correctly preserved and frame indices are updated."""
|
||||||
|
|
||||||
|
def visual_frames_equal(frame1, frame2):
|
||||||
|
return torch.allclose(frame1, frame2)
|
||||||
|
|
||||||
|
video_keys = list(
|
||||||
|
filter(
|
||||||
|
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
|
||||||
|
aggr_ds.meta.info["features"].keys(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test the section corresponding to the first dataset (ds_0)
|
||||||
|
for i in range(len(ds_0)):
|
||||||
|
assert aggr_ds[i]["index"] == i, (
|
||||||
|
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||||
|
)
|
||||||
|
for key in video_keys:
|
||||||
|
assert visual_frames_equal(aggr_ds[i][key], ds_0[i][key]), (
|
||||||
|
f"Visual frames at position {i} should be equal between aggregated and ds_0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test the section corresponding to the second dataset (ds_1)
|
||||||
|
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
||||||
|
# The frame index in the aggregated dataset should also match its position.
|
||||||
|
assert aggr_ds[i]["index"] == i, (
|
||||||
|
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||||
|
)
|
||||||
|
for key in video_keys:
|
||||||
|
assert visual_frames_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
|
||||||
|
f"Visual frames at position {i} should be equal between aggregated and ds_1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_dataset_iteration_works(aggr_ds):
|
||||||
|
"""Test that we can iterate through the entire dataset without errors."""
|
||||||
|
for _ in aggr_ds:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""Test basic aggregation functionality with standard parameters."""
|
||||||
|
ds_0_num_frames = 400
|
||||||
|
ds_1_num_frames = 400
|
||||||
|
ds_0_num_episodes = 10
|
||||||
|
ds_1_num_episodes = 10
|
||||||
|
|
||||||
|
# Create two datasets with different number of frames and episodes
|
||||||
|
ds_0 = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "test_0",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||||
|
total_episodes=ds_0_num_episodes,
|
||||||
|
total_frames=ds_0_num_frames,
|
||||||
|
)
|
||||||
|
ds_1 = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "test_1",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||||
|
total_episodes=ds_1_num_episodes,
|
||||||
|
total_frames=ds_1_num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregate_datasets(
|
||||||
|
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||||
|
roots=[ds_0.root, ds_1.root],
|
||||||
|
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
||||||
|
aggr_root=tmp_path / "test_aggr",
|
||||||
|
)
|
||||||
|
|
||||||
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||||
|
|
||||||
|
# Run all assertion functions
|
||||||
|
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
|
||||||
|
expected_total_frames = ds_0.num_frames + ds_1.num_frames
|
||||||
|
|
||||||
|
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
||||||
|
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
||||||
|
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||||
|
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||||
|
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||||
|
assert_dataset_iteration_works(aggr_ds)
|
||||||
@@ -13,10 +13,8 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -36,14 +34,15 @@ from lerobot.common.datasets.lerobot_dataset import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
create_branch,
|
create_branch,
|
||||||
flatten_dict,
|
hw_to_dataset_features,
|
||||||
unflatten_dict,
|
|
||||||
)
|
)
|
||||||
from lerobot.common.envs.factory import make_env_config
|
from lerobot.common.envs.factory import make_env_config
|
||||||
from lerobot.common.policies.factory import make_policy_config
|
from lerobot.common.policies.factory import make_policy_config
|
||||||
|
from lerobot.common.robots import make_robot_from_config
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
|
from tests.mocks.mock_robot import MockRobotConfig
|
||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@@ -69,12 +68,17 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
|||||||
objects have the same sets of attributes defined.
|
objects have the same sets of attributes defined.
|
||||||
"""
|
"""
|
||||||
# Instantiate both ways
|
# Instantiate both ways
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
robot = make_robot_from_config(MockRobotConfig())
|
||||||
|
action_features = hw_to_dataset_features(robot.action_features, "action", True)
|
||||||
|
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
|
||||||
|
dataset_features = {**action_features, **obs_features}
|
||||||
root_create = tmp_path / "create"
|
root_create = tmp_path / "create"
|
||||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create)
|
dataset_create = LeRobotDataset.create(
|
||||||
|
repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create
|
||||||
|
)
|
||||||
|
|
||||||
root_init = tmp_path / "init"
|
root_init = tmp_path / "init"
|
||||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||||
|
|
||||||
init_attr = set(vars(dataset_init).keys())
|
init_attr = set(vars(dataset_init).keys())
|
||||||
create_attr = set(vars(dataset_create).keys())
|
create_attr = set(vars(dataset_create).keys())
|
||||||
@@ -99,13 +103,41 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
|||||||
assert dataset.num_frames == len(dataset)
|
assert dataset.num_frames == len(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create
|
||||||
|
# and test the small resulting function that validates the features
|
||||||
|
def test_dataset_feature_with_forward_slash_raises_error():
|
||||||
|
# make sure dir does not exist
|
||||||
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
|
||||||
|
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||||
|
# make sure does not exist
|
||||||
|
if dataset_dir.exists():
|
||||||
|
dataset_dir.rmdir()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
LeRobotDataset.create(
|
||||||
|
repo_id="lerobot/test/with/slash",
|
||||||
|
fps=30,
|
||||||
|
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
|
||||||
|
):
|
||||||
|
dataset.add_frame({"state": torch.randn(1)})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
||||||
):
|
):
|
||||||
dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task")
|
dataset.add_frame({"task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -114,7 +146,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task")
|
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -123,7 +155,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task")
|
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -133,7 +165,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
ValueError,
|
ValueError,
|
||||||
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -145,7 +177,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact
|
|||||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": 1.0}, task="Dummy task")
|
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -155,7 +187,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
|||||||
ValueError,
|
ValueError,
|
||||||
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task")
|
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||||
@@ -167,13 +199,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
|||||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task")
|
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
@@ -185,7 +217,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2)}, task="Dummy task")
|
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2])
|
assert dataset[0]["state"].shape == torch.Size([2])
|
||||||
@@ -194,7 +226,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task")
|
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||||
@@ -203,7 +235,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task")
|
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||||
@@ -212,7 +244,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task")
|
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||||
@@ -221,7 +253,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task")
|
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||||
@@ -230,7 +262,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task")
|
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["state"].ndim == 0
|
assert dataset[0]["state"].ndim == 0
|
||||||
@@ -239,7 +271,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task")
|
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["caption"] == "Dummy caption"
|
assert dataset[0]["caption"] == "Dummy caption"
|
||||||
@@ -254,7 +286,7 @@ def test_add_frame_image_wrong_shape(image_dataset):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
c, h, w = DUMMY_CHW
|
c, h, w = DUMMY_CHW
|
||||||
dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task")
|
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image_wrong_range(image_dataset):
|
def test_add_frame_image_wrong_range(image_dataset):
|
||||||
@@ -267,14 +299,14 @@ def test_add_frame_image_wrong_range(image_dataset):
|
|||||||
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
||||||
"""
|
"""
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task")
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_image(image_dataset):
|
def test_add_frame_image(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task")
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -282,7 +314,7 @@ def test_add_frame_image(image_dataset):
|
|||||||
|
|
||||||
def test_add_frame_image_h_w_c(image_dataset):
|
def test_add_frame_image_h_w_c(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task")
|
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -291,7 +323,7 @@ def test_add_frame_image_h_w_c(image_dataset):
|
|||||||
def test_add_frame_image_uint8(image_dataset):
|
def test_add_frame_image_uint8(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||||
dataset.add_frame({"image": image}, task="Dummy task")
|
dataset.add_frame({"image": image, "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -300,7 +332,7 @@ def test_add_frame_image_uint8(image_dataset):
|
|||||||
def test_add_frame_image_pil(image_dataset):
|
def test_add_frame_image_pil(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||||
dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task")
|
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
|
|
||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
@@ -319,6 +351,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
|||||||
# - [ ] test push_to_hub
|
# - [ ] test push_to_hub
|
||||||
# - [ ] test smaller methods
|
# - [ ] test smaller methods
|
||||||
|
|
||||||
|
# TODO(rcadene):
|
||||||
|
# - [ ] fix code so that old test_factory + backward pass
|
||||||
|
# - [ ] write new unit tests to test save_episode + getitem
|
||||||
|
# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos)
|
||||||
|
# - [ ]
|
||||||
|
# - [ ] remove old tests
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name, repo_id, policy_name",
|
"env_name, repo_id, policy_name",
|
||||||
@@ -338,9 +377,8 @@ def test_factory(env_name, repo_id, policy_name):
|
|||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
||||||
env=make_env_config(env_name),
|
env=make_env_config(env_name),
|
||||||
policy=make_policy_config(policy_name, push_to_hub=False),
|
policy=make_policy_config(policy_name),
|
||||||
)
|
)
|
||||||
cfg.validate()
|
|
||||||
|
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
delta_timestamps = dataset.delta_timestamps
|
delta_timestamps = dataset.delta_timestamps
|
||||||
@@ -427,30 +465,6 @@ def test_multidataset_frames():
|
|||||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): Move to more appropriate location
|
|
||||||
def test_flatten_unflatten_dict():
|
|
||||||
d = {
|
|
||||||
"obs": {
|
|
||||||
"min": 0,
|
|
||||||
"max": 1,
|
|
||||||
"mean": 2,
|
|
||||||
"std": 3,
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"min": 4,
|
|
||||||
"max": 5,
|
|
||||||
"mean": 6,
|
|
||||||
"std": 7,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
original_d = deepcopy(d)
|
|
||||||
d = unflatten_dict(flatten_dict(d))
|
|
||||||
|
|
||||||
# test equality between nested dicts
|
|
||||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"repo_id",
|
"repo_id",
|
||||||
[
|
[
|
||||||
@@ -497,17 +511,23 @@ def test_backward_compatibility(repo_id):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# test2 first frames of first episode
|
# test2 first frames of first episode
|
||||||
i = dataset.episode_data_index["from"][0].item()
|
i = dataset.meta.episodes["dataset_from_index"][0].item()
|
||||||
load_and_compare(i)
|
load_and_compare(i)
|
||||||
load_and_compare(i + 1)
|
load_and_compare(i + 1)
|
||||||
|
|
||||||
# test 2 frames at the middle of first episode
|
# test 2 frames at the middle of first episode
|
||||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
i = int(
|
||||||
|
(
|
||||||
|
dataset.meta.episodes["dataset_to_index"][0].item()
|
||||||
|
- dataset.meta.episodes["dataset_from_index"][0].item()
|
||||||
|
)
|
||||||
|
/ 2
|
||||||
|
)
|
||||||
load_and_compare(i)
|
load_and_compare(i)
|
||||||
load_and_compare(i + 1)
|
load_and_compare(i + 1)
|
||||||
|
|
||||||
# test 2 last frames of first episode
|
# test 2 last frames of first episode
|
||||||
i = dataset.episode_data_index["to"][0].item()
|
i = dataset.meta.episodes["dataset_to_index"][0].item()
|
||||||
load_and_compare(i - 2)
|
load_and_compare(i - 2)
|
||||||
load_and_compare(i - 1)
|
load_and_compare(i - 1)
|
||||||
|
|
||||||
@@ -515,17 +535,17 @@ def test_backward_compatibility(repo_id):
|
|||||||
# We currently cant because our test dataset only contains the first episode
|
# We currently cant because our test dataset only contains the first episode
|
||||||
|
|
||||||
# # test 2 first frames of second episode
|
# # test 2 first frames of second episode
|
||||||
# i = dataset.episode_data_index["from"][1].item()
|
# i = dataset.meta.episodes["dataset_from_index"][1].item()
|
||||||
# load_and_compare(i)
|
# load_and_compare(i)
|
||||||
# load_and_compare(i + 1)
|
# load_and_compare(i + 1)
|
||||||
|
|
||||||
# # test 2 last frames of second episode
|
# # test 2 last frames of second episode
|
||||||
# i = dataset.episode_data_index["to"][1].item()
|
# i = dataset.meta.episodes["dataset_to_index"][1].item()
|
||||||
# load_and_compare(i - 2)
|
# load_and_compare(i - 2)
|
||||||
# load_and_compare(i - 1)
|
# load_and_compare(i - 1)
|
||||||
|
|
||||||
# # test 2 last frames of last episode
|
# # test 2 last frames of last episode
|
||||||
# i = dataset.episode_data_index["to"][-1].item()
|
# i = dataset.meta.episodes["dataset_to_index"][-1].item()
|
||||||
# load_and_compare(i - 2)
|
# load_and_compare(i - 2)
|
||||||
# load_and_compare(i - 1)
|
# load_and_compare(i - 1)
|
||||||
|
|
||||||
@@ -554,20 +574,3 @@ def test_create_branch():
|
|||||||
|
|
||||||
# Clean
|
# Clean
|
||||||
api.delete_repo(repo_id, repo_type=repo_type)
|
api.delete_repo(repo_id, repo_type=repo_type)
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_feature_with_forward_slash_raises_error():
|
|
||||||
# make sure dir does not exist
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
|
||||||
|
|
||||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
|
||||||
# make sure does not exist
|
|
||||||
if dataset_dir.exists():
|
|
||||||
dataset_dir.rmdir()
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
LeRobotDataset.create(
|
|
||||||
repo_id="lerobot/test/with/slash",
|
|
||||||
fps=30,
|
|
||||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ def test_drop_n_first_frames():
|
|||||||
)
|
)
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1)
|
||||||
assert sampler.indices == [1, 4, 5]
|
assert sampler.indices == [1, 4, 5]
|
||||||
assert len(sampler) == 3
|
assert len(sampler) == 3
|
||||||
assert list(sampler) == [1, 4, 5]
|
assert list(sampler) == [1, 4, 5]
|
||||||
@@ -48,7 +48,7 @@ def test_drop_n_last_frames():
|
|||||||
)
|
)
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1)
|
||||||
assert sampler.indices == [0, 3, 4]
|
assert sampler.indices == [0, 3, 4]
|
||||||
assert len(sampler) == 3
|
assert len(sampler) == 3
|
||||||
assert list(sampler) == [0, 3, 4]
|
assert list(sampler) == [0, 3, 4]
|
||||||
@@ -64,7 +64,9 @@ def test_episode_indices_to_use():
|
|||||||
)
|
)
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
sampler = EpisodeAwareSampler(
|
||||||
|
episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2]
|
||||||
|
)
|
||||||
assert sampler.indices == [0, 1, 3, 4, 5]
|
assert sampler.indices == [0, 1, 3, 4, 5]
|
||||||
assert len(sampler) == 5
|
assert len(sampler) == 5
|
||||||
assert list(sampler) == [0, 1, 3, 4, 5]
|
assert list(sampler) == [0, 1, 3, 4, 5]
|
||||||
@@ -80,11 +82,11 @@ def test_shuffle():
|
|||||||
)
|
)
|
||||||
dataset.set_transform(hf_transform_to_torch)
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False)
|
||||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||||
assert len(sampler) == 6
|
assert len(sampler) == 6
|
||||||
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
||||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
|
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True)
|
||||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||||
assert len(sampler) == 6
|
assert len(sampler) == 6
|
||||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||||
|
|||||||
@@ -14,12 +14,20 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import DatasetCard
|
from huggingface_hub import DatasetCard
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
from lerobot.common.datasets.utils import (
|
||||||
|
create_lerobot_dataset_card,
|
||||||
|
flatten_dict,
|
||||||
|
hf_transform_to_torch,
|
||||||
|
unflatten_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_default_parameters():
|
def test_default_parameters():
|
||||||
@@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
|
|||||||
episode_data_index = calculate_episode_data_index(dataset)
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_flatten_unflatten_dict():
|
||||||
|
d = {
|
||||||
|
"obs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
|
"mean": 2,
|
||||||
|
"std": 3,
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"min": 4,
|
||||||
|
"max": 5,
|
||||||
|
"mean": 6,
|
||||||
|
"std": 7,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
original_d = deepcopy(d)
|
||||||
|
d = unflatten_dict(flatten_dict(d))
|
||||||
|
|
||||||
|
# test equality between nested dicts
|
||||||
|
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||||
|
|||||||
4
tests/fixtures/constants.py
vendored
4
tests/fixtures/constants.py
vendored
@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
DUMMY_CAMERA_FEATURES = {
|
DUMMY_CAMERA_FEATURES = {
|
||||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||||
}
|
}
|
||||||
DEFAULT_FPS = 30
|
DEFAULT_FPS = 30
|
||||||
DUMMY_VIDEO_INFO = {
|
DUMMY_VIDEO_INFO = {
|
||||||
|
|||||||
312
tests/fixtures/dataset_factories.py
vendored
312
tests/fixtures/dataset_factories.py
vendored
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
@@ -19,19 +20,25 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_FEATURES,
|
DEFAULT_FEATURES,
|
||||||
DEFAULT_PARQUET_PATH,
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
|
flatten_dict,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||||
from tests.fixtures.constants import (
|
from tests.fixtures.constants import (
|
||||||
DEFAULT_FPS,
|
DEFAULT_FPS,
|
||||||
DUMMY_CAMERA_FEATURES,
|
DUMMY_CAMERA_FEATURES,
|
||||||
@@ -46,10 +53,10 @@ class LeRobotDatasetFactory(Protocol):
|
|||||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||||
|
|
||||||
|
|
||||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
def get_task_index(tasks: datasets.Dataset, task: str) -> int:
|
||||||
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
|
# TODO(rcadene): a bit complicated no? ^^
|
||||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
task_idx = tasks.loc[task].task_index.item()
|
||||||
return task_to_task_index[task]
|
return task_idx
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@@ -62,15 +69,49 @@ def img_tensor_factory():
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def img_array_factory():
|
def img_array_factory():
|
||||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8, content=None) -> np.ndarray:
|
||||||
if np.issubdtype(dtype, np.unsignedinteger):
|
if content is None:
|
||||||
# Int array in [0, 255] range
|
# Original random noise behavior
|
||||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
if np.issubdtype(dtype, np.unsignedinteger):
|
||||||
elif np.issubdtype(dtype, np.floating):
|
# Int array in [0, 255] range
|
||||||
# Float array in [0, 1] range
|
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
elif np.issubdtype(dtype, np.floating):
|
||||||
|
# Float array in [0, 1] range
|
||||||
|
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError(dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(dtype)
|
# Create image with text content using OpenCV
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# Create white background
|
||||||
|
img_array = np.ones((height, width, channels), dtype=np.uint8) * 255
|
||||||
|
|
||||||
|
# Font settings
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
font_scale = max(0.5, height / 200) # Scale font with image size
|
||||||
|
font_color = (0, 0, 0) # Black text
|
||||||
|
thickness = max(1, int(height / 100))
|
||||||
|
|
||||||
|
# Get text size to center it
|
||||||
|
text_size = cv2.getTextSize(content, font, font_scale, thickness)[0]
|
||||||
|
text_x = (width - text_size[0]) // 2
|
||||||
|
text_y = (height + text_size[1]) // 2
|
||||||
|
|
||||||
|
# Put text on image
|
||||||
|
cv2.putText(img_array, content, (text_x, text_y), font, font_scale, font_color, thickness)
|
||||||
|
|
||||||
|
# Handle single channel case
|
||||||
|
if channels == 1:
|
||||||
|
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_array = img_array[:, :, np.newaxis]
|
||||||
|
|
||||||
|
# Convert to target dtype
|
||||||
|
if np.issubdtype(dtype, np.floating):
|
||||||
|
img_array = img_array.astype(dtype) / 255.0
|
||||||
|
else:
|
||||||
|
img_array = img_array.astype(dtype)
|
||||||
|
|
||||||
return img_array
|
return img_array
|
||||||
|
|
||||||
return _create_img_array
|
return _create_img_array
|
||||||
@@ -117,9 +158,10 @@ def info_factory(features_factory):
|
|||||||
total_frames: int = 0,
|
total_frames: int = 0,
|
||||||
total_tasks: int = 0,
|
total_tasks: int = 0,
|
||||||
total_videos: int = 0,
|
total_videos: int = 0,
|
||||||
total_chunks: int = 0,
|
|
||||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
data_path: str = DEFAULT_PARQUET_PATH,
|
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
|
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
|
data_path: str = DEFAULT_DATA_PATH,
|
||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
video_path: str = DEFAULT_VIDEO_PATH,
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
@@ -133,8 +175,9 @@ def info_factory(features_factory):
|
|||||||
"total_frames": total_frames,
|
"total_frames": total_frames,
|
||||||
"total_tasks": total_tasks,
|
"total_tasks": total_tasks,
|
||||||
"total_videos": total_videos,
|
"total_videos": total_videos,
|
||||||
"total_chunks": total_chunks,
|
|
||||||
"chunks_size": chunks_size,
|
"chunks_size": chunks_size,
|
||||||
|
"data_files_size_in_mb": data_files_size_in_mb,
|
||||||
|
"video_files_size_in_mb": video_files_size_in_mb,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"data_path": data_path,
|
"data_path": data_path,
|
||||||
@@ -175,41 +218,45 @@ def stats_factory():
|
|||||||
return _create_stats
|
return _create_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
# @pytest.fixture(scope="session")
|
||||||
def episodes_stats_factory(stats_factory):
|
# def episodes_stats_factory(stats_factory):
|
||||||
def _create_episodes_stats(
|
# def _create_episodes_stats(
|
||||||
features: dict[str],
|
# features: dict[str],
|
||||||
total_episodes: int = 3,
|
# total_episodes: int = 3,
|
||||||
) -> dict:
|
# ) -> dict:
|
||||||
episodes_stats = {}
|
|
||||||
for episode_index in range(total_episodes):
|
|
||||||
episodes_stats[episode_index] = {
|
|
||||||
"episode_index": episode_index,
|
|
||||||
"stats": stats_factory(features),
|
|
||||||
}
|
|
||||||
return episodes_stats
|
|
||||||
|
|
||||||
return _create_episodes_stats
|
# def _generator(total_episodes):
|
||||||
|
# for ep_idx in range(total_episodes):
|
||||||
|
# flat_ep_stats = flatten_dict(stats_factory(features))
|
||||||
|
# flat_ep_stats["episode_index"] = ep_idx
|
||||||
|
# yield flat_ep_stats
|
||||||
|
|
||||||
|
# # Simpler to rely on generator instead of from_dict
|
||||||
|
# return Dataset.from_generator(lambda: _generator(total_episodes))
|
||||||
|
|
||||||
|
# return _create_episodes_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def tasks_factory():
|
def tasks_factory():
|
||||||
def _create_tasks(total_tasks: int = 3) -> int:
|
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
|
||||||
tasks = {}
|
ids = list(range(total_tasks))
|
||||||
for task_index in range(total_tasks):
|
tasks = [f"Perform action {i}." for i in ids]
|
||||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
df = pd.DataFrame({"task_index": ids}, index=tasks)
|
||||||
tasks[task_index] = task_dict
|
return df
|
||||||
return tasks
|
|
||||||
|
|
||||||
return _create_tasks
|
return _create_tasks
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episodes_factory(tasks_factory):
|
def episodes_factory(tasks_factory, stats_factory):
|
||||||
def _create_episodes(
|
def _create_episodes(
|
||||||
|
features: dict[str],
|
||||||
|
fps: int = DEFAULT_FPS,
|
||||||
total_episodes: int = 3,
|
total_episodes: int = 3,
|
||||||
total_frames: int = 400,
|
total_frames: int = 400,
|
||||||
tasks: dict | None = None,
|
video_keys: list[str] | None = None,
|
||||||
|
tasks: pd.DataFrame | None = None,
|
||||||
multi_task: bool = False,
|
multi_task: bool = False,
|
||||||
):
|
):
|
||||||
if total_episodes <= 0 or total_frames <= 0:
|
if total_episodes <= 0 or total_frames <= 0:
|
||||||
@@ -217,66 +264,142 @@ def episodes_factory(tasks_factory):
|
|||||||
if total_frames < total_episodes:
|
if total_frames < total_episodes:
|
||||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||||
|
|
||||||
if not tasks:
|
if tasks is None:
|
||||||
min_tasks = 2 if multi_task else 1
|
min_tasks = 2 if multi_task else 1
|
||||||
total_tasks = random.randint(min_tasks, total_episodes)
|
total_tasks = random.randint(min_tasks, total_episodes)
|
||||||
tasks = tasks_factory(total_tasks)
|
tasks = tasks_factory(total_tasks)
|
||||||
|
|
||||||
if total_episodes < len(tasks) and not multi_task:
|
num_tasks_available = len(tasks)
|
||||||
|
|
||||||
|
if total_episodes < num_tasks_available and not multi_task:
|
||||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||||
|
|
||||||
# Generate random lengths that sum up to total_length
|
# Generate random lengths that sum up to total_length
|
||||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||||
|
|
||||||
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
|
# Create empty dictionaries with all keys
|
||||||
num_tasks_available = len(tasks_list)
|
d = {
|
||||||
|
"episode_index": [],
|
||||||
|
"meta/episodes/chunk_index": [],
|
||||||
|
"meta/episodes/file_index": [],
|
||||||
|
"data/chunk_index": [],
|
||||||
|
"data/file_index": [],
|
||||||
|
"dataset_from_index": [],
|
||||||
|
"dataset_to_index": [],
|
||||||
|
"tasks": [],
|
||||||
|
"length": [],
|
||||||
|
}
|
||||||
|
if video_keys is not None:
|
||||||
|
for video_key in video_keys:
|
||||||
|
d[f"videos/{video_key}/chunk_index"] = []
|
||||||
|
d[f"videos/{video_key}/file_index"] = []
|
||||||
|
d[f"videos/{video_key}/from_timestamp"] = []
|
||||||
|
d[f"videos/{video_key}/to_timestamp"] = []
|
||||||
|
|
||||||
episodes = {}
|
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
||||||
remaining_tasks = tasks_list.copy()
|
d[stats_key] = []
|
||||||
|
|
||||||
|
num_frames = 0
|
||||||
|
remaining_tasks = list(tasks.index)
|
||||||
for ep_idx in range(total_episodes):
|
for ep_idx in range(total_episodes):
|
||||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
tasks_to_sample = remaining_tasks if len(remaining_tasks) > 0 else list(tasks.index)
|
||||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||||
if remaining_tasks:
|
if remaining_tasks:
|
||||||
for task in episode_tasks:
|
for task in episode_tasks:
|
||||||
remaining_tasks.remove(task)
|
remaining_tasks.remove(task)
|
||||||
|
|
||||||
episodes[ep_idx] = {
|
d["episode_index"].append(ep_idx)
|
||||||
"episode_index": ep_idx,
|
# TODO(rcadene): remove heuristic of only one file
|
||||||
"tasks": episode_tasks,
|
d["meta/episodes/chunk_index"].append(0)
|
||||||
"length": lengths[ep_idx],
|
d["meta/episodes/file_index"].append(0)
|
||||||
}
|
d["data/chunk_index"].append(0)
|
||||||
|
d["data/file_index"].append(0)
|
||||||
|
d["dataset_from_index"].append(num_frames)
|
||||||
|
d["dataset_to_index"].append(num_frames + lengths[ep_idx])
|
||||||
|
d["tasks"].append(episode_tasks)
|
||||||
|
d["length"].append(lengths[ep_idx])
|
||||||
|
|
||||||
return episodes
|
if video_keys is not None:
|
||||||
|
for video_key in video_keys:
|
||||||
|
d[f"videos/{video_key}/chunk_index"].append(0)
|
||||||
|
d[f"videos/{video_key}/file_index"].append(0)
|
||||||
|
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
|
||||||
|
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
|
||||||
|
|
||||||
|
# Add stats columns like "stats/action/max"
|
||||||
|
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
||||||
|
d[stats_key].append(stats)
|
||||||
|
|
||||||
|
num_frames += lengths[ep_idx]
|
||||||
|
|
||||||
|
return Dataset.from_dict(d)
|
||||||
|
|
||||||
return _create_episodes
|
return _create_episodes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def create_videos(info_factory, img_array_factory):
|
||||||
|
def _create_video_directory(
|
||||||
|
root: Path,
|
||||||
|
info: dict | None = None,
|
||||||
|
total_episodes: int = 3,
|
||||||
|
total_frames: int = 150,
|
||||||
|
total_tasks: int = 1,
|
||||||
|
):
|
||||||
|
if info is None:
|
||||||
|
info = info_factory(
|
||||||
|
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||||
|
)
|
||||||
|
|
||||||
|
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
|
||||||
|
for key, ft in video_feats.items():
|
||||||
|
# create and save images with identifiable content
|
||||||
|
tmp_dir = root / "tmp_images"
|
||||||
|
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for frame_index in range(info["total_frames"]):
|
||||||
|
content = f"{key}-{frame_index}"
|
||||||
|
img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content)
|
||||||
|
pil_img = PIL.Image.fromarray(img)
|
||||||
|
path = tmp_dir / f"frame-{frame_index:06d}.png"
|
||||||
|
pil_img.save(path)
|
||||||
|
|
||||||
|
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
|
||||||
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Use the global fps from info, not video-specific fps which might not exist
|
||||||
|
encode_video_frames(tmp_dir, video_path, fps=info["fps"])
|
||||||
|
shutil.rmtree(tmp_dir)
|
||||||
|
|
||||||
|
return _create_video_directory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||||
def _create_hf_dataset(
|
def _create_hf_dataset(
|
||||||
features: dict | None = None,
|
features: dict | None = None,
|
||||||
tasks: list[dict] | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
episodes: list[dict] | None = None,
|
episodes: datasets.Dataset | None = None,
|
||||||
fps: int = DEFAULT_FPS,
|
fps: int = DEFAULT_FPS,
|
||||||
) -> datasets.Dataset:
|
) -> datasets.Dataset:
|
||||||
if not tasks:
|
if tasks is None:
|
||||||
tasks = tasks_factory()
|
tasks = tasks_factory()
|
||||||
if not episodes:
|
if features is None:
|
||||||
episodes = episodes_factory()
|
|
||||||
if not features:
|
|
||||||
features = features_factory()
|
features = features_factory()
|
||||||
|
if episodes is None:
|
||||||
|
episodes = episodes_factory(features, fps)
|
||||||
|
|
||||||
timestamp_col = np.array([], dtype=np.float32)
|
timestamp_col = np.array([], dtype=np.float32)
|
||||||
frame_index_col = np.array([], dtype=np.int64)
|
frame_index_col = np.array([], dtype=np.int64)
|
||||||
episode_index_col = np.array([], dtype=np.int64)
|
episode_index_col = np.array([], dtype=np.int64)
|
||||||
task_index = np.array([], dtype=np.int64)
|
task_index = np.array([], dtype=np.int64)
|
||||||
for ep_dict in episodes.values():
|
for ep_dict in episodes:
|
||||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||||
episode_index_col = np.concatenate(
|
episode_index_col = np.concatenate(
|
||||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||||
)
|
)
|
||||||
|
# Slightly incorrect, but for simplicity, we assign to all frames the first task defined in the episode metadata.
|
||||||
|
# TODO(rcadene): assign the tasks of the episode per chunks of frames
|
||||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||||
|
|
||||||
@@ -286,8 +409,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
if ft["dtype"] == "image":
|
if ft["dtype"] == "image":
|
||||||
robot_cols[key] = [
|
robot_cols[key] = [
|
||||||
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
img_array_factory(height=ft["shape"][1], width=ft["shape"][0], content=f"{key}-{i}")
|
||||||
for _ in range(len(index_col))
|
for i in range(len(index_col))
|
||||||
]
|
]
|
||||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||||
@@ -314,7 +437,6 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||||||
def lerobot_dataset_metadata_factory(
|
def lerobot_dataset_metadata_factory(
|
||||||
info_factory,
|
info_factory,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
episodes_stats_factory,
|
|
||||||
tasks_factory,
|
tasks_factory,
|
||||||
episodes_factory,
|
episodes_factory,
|
||||||
mock_snapshot_download_factory,
|
mock_snapshot_download_factory,
|
||||||
@@ -324,29 +446,29 @@ def lerobot_dataset_metadata_factory(
|
|||||||
repo_id: str = DUMMY_REPO_ID,
|
repo_id: str = DUMMY_REPO_ID,
|
||||||
info: dict | None = None,
|
info: dict | None = None,
|
||||||
stats: dict | None = None,
|
stats: dict | None = None,
|
||||||
episodes_stats: list[dict] | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
tasks: list[dict] | None = None,
|
episodes: datasets.Dataset | None = None,
|
||||||
episodes: list[dict] | None = None,
|
|
||||||
) -> LeRobotDatasetMetadata:
|
) -> LeRobotDatasetMetadata:
|
||||||
if not info:
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
if not stats:
|
if stats is None:
|
||||||
stats = stats_factory(features=info["features"])
|
stats = stats_factory(features=info["features"])
|
||||||
if not episodes_stats:
|
if tasks is None:
|
||||||
episodes_stats = episodes_stats_factory(
|
|
||||||
features=info["features"], total_episodes=info["total_episodes"]
|
|
||||||
)
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||||
if not episodes:
|
if episodes is None:
|
||||||
|
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||||
episodes = episodes_factory(
|
episodes = episodes_factory(
|
||||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
features=info["features"],
|
||||||
|
fps=info["fps"],
|
||||||
|
total_episodes=info["total_episodes"],
|
||||||
|
total_frames=info["total_frames"],
|
||||||
|
video_keys=video_keys,
|
||||||
|
tasks=tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_snapshot_download = mock_snapshot_download_factory(
|
mock_snapshot_download = mock_snapshot_download_factory(
|
||||||
info=info,
|
info=info,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
episodes_stats=episodes_stats,
|
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
episodes=episodes,
|
episodes=episodes,
|
||||||
)
|
)
|
||||||
@@ -368,7 +490,6 @@ def lerobot_dataset_metadata_factory(
|
|||||||
def lerobot_dataset_factory(
|
def lerobot_dataset_factory(
|
||||||
info_factory,
|
info_factory,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
episodes_stats_factory,
|
|
||||||
tasks_factory,
|
tasks_factory,
|
||||||
episodes_factory,
|
episodes_factory,
|
||||||
hf_dataset_factory,
|
hf_dataset_factory,
|
||||||
@@ -382,40 +503,48 @@ def lerobot_dataset_factory(
|
|||||||
total_frames: int = 150,
|
total_frames: int = 150,
|
||||||
total_tasks: int = 1,
|
total_tasks: int = 1,
|
||||||
multi_task: bool = False,
|
multi_task: bool = False,
|
||||||
|
use_videos: bool = True,
|
||||||
info: dict | None = None,
|
info: dict | None = None,
|
||||||
stats: dict | None = None,
|
stats: dict | None = None,
|
||||||
episodes_stats: list[dict] | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
tasks: list[dict] | None = None,
|
episodes_metadata: datasets.Dataset | None = None,
|
||||||
episode_dicts: list[dict] | None = None,
|
|
||||||
hf_dataset: datasets.Dataset | None = None,
|
hf_dataset: datasets.Dataset | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
if not info:
|
# Instantiate objects
|
||||||
|
if info is None:
|
||||||
info = info_factory(
|
info = info_factory(
|
||||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
total_episodes=total_episodes,
|
||||||
|
total_frames=total_frames,
|
||||||
|
total_tasks=total_tasks,
|
||||||
|
use_videos=use_videos,
|
||||||
)
|
)
|
||||||
if not stats:
|
if stats is None:
|
||||||
stats = stats_factory(features=info["features"])
|
stats = stats_factory(features=info["features"])
|
||||||
if not episodes_stats:
|
if tasks is None:
|
||||||
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||||
if not episode_dicts:
|
if episodes_metadata is None:
|
||||||
episode_dicts = episodes_factory(
|
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||||
|
episodes_metadata = episodes_factory(
|
||||||
|
features=info["features"],
|
||||||
|
fps=info["fps"],
|
||||||
total_episodes=info["total_episodes"],
|
total_episodes=info["total_episodes"],
|
||||||
total_frames=info["total_frames"],
|
total_frames=info["total_frames"],
|
||||||
|
video_keys=video_keys,
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
multi_task=multi_task,
|
multi_task=multi_task,
|
||||||
)
|
)
|
||||||
if not hf_dataset:
|
if hf_dataset is None:
|
||||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
hf_dataset = hf_dataset_factory(
|
||||||
|
features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write data on disk
|
||||||
mock_snapshot_download = mock_snapshot_download_factory(
|
mock_snapshot_download = mock_snapshot_download_factory(
|
||||||
info=info,
|
info=info,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
episodes_stats=episodes_stats,
|
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
episodes=episode_dicts,
|
episodes=episodes_metadata,
|
||||||
hf_dataset=hf_dataset,
|
hf_dataset=hf_dataset,
|
||||||
)
|
)
|
||||||
mock_metadata = lerobot_dataset_metadata_factory(
|
mock_metadata = lerobot_dataset_metadata_factory(
|
||||||
@@ -423,9 +552,8 @@ def lerobot_dataset_factory(
|
|||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
info=info,
|
info=info,
|
||||||
stats=stats,
|
stats=stats,
|
||||||
episodes_stats=episodes_stats,
|
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
episodes=episode_dicts,
|
episodes=episodes_metadata,
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||||
|
|||||||
110
tests/fixtures/files.py
vendored
110
tests/fixtures/files.py
vendored
@@ -11,92 +11,82 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import jsonlines
|
import pandas as pd
|
||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
EPISODES_PATH,
|
write_episodes,
|
||||||
EPISODES_STATS_PATH,
|
write_hf_dataset,
|
||||||
INFO_PATH,
|
write_info,
|
||||||
STATS_PATH,
|
write_stats,
|
||||||
TASKS_PATH,
|
write_tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def info_path(info_factory):
|
def create_info(info_factory):
|
||||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
def _create_info(dir: Path, info: dict | None = None):
|
||||||
if not info:
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
fpath = dir / INFO_PATH
|
write_info(info, dir)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(fpath, "w") as f:
|
|
||||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_info_json_file
|
return _create_info
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def stats_path(stats_factory):
|
def create_stats(stats_factory):
|
||||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
def _create_stats(dir: Path, stats: dict | None = None):
|
||||||
if not stats:
|
if stats is None:
|
||||||
stats = stats_factory()
|
stats = stats_factory()
|
||||||
fpath = dir / STATS_PATH
|
write_stats(stats, dir)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(fpath, "w") as f:
|
|
||||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_stats_json_file
|
return _create_stats
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.fixture(scope="session")
|
||||||
|
# def create_episodes_stats(episodes_stats_factory):
|
||||||
|
# def _create_episodes_stats(dir: Path, episodes_stats: Dataset | None = None):
|
||||||
|
# if episodes_stats is None:
|
||||||
|
# episodes_stats = episodes_stats_factory()
|
||||||
|
# write_episodes_stats(episodes_stats, dir)
|
||||||
|
|
||||||
|
# return _create_episodes_stats
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episodes_stats_path(episodes_stats_factory):
|
def create_tasks(tasks_factory):
|
||||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
|
||||||
if not episodes_stats:
|
if tasks is None:
|
||||||
episodes_stats = episodes_stats_factory()
|
|
||||||
fpath = dir / EPISODES_STATS_PATH
|
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
|
||||||
writer.write_all(episodes_stats.values())
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_episodes_stats_jsonl_file
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def tasks_path(tasks_factory):
|
|
||||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory()
|
tasks = tasks_factory()
|
||||||
fpath = dir / TASKS_PATH
|
write_tasks(tasks, dir)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
|
||||||
writer.write_all(tasks.values())
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_tasks_jsonl_file
|
return _create_tasks
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def episode_path(episodes_factory):
|
def create_episodes(episodes_factory):
|
||||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
|
||||||
if not episodes:
|
if episodes is None:
|
||||||
|
# TODO(rcadene): add features, fps as arguments
|
||||||
episodes = episodes_factory()
|
episodes = episodes_factory()
|
||||||
fpath = dir / EPISODES_PATH
|
write_episodes(episodes, dir)
|
||||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with jsonlines.open(fpath, "w") as writer:
|
|
||||||
writer.write_all(episodes.values())
|
|
||||||
return fpath
|
|
||||||
|
|
||||||
return _create_episodes_jsonl_file
|
return _create_episodes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def create_hf_dataset(hf_dataset_factory):
|
||||||
|
def _create_hf_dataset(dir: Path, hf_dataset: datasets.Dataset | None = None):
|
||||||
|
if hf_dataset is None:
|
||||||
|
hf_dataset = hf_dataset_factory()
|
||||||
|
write_hf_dataset(hf_dataset, dir)
|
||||||
|
|
||||||
|
return _create_hf_dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@@ -104,7 +94,8 @@ def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||||||
def _create_single_episode_parquet(
|
def _create_single_episode_parquet(
|
||||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||||
) -> Path:
|
) -> Path:
|
||||||
if not info:
|
raise NotImplementedError()
|
||||||
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
if hf_dataset is None:
|
if hf_dataset is None:
|
||||||
hf_dataset = hf_dataset_factory()
|
hf_dataset = hf_dataset_factory()
|
||||||
@@ -127,7 +118,8 @@ def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
|||||||
def _create_multi_episode_parquet(
|
def _create_multi_episode_parquet(
|
||||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||||
) -> Path:
|
) -> Path:
|
||||||
if not info:
|
raise NotImplementedError()
|
||||||
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
if hf_dataset is None:
|
if hf_dataset is None:
|
||||||
hf_dataset = hf_dataset_factory()
|
hf_dataset = hf_dataset_factory()
|
||||||
|
|||||||
128
tests/fixtures/hub.py
vendored
128
tests/fixtures/hub.py
vendored
@@ -14,15 +14,17 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub.utils import filter_repo_objects
|
from huggingface_hub.utils import filter_repo_objects
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
EPISODES_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
EPISODES_STATS_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
|
DEFAULT_TASKS_PATH,
|
||||||
|
DEFAULT_VIDEO_PATH,
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
STATS_PATH,
|
STATS_PATH,
|
||||||
TASKS_PATH,
|
|
||||||
)
|
)
|
||||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||||
|
|
||||||
@@ -30,17 +32,16 @@ from tests.fixtures.constants import LEROBOT_TEST_DIR
|
|||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def mock_snapshot_download_factory(
|
def mock_snapshot_download_factory(
|
||||||
info_factory,
|
info_factory,
|
||||||
info_path,
|
create_info,
|
||||||
stats_factory,
|
stats_factory,
|
||||||
stats_path,
|
create_stats,
|
||||||
episodes_stats_factory,
|
|
||||||
episodes_stats_path,
|
|
||||||
tasks_factory,
|
tasks_factory,
|
||||||
tasks_path,
|
create_tasks,
|
||||||
episodes_factory,
|
episodes_factory,
|
||||||
episode_path,
|
create_episodes,
|
||||||
single_episode_parquet_path,
|
|
||||||
hf_dataset_factory,
|
hf_dataset_factory,
|
||||||
|
create_hf_dataset,
|
||||||
|
create_videos,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||||
@@ -50,82 +51,91 @@ def mock_snapshot_download_factory(
|
|||||||
def _mock_snapshot_download_func(
|
def _mock_snapshot_download_func(
|
||||||
info: dict | None = None,
|
info: dict | None = None,
|
||||||
stats: dict | None = None,
|
stats: dict | None = None,
|
||||||
episodes_stats: list[dict] | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
tasks: list[dict] | None = None,
|
episodes: datasets.Dataset | None = None,
|
||||||
episodes: list[dict] | None = None,
|
|
||||||
hf_dataset: datasets.Dataset | None = None,
|
hf_dataset: datasets.Dataset | None = None,
|
||||||
):
|
):
|
||||||
if not info:
|
if info is None:
|
||||||
info = info_factory()
|
info = info_factory()
|
||||||
if not stats:
|
if stats is None:
|
||||||
stats = stats_factory(features=info["features"])
|
stats = stats_factory(features=info["features"])
|
||||||
if not episodes_stats:
|
if tasks is None:
|
||||||
episodes_stats = episodes_stats_factory(
|
|
||||||
features=info["features"], total_episodes=info["total_episodes"]
|
|
||||||
)
|
|
||||||
if not tasks:
|
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||||
if not episodes:
|
if episodes is None:
|
||||||
episodes = episodes_factory(
|
episodes = episodes_factory(
|
||||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
features=info["features"],
|
||||||
|
fps=info["fps"],
|
||||||
|
total_episodes=info["total_episodes"],
|
||||||
|
total_frames=info["total_frames"],
|
||||||
|
tasks=tasks,
|
||||||
)
|
)
|
||||||
if not hf_dataset:
|
if hf_dataset is None:
|
||||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||||
|
|
||||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
|
||||||
path = Path(fpath)
|
|
||||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
|
||||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
|
||||||
return episode_index
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _mock_snapshot_download(
|
def _mock_snapshot_download(
|
||||||
repo_id: str,
|
repo_id: str, # TODO(rcadene): repo_id should be used no?
|
||||||
local_dir: str | Path | None = None,
|
local_dir: str | Path | None = None,
|
||||||
allow_patterns: str | list[str] | None = None,
|
allow_patterns: str | list[str] | None = None,
|
||||||
ignore_patterns: str | list[str] | None = None,
|
ignore_patterns: str | list[str] | None = None,
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
if not local_dir:
|
if local_dir is None:
|
||||||
local_dir = LEROBOT_TEST_DIR
|
local_dir = LEROBOT_TEST_DIR
|
||||||
|
|
||||||
# List all possible files
|
# List all possible files
|
||||||
all_files = []
|
all_files = [
|
||||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
INFO_PATH,
|
||||||
all_files.extend(meta_files)
|
STATS_PATH,
|
||||||
|
# TODO(rcadene): remove naive chunk 0 file 0 ?
|
||||||
|
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
||||||
|
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
||||||
|
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||||
|
]
|
||||||
|
|
||||||
data_files = []
|
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
|
||||||
for episode_dict in episodes.values():
|
for key in video_keys:
|
||||||
ep_idx = episode_dict["episode_index"]
|
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
|
||||||
ep_chunk = ep_idx // info["chunks_size"]
|
|
||||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
|
||||||
data_files.append(data_path)
|
|
||||||
all_files.extend(data_files)
|
|
||||||
|
|
||||||
allowed_files = filter_repo_objects(
|
allowed_files = filter_repo_objects(
|
||||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create allowed files
|
request_info = False
|
||||||
|
request_tasks = False
|
||||||
|
request_episodes = False
|
||||||
|
request_stats = False
|
||||||
|
request_data = False
|
||||||
|
request_videos = False
|
||||||
for rel_path in allowed_files:
|
for rel_path in allowed_files:
|
||||||
if rel_path.startswith("data/"):
|
if rel_path.startswith("meta/info.json"):
|
||||||
episode_index = _extract_episode_index_from_path(rel_path)
|
request_info = True
|
||||||
if episode_index is not None:
|
elif rel_path.startswith("meta/stats"):
|
||||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
request_stats = True
|
||||||
if rel_path == INFO_PATH:
|
elif rel_path.startswith("meta/tasks"):
|
||||||
_ = info_path(local_dir, info)
|
request_tasks = True
|
||||||
elif rel_path == STATS_PATH:
|
elif rel_path.startswith("meta/episodes"):
|
||||||
_ = stats_path(local_dir, stats)
|
request_episodes = True
|
||||||
elif rel_path == EPISODES_STATS_PATH:
|
elif rel_path.startswith("data/"):
|
||||||
_ = episodes_stats_path(local_dir, episodes_stats)
|
request_data = True
|
||||||
elif rel_path == TASKS_PATH:
|
elif rel_path.startswith("videos/"):
|
||||||
_ = tasks_path(local_dir, tasks)
|
request_videos = True
|
||||||
elif rel_path == EPISODES_PATH:
|
|
||||||
_ = episode_path(local_dir, episodes)
|
|
||||||
else:
|
else:
|
||||||
pass
|
raise ValueError(f"{rel_path} not supported.")
|
||||||
|
|
||||||
|
if request_info:
|
||||||
|
create_info(local_dir, info)
|
||||||
|
if request_stats:
|
||||||
|
create_stats(local_dir, stats)
|
||||||
|
if request_tasks:
|
||||||
|
create_tasks(local_dir, tasks)
|
||||||
|
if request_episodes:
|
||||||
|
create_episodes(local_dir, episodes)
|
||||||
|
if request_data:
|
||||||
|
create_hf_dataset(local_dir, hf_dataset)
|
||||||
|
if request_videos:
|
||||||
|
create_videos(root=local_dir, info=info)
|
||||||
|
|
||||||
return str(local_dir)
|
return str(local_dir)
|
||||||
|
|
||||||
return _mock_snapshot_download
|
return _mock_snapshot_download
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from pathlib import Path
|
|||||||
import einops
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from lerobot import available_policies
|
from lerobot import available_policies
|
||||||
@@ -69,7 +68,11 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
info = info_factory(
|
info = info_factory(
|
||||||
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
|
total_episodes=1,
|
||||||
|
total_frames=1,
|
||||||
|
total_tasks=1,
|
||||||
|
camera_features=camera_features,
|
||||||
|
motor_features=motor_features,
|
||||||
)
|
)
|
||||||
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
||||||
return ds_meta
|
return ds_meta
|
||||||
@@ -138,14 +141,14 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||||
and for now we add tests as we see fit.
|
and for now we add tests as we see fit.
|
||||||
"""
|
"""
|
||||||
|
policy_kwargs["device"] = DEVICE
|
||||||
|
|
||||||
train_cfg = TrainPipelineConfig(
|
train_cfg = TrainPipelineConfig(
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||||
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
|
policy=make_policy_config(policy_name, **policy_kwargs),
|
||||||
env=make_env_config(env_name, **env_kwargs),
|
env=make_env_config(env_name, **env_kwargs),
|
||||||
)
|
)
|
||||||
train_cfg.validate()
|
|
||||||
|
|
||||||
# Check that we can make the policy object.
|
# Check that we can make the policy object.
|
||||||
dataset = make_dataset(train_cfg)
|
dataset = make_dataset(train_cfg)
|
||||||
@@ -214,7 +217,7 @@ def test_act_backbone_lr():
|
|||||||
cfg = TrainPipelineConfig(
|
cfg = TrainPipelineConfig(
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001, push_to_hub=False),
|
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||||
)
|
)
|
||||||
cfg.validate() # Needed for auto-setting some parameters
|
cfg.validate() # Needed for auto-setting some parameters
|
||||||
|
|
||||||
@@ -410,17 +413,7 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
|||||||
4. Check that this test now passes.
|
4. Check that this test now passes.
|
||||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
||||||
|
|
||||||
NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact
|
|
||||||
is out of date. For example, some PyTorch versions have different randomness, see this PR:
|
|
||||||
https://github.com/huggingface/lerobot/pull/1127.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# NOTE: ACT policy has different randomness, after PyTorch 2.7.0
|
|
||||||
if policy_name == "act" and version.parse(torch.__version__) < version.parse("2.7.0"):
|
|
||||||
pytest.skip(f"Skipping act policy test with PyTorch {torch.__version__}. Requires PyTorch >= 2.7.0")
|
|
||||||
|
|
||||||
ds_name = ds_repo_id.split("/")[-1]
|
ds_name = ds_repo_id.split("/")[-1]
|
||||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||||
|
|||||||
Reference in New Issue
Block a user