Compare commits

..

6 Commits

Author SHA1 Message Date
Francesco Capuano
6eaf6a861a fix: single level loop 2025-09-24 01:06:13 +02:00
Francesco Capuano
cdd6cb606c add: inference benchmark 2025-09-23 22:34:52 +02:00
Jade Choghari
f6cd24be17 update
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-09-23 21:52:15 +02:00
Jade Choghari
54c6b8ae52 add file
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-09-23 21:52:14 +02:00
Steven Palma
c9787bd98a feat(script): add entry point for image transform viz (#2007)
* feat(Scripts): add entry point for img transform viz

* chore(style): pre-commit style
2025-09-23 18:47:36 +02:00
Steven Palma
c435d3cebc feat(script): add entry point for dataset viz (#2006)
* chore(scripts): rename script dataset viz

* feat(scripts): add entry point for dataset-viz

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-09-23 18:46:27 +02:00
13 changed files with 488 additions and 742 deletions

View File

@@ -202,7 +202,7 @@ Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
```bash
python -m lerobot.scripts.visualize_dataset \
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0
```
@@ -210,7 +210,7 @@ python -m lerobot.scripts.visualize_dataset \
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
```bash
python -m lerobot.scripts.visualize_dataset \
lerobot-dataset-viz \
--repo-id lerobot/pusht \
--root ./my_local_data_dir \
--local-files-only 1 \
@@ -221,7 +221,7 @@ It will open `rerun.io` and display the camera streams, robot states and actions
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
Our script can also visualize datasets stored on a distant server. See `python -m lerobot.scripts.visualize_dataset --help` for more instructions.
Our script can also visualize datasets stored on a distant server. See `lerobot-dataset-viz --help` for more instructions.
### The `LeRobotDataset` format

View File

@@ -0,0 +1,378 @@
"""
Benchmark memory footprint and inference latency of a policy on arbitrary devices.
This script loads a pretrained policy directly (similar to the async inference server)
and generates dummy input data based on the policy's input_features to perform
accurate benchmarking without requiring datasets.
"""
import argparse
import os
import signal
import statistics
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
import psutil
import torch
from tqdm import tqdm
from lerobot.configs.types import FeatureType
from lerobot.policies.factory import get_policy_class
from lerobot.policies.pretrained import PreTrainedPolicy
class TimeoutException:
pass
@contextmanager
def timeout(seconds):
def signal_handler(signum, frame):
raise TimeoutException(f"Timed out after {seconds} seconds")
# On Windows, signal is not available, so we can't use this timeout mechanism
if not hasattr(signal, "SIGALRM"):
yield
return
old_handler = signal.signal(signal.SIGALRM, signal_handler)
try:
# signal.alarm expects integer seconds
# for float seconds, we can use setitimer
signal.setitimer(signal.ITIMER_REAL, seconds)
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old_handler)
def bytes_to_human(n: int) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if n < 1024:
return f"{n:.2f} {unit}"
n /= 1024
return f"{n:.2f} PB"
def percentile(values: list[float], p: float) -> float:
if not values:
return float("nan")
k = (len(values) - 1) * (p / 100.0)
f = int(k)
c = min(f + 1, len(values) - 1)
if f == c:
return values[f]
return values[f] + (values[c] - values[f]) * (k - f)
def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dict:
"""Generate dummy observation data based on policy input features."""
dummy_obs = {}
for key, feature in input_features.items():
shape = feature.shape
if feature.type == FeatureType.VISUAL:
# Images: random values in [0, 1] range (already normalized)
dummy_obs[key] = torch.rand(shape, dtype=torch.float32, device=device)
elif feature.type in [FeatureType.STATE, FeatureType.ACTION, FeatureType.ENV]:
# State/action/env: random normal distribution
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
else:
# Default: random normal for unknown types
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
# Add batch dimension
for key in dummy_obs:
dummy_obs[key] = dummy_obs[key].unsqueeze(0)
# Add task string for language-conditioned policies
dummy_obs["task"] = ""
return dummy_obs
def main():
parser = argparse.ArgumentParser(description="Policy inference benchmark")
parser.add_argument(
"--policy-id", type=str, required=True, help="Model ID or local path to pretrained policy"
)
parser.add_argument(
"--policy-type", type=str, required=True, help="Type of policy (smolvla, act, diffusion, etc.)"
)
parser.add_argument(
"--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument(
"--num-samples", type=int, default=100, help="Number of inference samples to benchmark"
)
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup samples (not timed)")
parser.add_argument(
"--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results"
)
parser.add_argument(
"--timeout",
type=float,
default=0.3,
help="Timeout for each inference pass in seconds (default: 0.3s = 300ms)",
)
args = parser.parse_args()
# Seed & deterministic-ish setup
torch.manual_seed(args.seed)
if args.device == "cuda":
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False # leave False to avoid perf cliffs
# Resolve device availability
device = args.device.lower()
if device == "cuda" and not torch.cuda.is_available():
print("[!] CUDA requested but unavailable. Falling back to CPU.")
device = "cpu"
elif device == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
print("[!] MPS requested but unavailable. Falling back to CPU.")
device = "cpu"
use_cuda = device == "cuda"
# Create output directory and log file
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
policy_name = args.policy_id.replace("/", "_").replace("\\", "_")
log_file = output_dir / f"benchmark_{args.policy_type}_{policy_name}_{device}_{timestamp}.txt"
# Load policy directly from pretrained (similar to async inference server)
print(f"Loading policy {args.policy_type} from {args.policy_id}...")
policy_class = get_policy_class(args.policy_type)
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
policy.eval()
policy.to(device)
print(f"Policy loaded on {device}")
print(f"Input features: {list(policy.config.input_features.keys())}")
print(f"Output features: {list(policy.config.output_features.keys())}")
# Generate dummy observation based on policy input features
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
dummy_observation["task"] = ""
# Helper to sync for fair timings
def _sync(dev_=device):
if dev_ == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif dev_ == "mps" and hasattr(torch, "mps"):
try:
torch.mps.synchronize()
except AttributeError:
pass # MPS sync not available in this PyTorch version
# Warmup (to stabilize kernels/caches)
print("Warming up...")
with torch.no_grad():
policy.reset()
for _ in range(args.warmup):
_ = policy.select_action(dummy_observation)
_sync()
# Memory footprint before timing
process = psutil.Process(os.getpid())
rss_before = process.memory_info().rss
if use_cuda:
torch.cuda.reset_peak_memory_stats()
# PyTorch timing with Event objects for more accurate GPU timing
print(f"Running benchmark: {args.num_samples} samples...")
if use_cuda:
# Use CUDA Events for precise GPU timing
start_events = []
end_events = []
timeout_count = 0
with torch.no_grad():
for forward in tqdm(range(args.num_samples), desc="Trials"):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
try:
with timeout(args.timeout):
start_event.record()
_ = policy.select_action(dummy_observation)
end_event.record()
start_events.append(start_event)
end_events.append(end_event)
except TimeoutException:
timeout_count += 1
# Add placeholder for timeout
start_events.append(None)
end_events.append(None)
print(f"\n[!] Timeout on forward {forward + 1}")
continue
# Synchronize and collect timing results
torch.cuda.synchronize()
per_forward_ms = []
for start_event, end_event in zip(start_events, end_events, strict=True):
if start_event is None:
per_forward_ms.append(args.timeout * 1000)
else:
per_forward_ms.append(start_event.elapsed_time(end_event))
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
else:
# Use simple time.perf_counter for CPU/MPS timing with timeout
import time
per_forward_ms = []
timeout_count = 0
with torch.no_grad():
for sample in tqdm(range(args.num_samples), desc="Samples"):
try:
with timeout(args.timeout):
start_time = time.perf_counter()
_ = policy.select_action(dummy_observation)
end_time = time.perf_counter()
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
except TimeoutException:
timeout_count += 1
per_forward_ms.append(args.timeout * 1000)
print(f"\n[!] Timeout on sample {sample + 1}")
continue
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
# Memory footprint after timing
rss_after = process.memory_info().rss
rss_delta = rss_after - rss_before
cuda_peak = torch.cuda.max_memory_allocated() if use_cuda else 0
# Sort timing results for percentile calculations
per_forward_ms_sorted = sorted(per_forward_ms)
mean_ms = statistics.fmean(per_forward_ms) if per_forward_ms else float("nan")
std_ms = statistics.pstdev(per_forward_ms) if len(per_forward_ms) > 1 else 0.0
min_ms = per_forward_ms_sorted[0] if per_forward_ms_sorted else float("nan")
max_ms = per_forward_ms_sorted[-1] if per_forward_ms_sorted else float("nan")
p50_ms = percentile(per_forward_ms_sorted, 50)
p95_ms = percentile(per_forward_ms_sorted, 95)
# Model size
num_params = sum(p.numel() for p in policy.parameters())
# Prepare results for logging
results = {
"timestamp": datetime.now().isoformat(),
"policy_type": args.policy_type,
"policy_id": args.policy_id,
"device": device,
"num_trials": args.num_samples,
"forwards_per_trial": 1,
"warmup": args.warmup,
"timeout_ms": args.timeout * 1000,
"seed": args.seed,
"num_params": num_params,
"timeout_count": timeout_count,
"latency_mean_ms": mean_ms,
"latency_std_ms": std_ms,
"latency_min_ms": min_ms,
"latency_max_ms": max_ms,
"latency_p50_ms": p50_ms,
"latency_p95_ms": p95_ms,
"cpu_rss_before": rss_before,
"cpu_rss_after": rss_after,
"cpu_rss_delta": rss_delta,
"cuda_peak_alloc": cuda_peak,
"input_features": list(policy.config.input_features.keys()),
"output_features": list(policy.config.output_features.keys()),
}
# Format and write results to log file
log_content = f"""
=== LeRobot Policy Inference Benchmark ===
Timestamp: {results["timestamp"]}
Policy: {results["policy_type"]} ({results["policy_id"]})
Device: {results["device"]}
Seed: {results["seed"]}
=== Model Information ===
Parameters: {results["num_params"]:,}
Input Features: {", ".join(results["input_features"])}
Output Features: {", ".join(results["output_features"])}
=== Benchmark Configuration ===
Samples: {results["num_trials"]}
Warmup: {results["warmup"]}
Total Measurements: {len(per_forward_ms)}
Timeout: {results["timeout_ms"]:.1f}ms
Timeouts: {results["timeout_count"]} / {results["num_trials"]}
=== Latency Results (ms) ===
Mean: {results["latency_mean_ms"]:.3f}
Std Dev: {results["latency_std_ms"]:.3f}
Min: {results["latency_min_ms"]:.3f}
Max: {results["latency_max_ms"]:.3f}
P50: {results["latency_p50_ms"]:.3f}
P95: {results["latency_p95_ms"]:.3f}
=== Memory Footprint ===
CPU RSS Before: {bytes_to_human(results["cpu_rss_before"])}
CPU RSS After: {bytes_to_human(results["cpu_rss_after"])}{bytes_to_human(results["cpu_rss_delta"])})
"""
if use_cuda:
log_content += f"CUDA Peak: {bytes_to_human(results['cuda_peak_alloc'])} (reset before timing)\n"
log_content += f"""
=== Raw Timing Data (first 20 measurements, ms) ===
{", ".join(f"{t:.3f}" for t in per_forward_ms[:20])}
{"..." if len(per_forward_ms) > 20 else ""}
=== Summary Statistics ===
Timing Method: {"CUDA Events" if use_cuda else "torch.utils.benchmark.Timer"}
Device Available: {torch.cuda.is_available() if device == "cuda" else torch.backends.mps.is_available() if device == "mps" else True}
PyTorch Version: {torch.__version__}
Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
"""
# Write to log file
with open(log_file, "w") as f:
f.write(log_content)
# Print to console (shorter version)
print("\n=== Inference Benchmark Results ===")
print(f"Policy: {args.policy_type} ({args.policy_id})")
print(f"Device: {device}")
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
print(f"Model params: {num_params:,}")
print("\nLatency per forward (ms):")
print(f" mean: {mean_ms:.3f} std: {std_ms:.3f}")
print(f" min: {min_ms:.3f} max: {max_ms:.3f}")
print(f" p50: {p50_ms:.3f} p95: {p95_ms:.3f}")
print("\nMemory footprint:")
print(f" CPU RSS before: {bytes_to_human(rss_before)}")
print(f" CPU RSS after : {bytes_to_human(rss_after)}{bytes_to_human(rss_delta)})")
if use_cuda:
print(
f" CUDA peak allocated: {bytes_to_human(cuda_peak)} "
f"(reset by reset_peak_memory_stats before timing)"
)
print(f"\nResults saved to: {log_file}")
print("Benchmark completed successfully!")
if __name__ == "__main__":
main()

View File

@@ -246,7 +246,7 @@ You can also use any `torchvision.transforms.v2` transform by passing it directl
Use the visualization script to preview how transforms affect your data:
```bash
python -m lerobot.scripts.visualize_image_transforms \
lerobot-imgtransform-viz \
--repo-id=your-username/your-dataset \
--output-dir=./transform_examples \
--n-examples=5

View File

@@ -1,66 +0,0 @@
from lerobot.datasets.lerobot_dataset import MultiLeRobotDataset
REPO_A = "lerobot/pusht"
REPO_B = "lerobot/aloha_mobile_cabinet" # replace with the actual repo id
feature_keys_mapping = {
REPO_A: { # pusht (1 camera, 2-dim)
"action": "actions",
"observation.state": "obs_state",
"observation.image": "obs_image.cam_high",
},
REPO_B: { # dual arm (3 cameras, 14-dim)
"action": "actions",
"observation.state": "obs_state",
"observation.images.cam_high": "obs_image.cam_high",
"observation.images.cam_left_wrist": "obs_image.cam_left_wrist",
"observation.images.cam_right_wrist": "obs_image.cam_right_wrist",
},
}
from torchvision.transforms.v2 import Compose, ToImage, Resize
image_tf = Compose([
ToImage(), # converts to tensor if needed
Resize((224, 224)), # unify sizes across datasets (96x96 vs 480x640)
])
from torch.utils.data import DataLoader
dataset = MultiLeRobotDataset(
repo_ids=[REPO_A, REPO_B],
image_transforms=image_tf, # ensures same HxW
feature_keys_mapping=feature_keys_mapping,
train_on_all_features=True, # keep union of cameras; zero-fill missing
# optional: override if you want fixed maxima; else inferred:
# max_action_dim=14,
# max_state_dim=14,
max_action_dim=14,
max_state_dim=14,
max_image_dim=224,
ignore_keys=[
"next.*", # drop reward/done/success
"index",
"timestamp",
"videos/*", # drop all video metadata
"observation.effort", # 👈 drop effort everywhere
],
)
breakpoint()
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
for _ in range(100):
batch = next(iter(loader))
breakpoint()
# vectors padded to maxima (pusht:2 -> 14; dual-arm:14 -> 14)
assert batch["actions"].shape[-1] == 14
assert batch["obs_state"].shape[-1] == 14
assert batch["actions_padding_mask"].shape[-1] == 14
assert batch["obs_state_padding_mask"].shape[-1] == 14
# cameras: all canonical keys exist; pusht will have wrists zero-filled
for cam in ["obs_image.cam_high", "obs_image.cam_left_wrist", "obs_image.cam_right_wrist"]:
assert cam in batch
assert f"{cam}_is_pad" in batch
# images should all be 3x224x224 (or your transforms size)
img = batch[cam]
assert img.ndim in (4, 5) # (B,C,H,W) or (B,T,C,H,W) depending on your loader

View File

@@ -1,16 +0,0 @@
# storage / caches
RAID=/raid/jade
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
export HF_HOME=$RAID/.cache/huggingface
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
export WANDB_CACHE_DIR=$RAID/.cache/wandb
export TMPDIR=$RAID/.cache/tmp
mkdir -p $TMPDIR
export WANDB_MODE=offline
# export HF_DATASETS_OFFLINE=1
# export HF_HUB_OFFLINE=1
export TOKENIZERS_PARALLELISM=false
export MUJOCO_GL=egl
python examples/tester.py

View File

@@ -171,7 +171,9 @@ lerobot-setup-motors="lerobot.setup_motors:main"
lerobot-teleoperate="lerobot.teleoperate:main"
lerobot-eval="lerobot.scripts.eval:main"
lerobot-train="lerobot.scripts.train:main"
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.packages.find]

View File

@@ -174,79 +174,3 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
return aggregated_stats
import numpy as np
def aggregate_stats_multi(
stats_list: list[dict[str, dict]],
max_action_dim: int | None = None,
max_state_dim: int | None = None,
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
Supports heterogeneous robots by padding action/state stats to the max dim.
The final stats will have the union of all data keys from each of the stats dicts.
- new_min = elementwise min across datasets
- new_max = elementwise max across datasets
- new_mean = weighted mean (by count)
- new_std = recomputed from total variance
"""
data_keys = {key for stats in stats_list for key in stats}
aggregated_stats = {key: {} for key in data_keys}
def _pad(arr: np.ndarray, target: int) -> np.ndarray:
if arr.ndim == 0: # scalar
return arr
if target is None or target <= 0 or arr.shape[-1] == target:
return arr
pad_width = [(0, 0)] * arr.ndim
pad_width[-1] = (0, target - arr.shape[-1])
return np.pad(arr, pad_width, mode="constant")
for key in data_keys:
stats_with_key = [stats[key] for stats in stats_list if key in stats]
# decide if this key should be padded
target_dim = None
if "action" in key and max_action_dim:
target_dim = max_action_dim
elif "state" in key and max_state_dim:
target_dim = max_state_dim
padded = []
counts = []
for s in stats_with_key:
mean = _pad(np.array(s["mean"]), target_dim)
std = _pad(np.array(s["std"]), target_dim)
min_ = _pad(np.array(s["min"]), target_dim)
max_ = _pad(np.array(s["max"]), target_dim)
count = s.get("count", 1)
padded.append(dict(mean=mean, std=std, min=min_, max=max_, count=count))
counts.append(count)
counts = np.array(counts, dtype=np.float64)
total_count = counts.sum()
means = np.stack([p["mean"] for p in padded])
stds = np.stack([p["std"] for p in padded])
mins = np.stack([p["min"] for p in padded])
maxs = np.stack([p["max"] for p in padded])
# weighted mean (broadcast weights properly)
new_mean = np.average(means, axis=0, weights=counts)
new_var = np.average(stds**2 + (means - new_mean)**2, axis=0, weights=counts)
new_std = np.sqrt(new_var)
aggregated_stats[key] = {
"min": mins.min(axis=0),
"max": maxs.max(axis=0),
"mean": new_mean,
"std": new_std,
"count": int(total_count),
}
return aggregated_stats

View File

@@ -31,7 +31,6 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from collections import defaultdict
from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
@@ -82,12 +81,7 @@ from lerobot.datasets.video_utils import (
)
CODEBASE_VERSION = "v3.0"
OBS_IMAGE = "observation.image"
OBS_IMAGE_2 = "observation.image_2"
OBS_IMAGE_3 = "observation.image_3"
OBS_STATE = "observation.state"
OBS_ENV_STATE = "observation.env_state"
ACTION = "action"
class LeRobotDatasetMetadata:
def __init__(
@@ -1328,139 +1322,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
return obj
ROBOT_TYPE_KEYS_MAPPING = {
"lerobot/stanford_hydra_dataset": "static_single_arm",
"lerobot/iamlab_cmu_pickup_insert": "static_single_arm",
"lerobot/berkeley_fanuc_manipulation": "static_single_arm",
"lerobot/toto": "static_single_arm",
"lerobot/roboturk": "static_single_arm",
"lerobot/jaco_play": "static_single_arm",
"lerobot/taco_play": "static_single_arm_7statedim",
}
class MultiLeRobotDatasetMeta:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
keys_to_max_dim: dict[str, int],
train_on_all_features: bool = False,
):
self.repo_ids = repo_ids
self.keys_to_max_dim = keys_to_max_dim
self.train_on_all_features = train_on_all_features
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
# assign robot_type if missing
for ds in datasets:
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
ds.robot_type = ds.meta.info["robot_type"]
# step 1: compute disabled features
self.disabled_features = set()
if not self.train_on_all_features:
intersection = set(datasets[0].features)
for ds in datasets:
intersection.intersection_update(ds.features)
if not intersection:
raise RuntimeError("No common features across datasets.")
for repo_id, ds in zip(repo_ids, datasets, strict=False):
extra = set(ds.features) - intersection
logging.warning(f"Disabling {extra} for repo {repo_id}")
self.disabled_features.update(extra)
# step 2: build union_features excluding disabled
self.union_features = {}
for ds in datasets:
for k, v in ds.features.items():
if k not in self.disabled_features:
self.union_features[k] = v
# step 3: reshape feature schema
self.features = reshape_features_to_max_dim(
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
)
# step 4: aggregate stats
self.stats = aggregate_stats_per_robot_type(datasets)
for robot_type_, stats_ in self.stats.items():
for feat_key, feat_stats in stats_.items():
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
for k, v in feat_stats.items():
pad_value = 0 if k in ["min", "mean"] else 1
self.stats[robot_type_][feat_key][k] = pad_tensor(
v,
max_size=self.keys_to_max_dim.get(feat_key, -1),
pad_dim=-1,
pad_value=pad_value,
)
# step 5: episodes & tasks
self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)}
self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)}
class MultiLeRobotDatasetCleaner:
def __init__(
self,
datasets: list[LeRobotDataset],
repo_ids: list[str],
sampling_weights: list[float],
datasets_repo_ids: list[str],
min_fps: int = 1,
max_fps: int = 100,
):
self.original_datasets = datasets
self.original_repo_ids = repo_ids
self.original_weights = sampling_weights
self.original_datasets_repo_ids = datasets_repo_ids
# step 1: remove datasets with invalid fps
# step 2: keep datasets with same features per robot type
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(
datasets
)
self.cleaned_datasets = consistent_datasets
self.keep_mask = keep_mask
self.cleaned_weights = [sampling_weights[i] for i in range(len(datasets)) if keep_mask[i]]
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(datasets)) if keep_mask[i]]
self.cleaned_datasets_repo_ids = [
datasets_repo_ids[i] for i in range(len(datasets)) if keep_mask[i]
]
self.cumulative_sizes = np.array(
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
)
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
# --- at the top of the file (same imports as before) ---
from collections import defaultdict
from typing import Callable
import copy
import numpy as np
import torch
import datasets
from pathlib import Path
# If you already have these in your codebase, reuse them
try:
from lerobot.common.constants import (
ACTION, OBS_ENV_STATE, OBS_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3
)
except Exception:
# Fallbacks if constants are already strings elsewhere
ACTION = "action"
OBS_ENV_STATE = "observation.env_state"
OBS_STATE = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGE_2 = "observation.image_2"
OBS_IMAGE_3 = "observation.image_3"
IGNORED_KEYS = ["observation.effort"]
class MultiLeRobotDataset(torch.utils.data.Dataset):
# ... keep your existing docstring ...
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
@@ -1468,253 +1336,99 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root: str | Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
delta_timestamps: dict[str, list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
video_backend: str | None = None,
# --- NEW: simple add-ons ---
sampling_weights: list[float] | None = None,
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
max_action_dim: int | None = None,
max_state_dim: int | None = None,
max_num_images: int | None = None,
max_image_dim: int | None = None,
train_on_all_features: bool = False,
min_fps: int = 1,
max_fps: int = 100,
ignore_keys: list[str] | None = None, # exact or glob patterns
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
for repo_id in repo_ids
]
# --- NEW: store mapping and simple knobs ---
self.feature_keys_mapping: dict[str, dict[str, str]] = feature_keys_mapping or {}
self.train_on_all_features = train_on_all_features
self.max_action_dim = max_action_dim
self.max_state_dim = max_state_dim
self.max_image_dim = max_image_dim
self.max_num_images = max_num_images # (optional, we dont enforce count, we enforce names)
self._ignore_patterns = list(ignore_keys or [])
# Build underlying single datasets
_datasets = []
datasets_repo_ids = []
self.sampling_weights = []
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_features.update(extra_keys)
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
assert len(sampling_weights) == len(repo_ids), (
"The number of sampling weights must match the number of datasets. "
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
)
for i, repo_id in enumerate(repo_ids):
try:
_datasets.append(
LeRobotDataset(
repo_id,
root=self.root / repo_id,
episodes=episodes.get(repo_id, None) if episodes else None,
image_transforms=image_transforms, # transforms applied inside single ds
delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
video_backend=video_backend,
)
)
datasets_repo_ids.append(repo_id)
self.sampling_weights.append(float(sampling_weights[i]))
except Exception as e:
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
print(
f"Finish loading {len(_datasets)} datasets, with sampling weights: "
f"{self.sampling_weights} corresponding to: {datasets_repo_ids}"
)
# Bookkeeping for mapping & canonical image inventory
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None
self._datasets = _datasets
self.datasets_repo_ids = datasets_repo_ids
# --- NEW: compute “canonical image keys” (targets across all mappings) ---
self._canonical_image_keys: set[str] = set()
self._source_keys_per_repo: dict[str, set[str]] = {}
self._target_keys_per_repo: dict[str, set[str]] = {}
for rid, mapping in self.feature_keys_mapping.items():
src_keys = set(mapping.keys())
tgt_keys = set(mapping.values())
self._source_keys_per_repo[rid] = src_keys
self._target_keys_per_repo[rid] = tgt_keys
# union of target names (we will ensure these exist at __getitem__)
self._canonical_image_keys |= {
k for k in tgt_keys if self._is_image_key_like(k)
}
# If user didnt give any mapping, fall back to native keys (no-ops)
if not self._canonical_image_keys and self.train_on_all_features:
# discover all image-like keys from raw features
for ds in self._datasets:
for k, v in ds.hf_features.items():
if isinstance(v, (datasets.Image, VideoFrame)):
self._canonical_image_keys.add(k)
# Cleaner: keep fps & consistent feature sets per robot type (unchanged)
cleaner = MultiLeRobotDatasetCleaner(
datasets=self._datasets,
repo_ids=repo_ids,
sampling_weights=self.sampling_weights,
datasets_repo_ids=self.datasets_repo_ids,
min_fps=min_fps,
max_fps=max_fps,
)
self._datasets = cleaner.cleaned_datasets
self.sampling_weights = cleaner.cleaned_weights
self.repo_ids = cleaner.cleaned_repo_ids
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
self.cumulative_sizes = cleaner.cumulative_sizes
# Meta (unchanged): we give it dim maxima; it will reshape/pad vectors
self.meta = MultiLeRobotDatasetMeta(
datasets=self._datasets,
repo_ids=self.repo_ids,
keys_to_max_dim={
ACTION: self.max_action_dim if self.max_action_dim is not None else -1,
OBS_ENV_STATE: self.max_state_dim if self.max_state_dim is not None else -1,
OBS_STATE: self.max_state_dim if self.max_state_dim is not None else -1,
OBS_IMAGE: self.max_image_dim if self.max_image_dim is not None else -1,
OBS_IMAGE_2: self.max_image_dim if self.max_image_dim is not None else -1,
OBS_IMAGE_3: self.max_image_dim if self.max_image_dim is not None else -1,
},
train_on_all_features=train_on_all_features,
)
# --- NEW: track dropped (source) keys so collate wont expect them
# Anything that we *rename away* should be considered disabled,
# otherwise downstream may expect them to exist.
self._dropped_keys = set()
for rid, mapping in self.feature_keys_mapping.items():
self._dropped_keys |= set(mapping.keys())
# Merge with metas disabled features
self.disabled_features = set(self.meta.disabled_features) | self._dropped_keys
self.stats = self.meta.stats
# --- NEW: cache an example image shape per canonical key (lazy, filled on first use)
self._cached_img_shape: dict[str, torch.Size] = {}
# ---------------------- NEW small helpers ----------------------
def _is_image_key_like(self, key: str) -> bool:
# A loose heuristic: rely on name OR on features later
return ("image" in key) or ("cam_" in key) or ("images." in key)
def _should_ignore(self, key: str) -> bool:
# exact or glob-style match
for pat in self._ignore_patterns:
if key == pat or fnmatch.fnmatch(key, pat):
return True
return False
def _apply_feature_mapping(self, item: dict, repo_id: str) -> dict:
"""
Rename features according to feature_keys_mapping[repo_id].
- Moves tensor/image under target key.
- Drops source key if moved.
- Adds *_is_pad=False for image targets we fill/keep.
"""
mapping = self.feature_keys_mapping.get(repo_id, {}) or {}
if not mapping:
return item
for src, tgt in mapping.items():
if src in item:
# Move value
item[tgt] = item[src]
# Drop the source to avoid duplication
del item[src]
return item
def _ensure_union_image_keys(self, item: dict) -> dict:
"""
Ensure that every canonical image key exists.
When missing, create a zero tensor matching (B,C,H,W) or (C,H,W) of an available image.
Also add boolean mask at f"{key}_is_pad".
"""
if not self.train_on_all_features or not self._canonical_image_keys:
return item
# find any existing image tensor in item to copy shape/dtype
exemplar = None
for k in list(item.keys()):
v = item[k]
if torch.is_tensor(v) and v.ndim in (3, 4, 5): # (C,H,W) or (B,C,H,W) or (B,T,C,H,W)
exemplar = v
break
# fallback to a safe 3x224x224 if nothing found
def _fallback_image():
return torch.zeros(3, 224, 224, dtype=torch.uint8)
for key in self._canonical_image_keys:
if key not in item:
img = torch.zeros_like(exemplar) if exemplar is not None else _fallback_image()
item[key] = img
item[f"{key}_is_pad"] = torch.tensor(True, dtype=torch.bool)
else:
# Add a mask saying its *not* padded
if f"{key}_is_pad" not in item:
item[f"{key}_is_pad"] = torch.tensor(False, dtype=torch.bool)
return item
# ---------------------- existing API below (mostly unchanged) ----------------------
self.delta_timestamps = delta_timestamps
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def repo_index_to_id(self):
"""Return the inverse mapping if repo_id_to_index."""
return {v: k for k, v in self.repo_id_to_index}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].meta.info.get("video", False)
@property
def features(self) -> datasets.Features:
"""
Extend native HF features with any *target* keys introduced by mapping.
We copy the source spec for targets that didnt exist in any raw dataset.
"""
features: dict[str, datasets.features.Feature] = {}
features = {}
for dataset in self._datasets:
for k, v in dataset.hf_features.items():
if k not in self.disabled_features:
features[k] = v
# Add mapped target image specs if not present yet
for rid, mapping in self.feature_keys_mapping.items():
ds = None
# find the dataset object to read feature spec for source
for _ds, _rid in zip(self._datasets, self.repo_ids, strict=False):
if _rid == rid:
ds = _ds
break
if ds is None:
continue
for src, tgt in mapping.items():
if tgt not in features and src in ds.hf_features:
features[tgt] = ds.hf_features[src]
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
@@ -1723,6 +1437,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
@@ -1731,14 +1451,21 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
@property
def num_frames(self) -> int:
"""Number of samples/frames."""
return sum(d.num_frames for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
@@ -1747,83 +1474,22 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
item = self._datasets[dataset_idx][local_idx]
# Identify which repo this sample came from
repo_id = self.datasets_repo_ids[dataset_idx]
# --- NEW: apply mapping and ensure union of image keys ---
item = self._apply_feature_mapping(item, repo_id)
item = self._ensure_union_image_keys(item)
# annotate dataset index for downstream
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_frames:
start_idx += dataset.num_frames
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
# Pad vector features to max dims using meta (unchanged)
item = create_padded_features(item, self.meta.features)
# Drop any disabled (including original source keys we remapped away)
for data_key in self.disabled_features:
if data_key in item:
del item[data_key]
for k in IGNORED_KEYS:
if k in item:
item.pop(k)
# Convert any datasets.Image still present to tensor
if self.image_transforms is not None:
for cam in [k for k in item.keys() if self._is_image_key_like(k)]:
val = item[cam]
if not torch.is_tensor(val):
item[cam] = self.image_transforms(val)
# 🔑 Pad actions if too short
if "actions" in item and self.max_action_dim is not None:
act = item["actions"]
if act.shape[-1] < self.max_action_dim:
pad_len = self.max_action_dim - act.shape[-1]
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1)
item["actions_padding_mask"] = torch.cat(
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
dim=-1,
)
# pad obs_state if too short
if "obs_state" in item and self.max_state_dim is not None:
st = item["obs_state"]
if st.shape[-1] < self.max_state_dim:
pad_len = self.max_state_dim - st.shape[-1]
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1)
item["obs_state_padding_mask"] = torch.cat(
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
dim=-1,
)
# actions
if "actions" in item and self.max_action_dim is not None:
act = item["actions"]
if act.shape[-1] < self.max_action_dim:
pad_len = self.max_action_dim - act.shape[-1]
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1)
mask = torch.cat(
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
dim=-1,
)
else:
mask = torch.zeros(self.max_action_dim, dtype=torch.bool) # 👈 all False if no padding
item["actions_padding_mask"] = mask
# obs state
if "obs_state" in item and self.max_state_dim is not None:
st = item["obs_state"]
if st.shape[-1] < self.max_state_dim:
pad_len = self.max_state_dim - st.shape[-1]
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1)
mask = torch.cat(
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
dim=-1,
)
else:
mask = torch.zeros(self.max_state_dim, dtype=torch.bool) # 👈 always add mask
item["obs_state_padding_mask"] = mask
return item
@@ -1840,149 +1506,3 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
f" Transformations: {self.image_transforms},\n"
f")"
)
def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list:
"""
Filters datasets to only keep those with consistent feature shapes per robot type.
Args:
ls_datasets (List): List of datasets, each with a `meta.info['robot_type']`
and `meta.episodes_stats` dictionary.
Returns:
List: Filtered list of datasets with consistent feature shapes.
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
datasets_to_remove = set()
for robot_type in robot_types:
# Collect all stats dicts for this robot type
stats_list = [
ep_stats
for ds in ls_datasets
if ds.meta.info["robot_type"] == robot_type
for ep_stats in episode_stats_values(ds.meta)
]
if not stats_list:
continue
# Determine the most common shape for each key
all_keys = {key for stats in stats_list for key in stats}
for ds in ls_datasets:
if ds.meta.info["robot_type"] != robot_type:
continue
for key in all_keys:
shape_counter = defaultdict(int)
for stats in stats_list:
value = stats.get(key)
if (
value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray))
): # FIXME(mshukor): check all stats; min, mean, max
shape_counter[value["mean"].shape] += 1
if not shape_counter:
continue
# Identify the most frequent shape
main_shape = max(shape_counter, key=shape_counter.get)
# Flag datasets that don't match the main shape
# for ds in ls_datasets:
first_ep_stats = next(iter(episode_stats_values(ds.meta)), None)
if not first_ep_stats:
continue
value = first_ep_stats.get(key)
if (
value
and "mean" in value
and isinstance(value["mean"], (torch.Tensor, np.ndarray))
and value["mean"].shape != main_shape
):
datasets_to_remove.add(ds)
break
# Filter out inconsistent datasets
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets]
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove]
print(
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}"
)
return filtered_datasets, datasets_maks
def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]:
"""Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
"""
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
stats = {robot_type: {} for robot_type in robot_types}
for robot_type in robot_types:
robot_type_datasets = []
for ds in ls_datasets:
if ds.meta.info["robot_type"] == robot_type:
robot_type_datasets.extend(list(episode_stats_values(ds.meta)))
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
stat = aggregate_stats(robot_type_datasets)
stats[robot_type] = stat
return stats
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict:
"""Reshape features to have a maximum dimension of `max_dim`."""
reshaped_features = {}
for key in features:
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
reshaped_features[key] = features[key]
shape = list(features[key]["shape"])
if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images
shape[-3] = keys_to_max_dim[key]
shape[-2] = keys_to_max_dim[key]
else:
shape[reshape_dim] = keys_to_max_dim[key]
reshaped_features[key]["shape"] = tuple(shape)
else:
reshaped_features[key] = features[key]
return reshaped_features
def create_padded_features(item: dict, features: dict = {}):
for key, ft in features.items():
if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack
continue
shape = ft["shape"]
if len(shape) == 3: # images to torch format (C, H, W)
shape = (shape[2], shape[0], shape[1])
if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele)
shape = []
if key not in item:
dtype = str_to_torch_dtype(ft["dtype"])
item[key] = torch.zeros(shape, dtype=dtype)
item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64)
if "image" in key: # FIXME(mshukor): support other observations
item[f"{key}_is_pad"] = torch.BoolTensor([False])
else:
item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64)
return item
def str_to_torch_dtype(dtype_str):
"""Convert a dtype string to a torch dtype."""
mapping = {
"float32": torch.float32,
"int64": torch.int64,
"int16": torch.int16,
"bool": torch.bool,
"video": torch.float32, # Assuming video is stored as uint8 images
}
return mapping.get(dtype_str, torch.float32) # Default to float32
def episode_stats_values(meta):
episodes = meta.episodes.to_pandas().to_dict(orient="records")
return [
{k: v for k, v in ep.items() if isinstance(v, dict) and "mean" in v}
for ep in episodes
]

View File

@@ -118,7 +118,7 @@ echo ${HF_USER}/aloha_test
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with [Rerun](https://github.com/rerun-io/rerun):
```bash
python -m lerobot.scripts.visualize_dataset \
lerobot-dataset-viz \
--repo-id ${HF_USER}/aloha_test --episode 0
```

View File

@@ -29,14 +29,14 @@ Examples:
- Visualize data stored on a local machine:
```
local$ python -m lerobot.scripts.visualize_dataset \
local$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0
```
- Visualize data stored on a distant machine with a local viewer:
```
distant$ python -m lerobot.scripts.visualize_dataset \
distant$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0 \
--save 1 \
@@ -50,7 +50,7 @@ local$ rerun lerobot_pusht_episode_0.rrd
(You need to forward the websocket port to the distant machine, with
`ssh -L 9087:localhost:9087 username@remote-host`)
```
distant$ python -m lerobot.scripts.visualize_dataset \
distant$ lerobot-dataset-viz \
--repo-id lerobot/pusht \
--episode-index 0 \
--mode distant \

View File

@@ -20,10 +20,10 @@ Additionally, each individual transform can be visualized separately as well as
Example:
```bash
python -m lerobot.scripts.visualize_image_transforms \
--repo_id=lerobot/pusht \
--episodes='[0]' \
--image_transforms.enable=True
lerobot-imgtransform-viz \
--repo_id=lerobot/pusht \
--episodes='[0]' \
--image_transforms.enable=True
```
"""
@@ -126,5 +126,9 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR
save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
if __name__ == "__main__":
def main():
visualize_image_transforms()
if __name__ == "__main__":
main()

View File

@@ -29,7 +29,7 @@ from lerobot.datasets.transforms import (
SharpnessJitter,
make_transform_from_config,
)
from lerobot.scripts.visualize_image_transforms import (
from lerobot.scripts.lerobot_imgtransform_viz import (
save_all_transforms,
save_each_transform,
)

View File

@@ -15,7 +15,7 @@
# limitations under the License.
import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset
from lerobot.scripts.lerobot_dataset_viz import visualize_dataset
@pytest.mark.skip("TODO: add dummy videos")