mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
Compare commits
6 Commits
feat/add-m
...
feat/add-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6eaf6a861a | ||
|
|
cdd6cb606c | ||
|
|
f6cd24be17 | ||
|
|
54c6b8ae52 | ||
|
|
c9787bd98a | ||
|
|
c435d3cebc |
@@ -202,7 +202,7 @@ Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.visualize_dataset \
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
@@ -210,7 +210,7 @@ python -m lerobot.scripts.visualize_dataset \
|
||||
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.visualize_dataset \
|
||||
lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--local-files-only 1 \
|
||||
@@ -221,7 +221,7 @@ It will open `rerun.io` and display the camera streams, robot states and actions
|
||||
|
||||
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
|
||||
|
||||
Our script can also visualize datasets stored on a distant server. See `python -m lerobot.scripts.visualize_dataset --help` for more instructions.
|
||||
Our script can also visualize datasets stored on a distant server. See `lerobot-dataset-viz --help` for more instructions.
|
||||
|
||||
### The `LeRobotDataset` format
|
||||
|
||||
|
||||
378
benchmarks/policies/inference.py
Normal file
378
benchmarks/policies/inference.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Benchmark memory footprint and inference latency of a policy on arbitrary devices.
|
||||
|
||||
This script loads a pretrained policy directly (similar to the async inference server)
|
||||
and generates dummy input data based on the policy's input_features to perform
|
||||
accurate benchmarking without requiring datasets.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import signal
|
||||
import statistics
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class TimeoutException:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timeout(seconds):
|
||||
def signal_handler(signum, frame):
|
||||
raise TimeoutException(f"Timed out after {seconds} seconds")
|
||||
|
||||
# On Windows, signal is not available, so we can't use this timeout mechanism
|
||||
if not hasattr(signal, "SIGALRM"):
|
||||
yield
|
||||
return
|
||||
|
||||
old_handler = signal.signal(signal.SIGALRM, signal_handler)
|
||||
try:
|
||||
# signal.alarm expects integer seconds
|
||||
# for float seconds, we can use setitimer
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds)
|
||||
yield
|
||||
finally:
|
||||
signal.setitimer(signal.ITIMER_REAL, 0)
|
||||
signal.signal(signal.SIGALRM, old_handler)
|
||||
|
||||
|
||||
def bytes_to_human(n: int) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if n < 1024:
|
||||
return f"{n:.2f} {unit}"
|
||||
n /= 1024
|
||||
return f"{n:.2f} PB"
|
||||
|
||||
|
||||
def percentile(values: list[float], p: float) -> float:
|
||||
if not values:
|
||||
return float("nan")
|
||||
k = (len(values) - 1) * (p / 100.0)
|
||||
f = int(k)
|
||||
c = min(f + 1, len(values) - 1)
|
||||
if f == c:
|
||||
return values[f]
|
||||
return values[f] + (values[c] - values[f]) * (k - f)
|
||||
|
||||
|
||||
def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dict:
|
||||
"""Generate dummy observation data based on policy input features."""
|
||||
dummy_obs = {}
|
||||
|
||||
for key, feature in input_features.items():
|
||||
shape = feature.shape
|
||||
|
||||
if feature.type == FeatureType.VISUAL:
|
||||
# Images: random values in [0, 1] range (already normalized)
|
||||
dummy_obs[key] = torch.rand(shape, dtype=torch.float32, device=device)
|
||||
elif feature.type in [FeatureType.STATE, FeatureType.ACTION, FeatureType.ENV]:
|
||||
# State/action/env: random normal distribution
|
||||
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
else:
|
||||
# Default: random normal for unknown types
|
||||
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
|
||||
# Add batch dimension
|
||||
for key in dummy_obs:
|
||||
dummy_obs[key] = dummy_obs[key].unsqueeze(0)
|
||||
|
||||
# Add task string for language-conditioned policies
|
||||
dummy_obs["task"] = ""
|
||||
|
||||
return dummy_obs
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Policy inference benchmark")
|
||||
parser.add_argument(
|
||||
"--policy-id", type=str, required=True, help="Model ID or local path to pretrained policy"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy-type", type=str, required=True, help="Type of policy (smolvla, act, diffusion, etc.)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--num-samples", type=int, default=100, help="Number of inference samples to benchmark"
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup samples (not timed)")
|
||||
parser.add_argument(
|
||||
"--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="Timeout for each inference pass in seconds (default: 0.3s = 300ms)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Seed & deterministic-ish setup
|
||||
torch.manual_seed(args.seed)
|
||||
if args.device == "cuda":
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = False # leave False to avoid perf cliffs
|
||||
|
||||
# Resolve device availability
|
||||
device = args.device.lower()
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
print("[!] CUDA requested but unavailable. Falling back to CPU.")
|
||||
device = "cpu"
|
||||
elif device == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
|
||||
print("[!] MPS requested but unavailable. Falling back to CPU.")
|
||||
device = "cpu"
|
||||
|
||||
use_cuda = device == "cuda"
|
||||
|
||||
# Create output directory and log file
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
policy_name = args.policy_id.replace("/", "_").replace("\\", "_")
|
||||
log_file = output_dir / f"benchmark_{args.policy_type}_{policy_name}_{device}_{timestamp}.txt"
|
||||
|
||||
# Load policy directly from pretrained (similar to async inference server)
|
||||
print(f"Loading policy {args.policy_type} from {args.policy_id}...")
|
||||
policy_class = get_policy_class(args.policy_type)
|
||||
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
print(f"Policy loaded on {device}")
|
||||
print(f"Input features: {list(policy.config.input_features.keys())}")
|
||||
print(f"Output features: {list(policy.config.output_features.keys())}")
|
||||
|
||||
# Generate dummy observation based on policy input features
|
||||
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
|
||||
dummy_observation["task"] = ""
|
||||
|
||||
# Helper to sync for fair timings
|
||||
def _sync(dev_=device):
|
||||
if dev_ == "cuda" and torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elif dev_ == "mps" and hasattr(torch, "mps"):
|
||||
try:
|
||||
torch.mps.synchronize()
|
||||
except AttributeError:
|
||||
pass # MPS sync not available in this PyTorch version
|
||||
|
||||
# Warmup (to stabilize kernels/caches)
|
||||
print("Warming up...")
|
||||
with torch.no_grad():
|
||||
policy.reset()
|
||||
for _ in range(args.warmup):
|
||||
_ = policy.select_action(dummy_observation)
|
||||
_sync()
|
||||
|
||||
# Memory footprint before timing
|
||||
process = psutil.Process(os.getpid())
|
||||
rss_before = process.memory_info().rss
|
||||
if use_cuda:
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# PyTorch timing with Event objects for more accurate GPU timing
|
||||
print(f"Running benchmark: {args.num_samples} samples...")
|
||||
|
||||
if use_cuda:
|
||||
# Use CUDA Events for precise GPU timing
|
||||
start_events = []
|
||||
end_events = []
|
||||
timeout_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for forward in tqdm(range(args.num_samples), desc="Trials"):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
try:
|
||||
with timeout(args.timeout):
|
||||
start_event.record()
|
||||
_ = policy.select_action(dummy_observation)
|
||||
end_event.record()
|
||||
|
||||
start_events.append(start_event)
|
||||
end_events.append(end_event)
|
||||
except TimeoutException:
|
||||
timeout_count += 1
|
||||
# Add placeholder for timeout
|
||||
start_events.append(None)
|
||||
end_events.append(None)
|
||||
print(f"\n[!] Timeout on forward {forward + 1}")
|
||||
continue
|
||||
|
||||
# Synchronize and collect timing results
|
||||
torch.cuda.synchronize()
|
||||
per_forward_ms = []
|
||||
for start_event, end_event in zip(start_events, end_events, strict=True):
|
||||
if start_event is None:
|
||||
per_forward_ms.append(args.timeout * 1000)
|
||||
else:
|
||||
per_forward_ms.append(start_event.elapsed_time(end_event))
|
||||
|
||||
if timeout_count > 0:
|
||||
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
||||
|
||||
else:
|
||||
# Use simple time.perf_counter for CPU/MPS timing with timeout
|
||||
import time
|
||||
|
||||
per_forward_ms = []
|
||||
timeout_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for sample in tqdm(range(args.num_samples), desc="Samples"):
|
||||
try:
|
||||
with timeout(args.timeout):
|
||||
start_time = time.perf_counter()
|
||||
_ = policy.select_action(dummy_observation)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
|
||||
except TimeoutException:
|
||||
timeout_count += 1
|
||||
per_forward_ms.append(args.timeout * 1000)
|
||||
print(f"\n[!] Timeout on sample {sample + 1}")
|
||||
continue
|
||||
|
||||
if timeout_count > 0:
|
||||
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
||||
|
||||
# Memory footprint after timing
|
||||
rss_after = process.memory_info().rss
|
||||
rss_delta = rss_after - rss_before
|
||||
cuda_peak = torch.cuda.max_memory_allocated() if use_cuda else 0
|
||||
|
||||
# Sort timing results for percentile calculations
|
||||
per_forward_ms_sorted = sorted(per_forward_ms)
|
||||
|
||||
mean_ms = statistics.fmean(per_forward_ms) if per_forward_ms else float("nan")
|
||||
std_ms = statistics.pstdev(per_forward_ms) if len(per_forward_ms) > 1 else 0.0
|
||||
min_ms = per_forward_ms_sorted[0] if per_forward_ms_sorted else float("nan")
|
||||
max_ms = per_forward_ms_sorted[-1] if per_forward_ms_sorted else float("nan")
|
||||
p50_ms = percentile(per_forward_ms_sorted, 50)
|
||||
p95_ms = percentile(per_forward_ms_sorted, 95)
|
||||
|
||||
# Model size
|
||||
num_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
# Prepare results for logging
|
||||
results = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"policy_type": args.policy_type,
|
||||
"policy_id": args.policy_id,
|
||||
"device": device,
|
||||
"num_trials": args.num_samples,
|
||||
"forwards_per_trial": 1,
|
||||
"warmup": args.warmup,
|
||||
"timeout_ms": args.timeout * 1000,
|
||||
"seed": args.seed,
|
||||
"num_params": num_params,
|
||||
"timeout_count": timeout_count,
|
||||
"latency_mean_ms": mean_ms,
|
||||
"latency_std_ms": std_ms,
|
||||
"latency_min_ms": min_ms,
|
||||
"latency_max_ms": max_ms,
|
||||
"latency_p50_ms": p50_ms,
|
||||
"latency_p95_ms": p95_ms,
|
||||
"cpu_rss_before": rss_before,
|
||||
"cpu_rss_after": rss_after,
|
||||
"cpu_rss_delta": rss_delta,
|
||||
"cuda_peak_alloc": cuda_peak,
|
||||
"input_features": list(policy.config.input_features.keys()),
|
||||
"output_features": list(policy.config.output_features.keys()),
|
||||
}
|
||||
|
||||
# Format and write results to log file
|
||||
log_content = f"""
|
||||
=== LeRobot Policy Inference Benchmark ===
|
||||
Timestamp: {results["timestamp"]}
|
||||
Policy: {results["policy_type"]} ({results["policy_id"]})
|
||||
Device: {results["device"]}
|
||||
Seed: {results["seed"]}
|
||||
|
||||
=== Model Information ===
|
||||
Parameters: {results["num_params"]:,}
|
||||
Input Features: {", ".join(results["input_features"])}
|
||||
Output Features: {", ".join(results["output_features"])}
|
||||
|
||||
=== Benchmark Configuration ===
|
||||
Samples: {results["num_trials"]}
|
||||
Warmup: {results["warmup"]}
|
||||
Total Measurements: {len(per_forward_ms)}
|
||||
Timeout: {results["timeout_ms"]:.1f}ms
|
||||
Timeouts: {results["timeout_count"]} / {results["num_trials"]}
|
||||
|
||||
=== Latency Results (ms) ===
|
||||
Mean: {results["latency_mean_ms"]:.3f}
|
||||
Std Dev: {results["latency_std_ms"]:.3f}
|
||||
Min: {results["latency_min_ms"]:.3f}
|
||||
Max: {results["latency_max_ms"]:.3f}
|
||||
P50: {results["latency_p50_ms"]:.3f}
|
||||
P95: {results["latency_p95_ms"]:.3f}
|
||||
|
||||
=== Memory Footprint ===
|
||||
CPU RSS Before: {bytes_to_human(results["cpu_rss_before"])}
|
||||
CPU RSS After: {bytes_to_human(results["cpu_rss_after"])} (Δ {bytes_to_human(results["cpu_rss_delta"])})
|
||||
"""
|
||||
|
||||
if use_cuda:
|
||||
log_content += f"CUDA Peak: {bytes_to_human(results['cuda_peak_alloc'])} (reset before timing)\n"
|
||||
|
||||
log_content += f"""
|
||||
=== Raw Timing Data (first 20 measurements, ms) ===
|
||||
{", ".join(f"{t:.3f}" for t in per_forward_ms[:20])}
|
||||
{"..." if len(per_forward_ms) > 20 else ""}
|
||||
|
||||
=== Summary Statistics ===
|
||||
Timing Method: {"CUDA Events" if use_cuda else "torch.utils.benchmark.Timer"}
|
||||
Device Available: {torch.cuda.is_available() if device == "cuda" else torch.backends.mps.is_available() if device == "mps" else True}
|
||||
PyTorch Version: {torch.__version__}
|
||||
|
||||
Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||||
"""
|
||||
|
||||
# Write to log file
|
||||
with open(log_file, "w") as f:
|
||||
f.write(log_content)
|
||||
|
||||
# Print to console (shorter version)
|
||||
print("\n=== Inference Benchmark Results ===")
|
||||
print(f"Policy: {args.policy_type} ({args.policy_id})")
|
||||
print(f"Device: {device}")
|
||||
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
|
||||
print(f"Model params: {num_params:,}")
|
||||
|
||||
print("\nLatency per forward (ms):")
|
||||
print(f" mean: {mean_ms:.3f} std: {std_ms:.3f}")
|
||||
print(f" min: {min_ms:.3f} max: {max_ms:.3f}")
|
||||
print(f" p50: {p50_ms:.3f} p95: {p95_ms:.3f}")
|
||||
|
||||
print("\nMemory footprint:")
|
||||
print(f" CPU RSS before: {bytes_to_human(rss_before)}")
|
||||
print(f" CPU RSS after : {bytes_to_human(rss_after)} (Δ {bytes_to_human(rss_delta)})")
|
||||
if use_cuda:
|
||||
print(
|
||||
f" CUDA peak allocated: {bytes_to_human(cuda_peak)} "
|
||||
f"(reset by reset_peak_memory_stats before timing)"
|
||||
)
|
||||
|
||||
print(f"\nResults saved to: {log_file}")
|
||||
print("Benchmark completed successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -246,7 +246,7 @@ You can also use any `torchvision.transforms.v2` transform by passing it directl
|
||||
Use the visualization script to preview how transforms affect your data:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.visualize_image_transforms \
|
||||
lerobot-imgtransform-viz \
|
||||
--repo-id=your-username/your-dataset \
|
||||
--output-dir=./transform_examples \
|
||||
--n-examples=5
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
from lerobot.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||
|
||||
REPO_A = "lerobot/pusht"
|
||||
REPO_B = "lerobot/aloha_mobile_cabinet" # replace with the actual repo id
|
||||
|
||||
feature_keys_mapping = {
|
||||
REPO_A: { # pusht (1 camera, 2-dim)
|
||||
"action": "actions",
|
||||
"observation.state": "obs_state",
|
||||
"observation.image": "obs_image.cam_high",
|
||||
},
|
||||
REPO_B: { # dual arm (3 cameras, 14-dim)
|
||||
"action": "actions",
|
||||
"observation.state": "obs_state",
|
||||
"observation.images.cam_high": "obs_image.cam_high",
|
||||
"observation.images.cam_left_wrist": "obs_image.cam_left_wrist",
|
||||
"observation.images.cam_right_wrist": "obs_image.cam_right_wrist",
|
||||
},
|
||||
}
|
||||
|
||||
from torchvision.transforms.v2 import Compose, ToImage, Resize
|
||||
image_tf = Compose([
|
||||
ToImage(), # converts to tensor if needed
|
||||
Resize((224, 224)), # unify sizes across datasets (96x96 vs 480x640)
|
||||
])
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
dataset = MultiLeRobotDataset(
|
||||
repo_ids=[REPO_A, REPO_B],
|
||||
image_transforms=image_tf, # ensures same HxW
|
||||
feature_keys_mapping=feature_keys_mapping,
|
||||
train_on_all_features=True, # keep union of cameras; zero-fill missing
|
||||
# optional: override if you want fixed maxima; else inferred:
|
||||
# max_action_dim=14,
|
||||
# max_state_dim=14,
|
||||
max_action_dim=14,
|
||||
max_state_dim=14,
|
||||
max_image_dim=224,
|
||||
ignore_keys=[
|
||||
"next.*", # drop reward/done/success
|
||||
"index",
|
||||
"timestamp",
|
||||
"videos/*", # drop all video metadata
|
||||
"observation.effort", # 👈 drop effort everywhere
|
||||
],
|
||||
)
|
||||
breakpoint()
|
||||
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, pin_memory=True)
|
||||
for _ in range(100):
|
||||
batch = next(iter(loader))
|
||||
|
||||
breakpoint()
|
||||
# vectors padded to maxima (pusht:2 -> 14; dual-arm:14 -> 14)
|
||||
assert batch["actions"].shape[-1] == 14
|
||||
assert batch["obs_state"].shape[-1] == 14
|
||||
assert batch["actions_padding_mask"].shape[-1] == 14
|
||||
assert batch["obs_state_padding_mask"].shape[-1] == 14
|
||||
|
||||
# cameras: all canonical keys exist; pusht will have wrists zero-filled
|
||||
for cam in ["obs_image.cam_high", "obs_image.cam_left_wrist", "obs_image.cam_right_wrist"]:
|
||||
assert cam in batch
|
||||
assert f"{cam}_is_pad" in batch
|
||||
# images should all be 3x224x224 (or your transform’s size)
|
||||
img = batch[cam]
|
||||
assert img.ndim in (4, 5) # (B,C,H,W) or (B,T,C,H,W) depending on your loader
|
||||
@@ -1,16 +0,0 @@
|
||||
# storage / caches
|
||||
RAID=/raid/jade
|
||||
export TRANSFORMERS_CACHE=$RAID/.cache/huggingface/transformers
|
||||
export HF_HOME=$RAID/.cache/huggingface
|
||||
export HF_DATASETS_CACHE=$RAID/.cache/huggingface/datasets
|
||||
export HF_LEROBOT_HOME=$RAID/.cache/huggingface/lerobot
|
||||
export WANDB_CACHE_DIR=$RAID/.cache/wandb
|
||||
export TMPDIR=$RAID/.cache/tmp
|
||||
mkdir -p $TMPDIR
|
||||
export WANDB_MODE=offline
|
||||
# export HF_DATASETS_OFFLINE=1
|
||||
# export HF_HUB_OFFLINE=1
|
||||
export TOKENIZERS_PARALLELISM=false
|
||||
export MUJOCO_GL=egl
|
||||
|
||||
python examples/tester.py
|
||||
@@ -171,7 +171,9 @@ lerobot-setup-motors="lerobot.setup_motors:main"
|
||||
lerobot-teleoperate="lerobot.teleoperate:main"
|
||||
lerobot-eval="lerobot.scripts.eval:main"
|
||||
lerobot-train="lerobot.scripts.train:main"
|
||||
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
||||
lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.packages.find]
|
||||
|
||||
@@ -174,79 +174,3 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np
|
||||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||||
|
||||
return aggregated_stats
|
||||
|
||||
import numpy as np
|
||||
|
||||
def aggregate_stats_multi(
|
||||
stats_list: list[dict[str, dict]],
|
||||
max_action_dim: int | None = None,
|
||||
max_state_dim: int | None = None,
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
Supports heterogeneous robots by padding action/state stats to the max dim.
|
||||
The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
||||
- new_min = elementwise min across datasets
|
||||
- new_max = elementwise max across datasets
|
||||
- new_mean = weighted mean (by count)
|
||||
- new_std = recomputed from total variance
|
||||
"""
|
||||
|
||||
data_keys = {key for stats in stats_list for key in stats}
|
||||
aggregated_stats = {key: {} for key in data_keys}
|
||||
|
||||
def _pad(arr: np.ndarray, target: int) -> np.ndarray:
|
||||
if arr.ndim == 0: # scalar
|
||||
return arr
|
||||
if target is None or target <= 0 or arr.shape[-1] == target:
|
||||
return arr
|
||||
pad_width = [(0, 0)] * arr.ndim
|
||||
pad_width[-1] = (0, target - arr.shape[-1])
|
||||
return np.pad(arr, pad_width, mode="constant")
|
||||
|
||||
for key in data_keys:
|
||||
stats_with_key = [stats[key] for stats in stats_list if key in stats]
|
||||
|
||||
# decide if this key should be padded
|
||||
target_dim = None
|
||||
if "action" in key and max_action_dim:
|
||||
target_dim = max_action_dim
|
||||
elif "state" in key and max_state_dim:
|
||||
target_dim = max_state_dim
|
||||
|
||||
padded = []
|
||||
counts = []
|
||||
for s in stats_with_key:
|
||||
mean = _pad(np.array(s["mean"]), target_dim)
|
||||
std = _pad(np.array(s["std"]), target_dim)
|
||||
min_ = _pad(np.array(s["min"]), target_dim)
|
||||
max_ = _pad(np.array(s["max"]), target_dim)
|
||||
count = s.get("count", 1)
|
||||
|
||||
padded.append(dict(mean=mean, std=std, min=min_, max=max_, count=count))
|
||||
counts.append(count)
|
||||
|
||||
counts = np.array(counts, dtype=np.float64)
|
||||
total_count = counts.sum()
|
||||
|
||||
means = np.stack([p["mean"] for p in padded])
|
||||
stds = np.stack([p["std"] for p in padded])
|
||||
mins = np.stack([p["min"] for p in padded])
|
||||
maxs = np.stack([p["max"] for p in padded])
|
||||
|
||||
# weighted mean (broadcast weights properly)
|
||||
new_mean = np.average(means, axis=0, weights=counts)
|
||||
new_var = np.average(stds**2 + (means - new_mean)**2, axis=0, weights=counts)
|
||||
|
||||
new_std = np.sqrt(new_var)
|
||||
|
||||
aggregated_stats[key] = {
|
||||
"min": mins.min(axis=0),
|
||||
"max": maxs.max(axis=0),
|
||||
"mean": new_mean,
|
||||
"std": new_std,
|
||||
"count": int(total_count),
|
||||
}
|
||||
|
||||
return aggregated_stats
|
||||
|
||||
@@ -31,7 +31,6 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from collections import defaultdict
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
@@ -82,12 +81,7 @@ from lerobot.datasets.video_utils import (
|
||||
)
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGE_2 = "observation.image_2"
|
||||
OBS_IMAGE_3 = "observation.image_3"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_ENV_STATE = "observation.env_state"
|
||||
ACTION = "action"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
def __init__(
|
||||
@@ -1328,139 +1322,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
return obj
|
||||
|
||||
ROBOT_TYPE_KEYS_MAPPING = {
|
||||
"lerobot/stanford_hydra_dataset": "static_single_arm",
|
||||
"lerobot/iamlab_cmu_pickup_insert": "static_single_arm",
|
||||
"lerobot/berkeley_fanuc_manipulation": "static_single_arm",
|
||||
"lerobot/toto": "static_single_arm",
|
||||
"lerobot/roboturk": "static_single_arm",
|
||||
"lerobot/jaco_play": "static_single_arm",
|
||||
"lerobot/taco_play": "static_single_arm_7statedim",
|
||||
}
|
||||
class MultiLeRobotDatasetMeta:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
repo_ids: list[str],
|
||||
keys_to_max_dim: dict[str, int],
|
||||
train_on_all_features: bool = False,
|
||||
):
|
||||
self.repo_ids = repo_ids
|
||||
self.keys_to_max_dim = keys_to_max_dim
|
||||
self.train_on_all_features = train_on_all_features
|
||||
self.robot_types = [ds.meta.info["robot_type"] for ds in datasets]
|
||||
|
||||
# assign robot_type if missing
|
||||
for ds in datasets:
|
||||
ds.meta.info["robot_type"] = ROBOT_TYPE_KEYS_MAPPING.get(ds.repo_id, ds.meta.info["robot_type"])
|
||||
ds.robot_type = ds.meta.info["robot_type"]
|
||||
|
||||
# step 1: compute disabled features
|
||||
self.disabled_features = set()
|
||||
if not self.train_on_all_features:
|
||||
intersection = set(datasets[0].features)
|
||||
for ds in datasets:
|
||||
intersection.intersection_update(ds.features)
|
||||
if not intersection:
|
||||
raise RuntimeError("No common features across datasets.")
|
||||
for repo_id, ds in zip(repo_ids, datasets, strict=False):
|
||||
extra = set(ds.features) - intersection
|
||||
logging.warning(f"Disabling {extra} for repo {repo_id}")
|
||||
self.disabled_features.update(extra)
|
||||
|
||||
# step 2: build union_features excluding disabled
|
||||
self.union_features = {}
|
||||
for ds in datasets:
|
||||
for k, v in ds.features.items():
|
||||
if k not in self.disabled_features:
|
||||
self.union_features[k] = v
|
||||
|
||||
# step 3: reshape feature schema
|
||||
self.features = reshape_features_to_max_dim(
|
||||
self.union_features, reshape_dim=-1, keys_to_max_dim=self.keys_to_max_dim
|
||||
)
|
||||
|
||||
# step 4: aggregate stats
|
||||
self.stats = aggregate_stats_per_robot_type(datasets)
|
||||
for robot_type_, stats_ in self.stats.items():
|
||||
for feat_key, feat_stats in stats_.items():
|
||||
if feat_key in [ACTION, OBS_ENV_STATE, OBS_STATE]:
|
||||
for k, v in feat_stats.items():
|
||||
pad_value = 0 if k in ["min", "mean"] else 1
|
||||
self.stats[robot_type_][feat_key][k] = pad_tensor(
|
||||
v,
|
||||
max_size=self.keys_to_max_dim.get(feat_key, -1),
|
||||
pad_dim=-1,
|
||||
pad_value=pad_value,
|
||||
)
|
||||
|
||||
# step 5: episodes & tasks
|
||||
self.episodes = {repo_id: ds.meta.episodes for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||
self.tasks = {repo_id: ds.meta.tasks for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||
self.info = {repo_id: ds.meta.info for repo_id, ds in zip(repo_ids, datasets, strict=False)}
|
||||
|
||||
|
||||
class MultiLeRobotDatasetCleaner:
|
||||
def __init__(
|
||||
self,
|
||||
datasets: list[LeRobotDataset],
|
||||
repo_ids: list[str],
|
||||
sampling_weights: list[float],
|
||||
datasets_repo_ids: list[str],
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 100,
|
||||
):
|
||||
self.original_datasets = datasets
|
||||
self.original_repo_ids = repo_ids
|
||||
self.original_weights = sampling_weights
|
||||
self.original_datasets_repo_ids = datasets_repo_ids
|
||||
|
||||
# step 1: remove datasets with invalid fps
|
||||
|
||||
# step 2: keep datasets with same features per robot type
|
||||
consistent_datasets, keep_mask = keep_datasets_with_the_same_features_per_robot_type(
|
||||
datasets
|
||||
)
|
||||
|
||||
self.cleaned_datasets = consistent_datasets
|
||||
self.keep_mask = keep_mask
|
||||
self.cleaned_weights = [sampling_weights[i] for i in range(len(datasets)) if keep_mask[i]]
|
||||
self.cleaned_repo_ids = [repo_ids[i] for i in range(len(datasets)) if keep_mask[i]]
|
||||
self.cleaned_datasets_repo_ids = [
|
||||
datasets_repo_ids[i] for i in range(len(datasets)) if keep_mask[i]
|
||||
]
|
||||
|
||||
self.cumulative_sizes = np.array(
|
||||
[0] + list(torch.cumsum(torch.tensor([len(d) for d in consistent_datasets]), dim=0))
|
||||
)
|
||||
self.cleaned_weights = np.array(self.cleaned_weights, dtype=np.float32)
|
||||
|
||||
# --- at the top of the file (same imports as before) ---
|
||||
from collections import defaultdict
|
||||
from typing import Callable
|
||||
import copy
|
||||
import numpy as np
|
||||
import torch
|
||||
import datasets
|
||||
from pathlib import Path
|
||||
|
||||
# If you already have these in your codebase, reuse them
|
||||
try:
|
||||
from lerobot.common.constants import (
|
||||
ACTION, OBS_ENV_STATE, OBS_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3
|
||||
)
|
||||
except Exception:
|
||||
# Fallbacks if constants are already strings elsewhere
|
||||
ACTION = "action"
|
||||
OBS_ENV_STATE = "observation.env_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGE_2 = "observation.image_2"
|
||||
OBS_IMAGE_3 = "observation.image_3"
|
||||
|
||||
IGNORED_KEYS = ["observation.effort"]
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
# ... keep your existing docstring ...
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
|
||||
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
||||
structure of `LeRobotDataset`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1468,253 +1336,99 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
root: str | Path | None = None,
|
||||
episodes: dict | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
delta_timestamps: dict[str, list[float]] | None = None,
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
# --- NEW: simple add-ons ---
|
||||
sampling_weights: list[float] | None = None,
|
||||
feature_keys_mapping: dict[str, dict[str, str]] | None = None,
|
||||
max_action_dim: int | None = None,
|
||||
max_state_dim: int | None = None,
|
||||
max_num_images: int | None = None,
|
||||
max_image_dim: int | None = None,
|
||||
train_on_all_features: bool = False,
|
||||
min_fps: int = 1,
|
||||
max_fps: int = 100,
|
||||
ignore_keys: list[str] | None = None, # exact or glob patterns
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001)
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes[repo_id] if episodes else None,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
||||
# --- NEW: store mapping and simple knobs ---
|
||||
self.feature_keys_mapping: dict[str, dict[str, str]] = feature_keys_mapping or {}
|
||||
self.train_on_all_features = train_on_all_features
|
||||
self.max_action_dim = max_action_dim
|
||||
self.max_state_dim = max_state_dim
|
||||
self.max_image_dim = max_image_dim
|
||||
self.max_num_images = max_num_images # (optional, we don’t enforce count, we enforce names)
|
||||
self._ignore_patterns = list(ignore_keys or [])
|
||||
# Build underlying single datasets
|
||||
_datasets = []
|
||||
datasets_repo_ids = []
|
||||
self.sampling_weights = []
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
self.disabled_features = set()
|
||||
intersection_features = set(self._datasets[0].features)
|
||||
for ds in self._datasets:
|
||||
intersection_features.intersection_update(ds.features)
|
||||
if len(intersection_features) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. "
|
||||
"The multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
sampling_weights = sampling_weights if sampling_weights is not None else [1] * len(repo_ids)
|
||||
assert len(sampling_weights) == len(repo_ids), (
|
||||
"The number of sampling weights must match the number of datasets. "
|
||||
f"Got {len(sampling_weights)} weights for {len(repo_ids)} datasets."
|
||||
)
|
||||
for i, repo_id in enumerate(repo_ids):
|
||||
try:
|
||||
_datasets.append(
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=self.root / repo_id,
|
||||
episodes=episodes.get(repo_id, None) if episodes else None,
|
||||
image_transforms=image_transforms, # transforms applied inside single ds
|
||||
delta_timestamps=delta_timestamps.get(repo_id, None) if delta_timestamps else None,
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
)
|
||||
datasets_repo_ids.append(repo_id)
|
||||
self.sampling_weights.append(float(sampling_weights[i]))
|
||||
except Exception as e:
|
||||
print(f"Failed to load dataset: {repo_id} due to Exception: {e}")
|
||||
|
||||
print(
|
||||
f"Finish loading {len(_datasets)} datasets, with sampling weights: "
|
||||
f"{self.sampling_weights} corresponding to: {datasets_repo_ids}"
|
||||
)
|
||||
|
||||
# Bookkeeping for mapping & canonical image inventory
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps.get(repo_id, None) if delta_timestamps else None
|
||||
self._datasets = _datasets
|
||||
self.datasets_repo_ids = datasets_repo_ids
|
||||
|
||||
# --- NEW: compute “canonical image keys” (targets across all mappings) ---
|
||||
self._canonical_image_keys: set[str] = set()
|
||||
self._source_keys_per_repo: dict[str, set[str]] = {}
|
||||
self._target_keys_per_repo: dict[str, set[str]] = {}
|
||||
for rid, mapping in self.feature_keys_mapping.items():
|
||||
src_keys = set(mapping.keys())
|
||||
tgt_keys = set(mapping.values())
|
||||
self._source_keys_per_repo[rid] = src_keys
|
||||
self._target_keys_per_repo[rid] = tgt_keys
|
||||
# union of target names (we will ensure these exist at __getitem__)
|
||||
self._canonical_image_keys |= {
|
||||
k for k in tgt_keys if self._is_image_key_like(k)
|
||||
}
|
||||
|
||||
# If user didn’t give any mapping, fall back to native keys (no-ops)
|
||||
if not self._canonical_image_keys and self.train_on_all_features:
|
||||
# discover all image-like keys from raw features
|
||||
for ds in self._datasets:
|
||||
for k, v in ds.hf_features.items():
|
||||
if isinstance(v, (datasets.Image, VideoFrame)):
|
||||
self._canonical_image_keys.add(k)
|
||||
|
||||
# Cleaner: keep fps & consistent feature sets per robot type (unchanged)
|
||||
cleaner = MultiLeRobotDatasetCleaner(
|
||||
datasets=self._datasets,
|
||||
repo_ids=repo_ids,
|
||||
sampling_weights=self.sampling_weights,
|
||||
datasets_repo_ids=self.datasets_repo_ids,
|
||||
min_fps=min_fps,
|
||||
max_fps=max_fps,
|
||||
)
|
||||
self._datasets = cleaner.cleaned_datasets
|
||||
self.sampling_weights = cleaner.cleaned_weights
|
||||
self.repo_ids = cleaner.cleaned_repo_ids
|
||||
self.datasets_repo_ids = cleaner.cleaned_datasets_repo_ids
|
||||
self.cumulative_sizes = cleaner.cumulative_sizes
|
||||
|
||||
# Meta (unchanged): we give it dim maxima; it will reshape/pad vectors
|
||||
self.meta = MultiLeRobotDatasetMeta(
|
||||
datasets=self._datasets,
|
||||
repo_ids=self.repo_ids,
|
||||
keys_to_max_dim={
|
||||
ACTION: self.max_action_dim if self.max_action_dim is not None else -1,
|
||||
OBS_ENV_STATE: self.max_state_dim if self.max_state_dim is not None else -1,
|
||||
OBS_STATE: self.max_state_dim if self.max_state_dim is not None else -1,
|
||||
OBS_IMAGE: self.max_image_dim if self.max_image_dim is not None else -1,
|
||||
OBS_IMAGE_2: self.max_image_dim if self.max_image_dim is not None else -1,
|
||||
OBS_IMAGE_3: self.max_image_dim if self.max_image_dim is not None else -1,
|
||||
},
|
||||
train_on_all_features=train_on_all_features,
|
||||
)
|
||||
|
||||
# --- NEW: track dropped (source) keys so collate won’t expect them
|
||||
# Anything that we *rename away* should be considered disabled,
|
||||
# otherwise downstream may expect them to exist.
|
||||
self._dropped_keys = set()
|
||||
for rid, mapping in self.feature_keys_mapping.items():
|
||||
self._dropped_keys |= set(mapping.keys())
|
||||
|
||||
# Merge with meta’s disabled features
|
||||
self.disabled_features = set(self.meta.disabled_features) | self._dropped_keys
|
||||
|
||||
self.stats = self.meta.stats
|
||||
|
||||
# --- NEW: cache an example image shape per canonical key (lazy, filled on first use)
|
||||
self._cached_img_shape: dict[str, torch.Size] = {}
|
||||
|
||||
# ---------------------- NEW small helpers ----------------------
|
||||
|
||||
def _is_image_key_like(self, key: str) -> bool:
|
||||
# A loose heuristic: rely on name OR on features later
|
||||
return ("image" in key) or ("cam_" in key) or ("images." in key)
|
||||
|
||||
def _should_ignore(self, key: str) -> bool:
|
||||
# exact or glob-style match
|
||||
for pat in self._ignore_patterns:
|
||||
if key == pat or fnmatch.fnmatch(key, pat):
|
||||
return True
|
||||
return False
|
||||
def _apply_feature_mapping(self, item: dict, repo_id: str) -> dict:
|
||||
"""
|
||||
Rename features according to feature_keys_mapping[repo_id].
|
||||
- Moves tensor/image under target key.
|
||||
- Drops source key if moved.
|
||||
- Adds *_is_pad=False for image targets we fill/keep.
|
||||
"""
|
||||
mapping = self.feature_keys_mapping.get(repo_id, {}) or {}
|
||||
if not mapping:
|
||||
return item
|
||||
|
||||
for src, tgt in mapping.items():
|
||||
if src in item:
|
||||
# Move value
|
||||
item[tgt] = item[src]
|
||||
# Drop the source to avoid duplication
|
||||
del item[src]
|
||||
return item
|
||||
|
||||
def _ensure_union_image_keys(self, item: dict) -> dict:
|
||||
"""
|
||||
Ensure that every canonical image key exists.
|
||||
When missing, create a zero tensor matching (B,C,H,W) or (C,H,W) of an available image.
|
||||
Also add boolean mask at f"{key}_is_pad".
|
||||
"""
|
||||
if not self.train_on_all_features or not self._canonical_image_keys:
|
||||
return item
|
||||
|
||||
# find any existing image tensor in item to copy shape/dtype
|
||||
exemplar = None
|
||||
for k in list(item.keys()):
|
||||
v = item[k]
|
||||
if torch.is_tensor(v) and v.ndim in (3, 4, 5): # (C,H,W) or (B,C,H,W) or (B,T,C,H,W)
|
||||
exemplar = v
|
||||
break
|
||||
|
||||
# fallback to a safe 3x224x224 if nothing found
|
||||
def _fallback_image():
|
||||
return torch.zeros(3, 224, 224, dtype=torch.uint8)
|
||||
|
||||
for key in self._canonical_image_keys:
|
||||
if key not in item:
|
||||
img = torch.zeros_like(exemplar) if exemplar is not None else _fallback_image()
|
||||
item[key] = img
|
||||
item[f"{key}_is_pad"] = torch.tensor(True, dtype=torch.bool)
|
||||
else:
|
||||
# Add a mask saying it’s *not* padded
|
||||
if f"{key}_is_pad" not in item:
|
||||
item[f"{key}_is_pad"] = torch.tensor(False, dtype=torch.bool)
|
||||
return item
|
||||
|
||||
# ---------------------- existing API below (mostly unchanged) ----------------------
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
|
||||
# with multiple robots of different ranges. Instead we should have one normalization
|
||||
# per robot.
|
||||
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||
|
||||
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
||||
"""
|
||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||
|
||||
@property
|
||||
def repo_index_to_id(self):
|
||||
"""Return the inverse mapping if repo_id_to_index."""
|
||||
return {v: k for k, v in self.repo_id_to_index}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
"""Returns True if this dataset loads video frames from mp4 files.
|
||||
|
||||
Returns False if it only loads images from png files.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].meta.info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
"""
|
||||
Extend native HF features with any *target* keys introduced by mapping.
|
||||
We copy the source spec for targets that didn’t exist in any raw dataset.
|
||||
"""
|
||||
features: dict[str, datasets.features.Feature] = {}
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
for k, v in dataset.hf_features.items():
|
||||
if k not in self.disabled_features:
|
||||
features[k] = v
|
||||
|
||||
# Add mapped target image specs if not present yet
|
||||
for rid, mapping in self.feature_keys_mapping.items():
|
||||
ds = None
|
||||
# find the dataset object to read feature spec for source
|
||||
for _ds, _rid in zip(self._datasets, self.repo_ids, strict=False):
|
||||
if _rid == rid:
|
||||
ds = _ds
|
||||
break
|
||||
if ds is None:
|
||||
continue
|
||||
for src, tgt in mapping.items():
|
||||
if tgt not in features and src in ds.hf_features:
|
||||
features[tgt] = ds.hf_features[src]
|
||||
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
return features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||
@@ -1723,6 +1437,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames that requires to be decoded into images.
|
||||
|
||||
Note: It is empty if the dataset contains images only,
|
||||
or equal to `self.cameras` if the dataset contains videos only,
|
||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||
"""
|
||||
video_frame_keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
@@ -1731,14 +1451,21 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return sum(d.num_frames for d in self._datasets)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes."""
|
||||
return sum(d.num_episodes for d in self._datasets)
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||
is provided or when loading video frames from mp4 files.
|
||||
"""
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
@@ -1747,83 +1474,22 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
dataset_idx = np.searchsorted(self.cumulative_sizes, idx, side="right").item() - 1
|
||||
local_idx = (idx - self.cumulative_sizes[dataset_idx]).item()
|
||||
item = self._datasets[dataset_idx][local_idx]
|
||||
|
||||
# Identify which repo this sample came from
|
||||
repo_id = self.datasets_repo_ids[dataset_idx]
|
||||
|
||||
# --- NEW: apply mapping and ensure union of image keys ---
|
||||
item = self._apply_feature_mapping(item, repo_id)
|
||||
item = self._ensure_union_image_keys(item)
|
||||
|
||||
# annotate dataset index for downstream
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_frames:
|
||||
start_idx += dataset.num_frames
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
|
||||
# Pad vector features to max dims using meta (unchanged)
|
||||
item = create_padded_features(item, self.meta.features)
|
||||
|
||||
# Drop any disabled (including original source keys we remapped away)
|
||||
for data_key in self.disabled_features:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
for k in IGNORED_KEYS:
|
||||
if k in item:
|
||||
item.pop(k)
|
||||
# Convert any datasets.Image still present to tensor
|
||||
if self.image_transforms is not None:
|
||||
for cam in [k for k in item.keys() if self._is_image_key_like(k)]:
|
||||
val = item[cam]
|
||||
if not torch.is_tensor(val):
|
||||
item[cam] = self.image_transforms(val)
|
||||
# 🔑 Pad actions if too short
|
||||
if "actions" in item and self.max_action_dim is not None:
|
||||
act = item["actions"]
|
||||
if act.shape[-1] < self.max_action_dim:
|
||||
pad_len = self.max_action_dim - act.shape[-1]
|
||||
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1)
|
||||
item["actions_padding_mask"] = torch.cat(
|
||||
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# pad obs_state if too short
|
||||
if "obs_state" in item and self.max_state_dim is not None:
|
||||
st = item["obs_state"]
|
||||
if st.shape[-1] < self.max_state_dim:
|
||||
pad_len = self.max_state_dim - st.shape[-1]
|
||||
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1)
|
||||
item["obs_state_padding_mask"] = torch.cat(
|
||||
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
|
||||
dim=-1,
|
||||
)
|
||||
# actions
|
||||
if "actions" in item and self.max_action_dim is not None:
|
||||
act = item["actions"]
|
||||
if act.shape[-1] < self.max_action_dim:
|
||||
pad_len = self.max_action_dim - act.shape[-1]
|
||||
item["actions"] = torch.cat([act, torch.zeros(pad_len, dtype=act.dtype)], dim=-1)
|
||||
mask = torch.cat(
|
||||
[torch.zeros_like(act, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
mask = torch.zeros(self.max_action_dim, dtype=torch.bool) # 👈 all False if no padding
|
||||
item["actions_padding_mask"] = mask
|
||||
# obs state
|
||||
if "obs_state" in item and self.max_state_dim is not None:
|
||||
st = item["obs_state"]
|
||||
if st.shape[-1] < self.max_state_dim:
|
||||
pad_len = self.max_state_dim - st.shape[-1]
|
||||
item["obs_state"] = torch.cat([st, torch.zeros(pad_len, dtype=st.dtype)], dim=-1)
|
||||
mask = torch.cat(
|
||||
[torch.zeros_like(st, dtype=torch.bool), torch.ones(pad_len, dtype=torch.bool)],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
mask = torch.zeros(self.max_state_dim, dtype=torch.bool) # 👈 always add mask
|
||||
item["obs_state_padding_mask"] = mask
|
||||
|
||||
return item
|
||||
|
||||
@@ -1840,149 +1506,3 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
def keep_datasets_with_the_same_features_per_robot_type(ls_datasets: list) -> list:
|
||||
"""
|
||||
Filters datasets to only keep those with consistent feature shapes per robot type.
|
||||
|
||||
Args:
|
||||
ls_datasets (List): List of datasets, each with a `meta.info['robot_type']`
|
||||
and `meta.episodes_stats` dictionary.
|
||||
|
||||
Returns:
|
||||
List: Filtered list of datasets with consistent feature shapes.
|
||||
"""
|
||||
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
|
||||
datasets_to_remove = set()
|
||||
|
||||
for robot_type in robot_types:
|
||||
# Collect all stats dicts for this robot type
|
||||
stats_list = [
|
||||
ep_stats
|
||||
for ds in ls_datasets
|
||||
if ds.meta.info["robot_type"] == robot_type
|
||||
for ep_stats in episode_stats_values(ds.meta)
|
||||
]
|
||||
if not stats_list:
|
||||
continue
|
||||
|
||||
# Determine the most common shape for each key
|
||||
all_keys = {key for stats in stats_list for key in stats}
|
||||
for ds in ls_datasets:
|
||||
if ds.meta.info["robot_type"] != robot_type:
|
||||
continue
|
||||
for key in all_keys:
|
||||
shape_counter = defaultdict(int)
|
||||
|
||||
for stats in stats_list:
|
||||
value = stats.get(key)
|
||||
if (
|
||||
value and "mean" in value and isinstance(value["mean"], (torch.Tensor, np.ndarray))
|
||||
): # FIXME(mshukor): check all stats; min, mean, max
|
||||
shape_counter[value["mean"].shape] += 1
|
||||
if not shape_counter:
|
||||
continue
|
||||
|
||||
# Identify the most frequent shape
|
||||
main_shape = max(shape_counter, key=shape_counter.get)
|
||||
# Flag datasets that don't match the main shape
|
||||
# for ds in ls_datasets:
|
||||
first_ep_stats = next(iter(episode_stats_values(ds.meta)), None)
|
||||
if not first_ep_stats:
|
||||
continue
|
||||
value = first_ep_stats.get(key)
|
||||
if (
|
||||
value
|
||||
and "mean" in value
|
||||
and isinstance(value["mean"], (torch.Tensor, np.ndarray))
|
||||
and value["mean"].shape != main_shape
|
||||
):
|
||||
datasets_to_remove.add(ds)
|
||||
break
|
||||
|
||||
# Filter out inconsistent datasets
|
||||
datasets_maks = [ds not in datasets_to_remove for ds in ls_datasets]
|
||||
filtered_datasets = [ds for ds in ls_datasets if ds not in datasets_to_remove]
|
||||
print(
|
||||
f"Keeping {len(filtered_datasets)} datasets. Removed {len(datasets_to_remove)} inconsistent ones. Inconsistent datasets:\n{datasets_to_remove}"
|
||||
)
|
||||
return filtered_datasets, datasets_maks
|
||||
|
||||
|
||||
def aggregate_stats_per_robot_type(ls_datasets) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into multiple set of stats per robot type.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
|
||||
robot_types = {ds.meta.info["robot_type"] for ds in ls_datasets}
|
||||
stats = {robot_type: {} for robot_type in robot_types}
|
||||
for robot_type in robot_types:
|
||||
robot_type_datasets = []
|
||||
for ds in ls_datasets:
|
||||
if ds.meta.info["robot_type"] == robot_type:
|
||||
robot_type_datasets.extend(list(episode_stats_values(ds.meta)))
|
||||
# robot_type_datasets = [list(ds.episodes_stats.values()) for ds in ls_datasets if ds.meta.info["robot_type"] == robot_type]
|
||||
stat = aggregate_stats(robot_type_datasets)
|
||||
stats[robot_type] = stat
|
||||
return stats
|
||||
|
||||
def reshape_features_to_max_dim(features: dict, reshape_dim: int = -1, keys_to_max_dim: dict = {}) -> dict:
|
||||
"""Reshape features to have a maximum dimension of `max_dim`."""
|
||||
reshaped_features = {}
|
||||
for key in features:
|
||||
if key in keys_to_max_dim and keys_to_max_dim[key] is not None:
|
||||
reshaped_features[key] = features[key]
|
||||
shape = list(features[key]["shape"])
|
||||
if any([k in key for k in [OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGE_3]]): # Assume square images
|
||||
shape[-3] = keys_to_max_dim[key]
|
||||
shape[-2] = keys_to_max_dim[key]
|
||||
else:
|
||||
shape[reshape_dim] = keys_to_max_dim[key]
|
||||
reshaped_features[key]["shape"] = tuple(shape)
|
||||
else:
|
||||
reshaped_features[key] = features[key]
|
||||
return reshaped_features
|
||||
|
||||
def create_padded_features(item: dict, features: dict = {}):
|
||||
for key, ft in features.items():
|
||||
if any([k in key for k in ["cam", "effort", "absolute"]]): # FIXME(mshukor): temporary hack
|
||||
continue
|
||||
shape = ft["shape"]
|
||||
if len(shape) == 3: # images to torch format (C, H, W)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
if len(shape) == 1 and shape[0] == 1: # ft with shape are actually tensor(ele)
|
||||
shape = []
|
||||
if key not in item:
|
||||
dtype = str_to_torch_dtype(ft["dtype"])
|
||||
item[key] = torch.zeros(shape, dtype=dtype)
|
||||
item[f"{key}_padding_mask"] = torch.tensor(0, dtype=torch.int64)
|
||||
if "image" in key: # FIXME(mshukor): support other observations
|
||||
item[f"{key}_is_pad"] = torch.BoolTensor([False])
|
||||
else:
|
||||
item[f"{key}_padding_mask"] = torch.tensor(1, dtype=torch.int64)
|
||||
return item
|
||||
|
||||
def str_to_torch_dtype(dtype_str):
|
||||
"""Convert a dtype string to a torch dtype."""
|
||||
mapping = {
|
||||
"float32": torch.float32,
|
||||
"int64": torch.int64,
|
||||
"int16": torch.int16,
|
||||
"bool": torch.bool,
|
||||
"video": torch.float32, # Assuming video is stored as uint8 images
|
||||
}
|
||||
return mapping.get(dtype_str, torch.float32) # Default to float32
|
||||
|
||||
def episode_stats_values(meta):
|
||||
episodes = meta.episodes.to_pandas().to_dict(orient="records")
|
||||
return [
|
||||
{k: v for k, v in ep.items() if isinstance(v, dict) and "mean" in v}
|
||||
for ep in episodes
|
||||
]
|
||||
|
||||
@@ -118,7 +118,7 @@ echo ${HF_USER}/aloha_test
|
||||
If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with [Rerun](https://github.com/rerun-io/rerun):
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.visualize_dataset \
|
||||
lerobot-dataset-viz \
|
||||
--repo-id ${HF_USER}/aloha_test --episode 0
|
||||
```
|
||||
|
||||
|
||||
@@ -29,14 +29,14 @@ Examples:
|
||||
|
||||
- Visualize data stored on a local machine:
|
||||
```
|
||||
local$ python -m lerobot.scripts.visualize_dataset \
|
||||
local$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
- Visualize data stored on a distant machine with a local viewer:
|
||||
```
|
||||
distant$ python -m lerobot.scripts.visualize_dataset \
|
||||
distant$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--save 1 \
|
||||
@@ -50,7 +50,7 @@ local$ rerun lerobot_pusht_episode_0.rrd
|
||||
(You need to forward the websocket port to the distant machine, with
|
||||
`ssh -L 9087:localhost:9087 username@remote-host`)
|
||||
```
|
||||
distant$ python -m lerobot.scripts.visualize_dataset \
|
||||
distant$ lerobot-dataset-viz \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0 \
|
||||
--mode distant \
|
||||
@@ -20,10 +20,10 @@ Additionally, each individual transform can be visualized separately as well as
|
||||
|
||||
Example:
|
||||
```bash
|
||||
python -m lerobot.scripts.visualize_image_transforms \
|
||||
--repo_id=lerobot/pusht \
|
||||
--episodes='[0]' \
|
||||
--image_transforms.enable=True
|
||||
lerobot-imgtransform-viz \
|
||||
--repo_id=lerobot/pusht \
|
||||
--episodes='[0]' \
|
||||
--image_transforms.enable=True
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -126,5 +126,9 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR
|
||||
save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
visualize_image_transforms()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -29,7 +29,7 @@ from lerobot.datasets.transforms import (
|
||||
SharpnessJitter,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.scripts.visualize_image_transforms import (
|
||||
from lerobot.scripts.lerobot_imgtransform_viz import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
from lerobot.scripts.lerobot_dataset_viz import visualize_dataset
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO: add dummy videos")
|
||||
|
||||
Reference in New Issue
Block a user