mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
Compare commits
11 Commits
feat/unitr
...
feat/datas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
026e4c937d | ||
|
|
efe8c09fca | ||
|
|
58eecad8a4 | ||
|
|
c7fd1f47d1 | ||
|
|
6370949e5c | ||
|
|
46b97da168 | ||
|
|
e69be57a66 | ||
|
|
c3a6ddb668 | ||
|
|
dad661012d | ||
|
|
219c08ccb8 | ||
|
|
06385902df |
@@ -19,8 +19,6 @@
|
||||
title: Multi GPU training
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
|
||||
@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
# Rename Map and Empty Cameras
|
||||
|
||||
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
|
||||
|
||||
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
|
||||
|
||||
## Why observation keys don't always match
|
||||
|
||||
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
|
||||
|
||||
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
|
||||
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
|
||||
|
||||
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
|
||||
|
||||
## Using the rename map
|
||||
|
||||
Pass the mapping as a JSON string on the command line. The convention is always:
|
||||
|
||||
```
|
||||
--rename_map='{"source_key": "policy_key", ...}'
|
||||
```
|
||||
|
||||
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
|
||||
|
||||
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
|
||||
|
||||
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
|
||||
|
||||
### Training
|
||||
|
||||
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/xvla_training \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.action_mode=auto \
|
||||
--steps=20000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0fast-libero \
|
||||
--env.type=libero \
|
||||
... \
|
||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||
```
|
||||
|
||||
### Recording
|
||||
|
||||
`lerobot-record` also supports rename maps, nested under the dataset config:
|
||||
|
||||
```bash
|
||||
lerobot-record \ # When running inference
|
||||
--policy.path="<user>/smolVLA_finetuned" \
|
||||
... \
|
||||
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
||||
```
|
||||
|
||||
## Alternative: edit the policy config directly
|
||||
|
||||
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||
|
||||
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
|
||||
|
||||
## Empty cameras: fewer views than the policy expects
|
||||
|
||||
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
|
||||
|
||||
### How it works
|
||||
|
||||
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
|
||||
|
||||
```
|
||||
observation.images.empty_camera_0
|
||||
observation.images.empty_camera_1
|
||||
...
|
||||
```
|
||||
|
||||
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
|
||||
|
||||
### Example
|
||||
|
||||
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
|
||||
|
||||
1. Set `--policy.empty_cameras=1`.
|
||||
2. The config adds a third key: `observation.images.empty_camera_0`.
|
||||
3. Use the rename map for your two real cameras as usual.
|
||||
4. The third slot is masked out — no fake images needed in your dataset.
|
||||
|
||||
## Quick reference
|
||||
|
||||
| Goal | What to do |
|
||||
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||
717
examples/dataset/visualization_tools/action_consistency.py
Normal file
717
examples/dataset/visualization_tools/action_consistency.py
Normal file
@@ -0,0 +1,717 @@
|
||||
"""
|
||||
Action consistency analysis for imitation learning datasets.
|
||||
|
||||
Two parallel analyses per dataset:
|
||||
1. State-based: KNN in joint-state space → action chunk variance
|
||||
2. Image-based: KNN in SigLIP embedding space → action chunk variance
|
||||
|
||||
Comparing them reveals whether visual similarity and proprioceptive similarity
|
||||
agree on where the data is inconsistent — and images are what the policy
|
||||
primarily sees.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from PIL import Image
|
||||
from scipy.spatial import cKDTree
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
MAX_FRAMES = 100_000
|
||||
K_NEIGHBORS = 50
|
||||
ACTION_CHUNK_SIZE = 30
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
ENCODER_MODEL = "google/siglip-base-patch16-224"
|
||||
ENCODE_BATCH_SIZE = 512
|
||||
SEED = 42
|
||||
DPI = 150
|
||||
|
||||
CONSISTENCY_CMAP = LinearSegmentedColormap.from_list(
|
||||
"consistency", ["#0a2e0a", "#1a8e1a", "#88cc22", "#ffaa22", "#ff2222"]
|
||||
)
|
||||
|
||||
# FK chains from OpenArm bimanual URDF (same as workspace_density.py).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data helpers ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str, camera_key: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=[
|
||||
"meta/**",
|
||||
"data/**",
|
||||
f"videos/{camera_key}/**",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_action_chunks(
|
||||
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
For each frame, concatenate the next chunk_size actions from the same episode.
|
||||
Returns (action_chunks, valid_mask).
|
||||
"""
|
||||
n = len(actions)
|
||||
act_dim = actions.shape[1]
|
||||
chunks = np.zeros((n, chunk_size * act_dim), dtype=np.float64)
|
||||
valid = np.zeros(n, dtype=bool)
|
||||
|
||||
for i in range(n):
|
||||
end = i + chunk_size
|
||||
if end > n:
|
||||
continue
|
||||
if episode_ids[i] != episode_ids[end - 1]:
|
||||
continue
|
||||
chunks[i] = actions[i:end].ravel()
|
||||
valid[i] = True
|
||||
|
||||
return chunks, valid
|
||||
|
||||
|
||||
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
|
||||
"""
|
||||
Load observation.state and action, build action chunks, subsample, normalize.
|
||||
Also returns the original row indices (`chosen_idx`) for video frame mapping.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
n_total = len(df)
|
||||
print(f" Total frames: {n_total:,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
action_col = next((c for c in df.columns if c == "action"), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
if action_col is None:
|
||||
raise RuntimeError(f"No action column. Available: {list(df.columns)}")
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
state_all = np.stack(df[state_col].values).astype(np.float64)
|
||||
action_all = np.stack(df[action_col].values).astype(np.float64)
|
||||
episode_all = df[ep_col].values.astype(np.int64)
|
||||
|
||||
n_dim = state_all.shape[1]
|
||||
act_dim = action_all.shape[1]
|
||||
print(f" State dim: {n_dim} Action dim: {act_dim} Chunk size: {chunk_size}")
|
||||
print(f" Action chunk dim: {chunk_size * act_dim}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
print(" Building action chunks …")
|
||||
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
|
||||
valid_idx = np.where(valid)[0]
|
||||
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
|
||||
|
||||
if len(valid_idx) > max_frames:
|
||||
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
|
||||
else:
|
||||
chosen = valid_idx
|
||||
print(f" Using {len(chosen):,} frames")
|
||||
|
||||
state_raw = state_all[chosen]
|
||||
action_raw = action_chunks[chosen]
|
||||
episode_ids = episode_all[chosen]
|
||||
|
||||
state_mean = state_raw.mean(axis=0)
|
||||
state_std = state_raw.std(axis=0)
|
||||
state_std[state_std < 1e-8] = 1.0
|
||||
state_norm = (state_raw - state_mean) / state_std
|
||||
|
||||
action_mean = action_raw.mean(axis=0)
|
||||
action_std = action_raw.std(axis=0)
|
||||
action_std[action_std < 1e-8] = 1.0
|
||||
action_norm = (action_raw - action_mean) / action_std
|
||||
|
||||
return {
|
||||
"state_raw": state_raw,
|
||||
"state_norm": state_norm,
|
||||
"action_raw": action_raw,
|
||||
"action_norm": action_norm,
|
||||
"episode_ids": episode_ids,
|
||||
"episode_all": episode_all,
|
||||
"left_joint_idx": left_idx,
|
||||
"right_joint_idx": right_idx,
|
||||
"n_total": n_total,
|
||||
"chosen_idx": chosen,
|
||||
"df": df,
|
||||
}
|
||||
|
||||
|
||||
# ── Video → frame extraction ──────────────────────────────
|
||||
|
||||
|
||||
def build_video_lookup(local: Path, camera_key: str) -> dict:
|
||||
"""
|
||||
Build a mapping from episode_index → {video_path, fps, from_ts}.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
chunk_col = f"videos/{camera_key}/chunk_index"
|
||||
file_col = f"videos/{camera_key}/file_index"
|
||||
ts_from = f"videos/{camera_key}/from_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{camera_key}/chunk_index"
|
||||
file_col = f"{camera_key}/file_index"
|
||||
ts_from = f"{camera_key}/from_timestamp"
|
||||
|
||||
lookup: dict[int, dict] = {}
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
|
||||
lookup[int(row["episode_index"])] = {
|
||||
"video_path": local / video_rel,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"fps": fps,
|
||||
}
|
||||
return lookup
|
||||
|
||||
|
||||
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
|
||||
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
|
||||
container = av.open(video_path)
|
||||
stream = container.streams.video[0]
|
||||
stream.thread_type = "AUTO"
|
||||
decoded = []
|
||||
for frame in container.decode(stream):
|
||||
decoded.append(frame.to_ndarray(format="rgb24"))
|
||||
container.close()
|
||||
return decoded
|
||||
|
||||
|
||||
def extract_frames(
|
||||
chosen_idx: np.ndarray,
|
||||
episode_all: np.ndarray,
|
||||
video_lookup: dict,
|
||||
) -> list[np.ndarray | None]:
|
||||
"""
|
||||
Extract RGB frames for each chosen global index using PyAV.
|
||||
Returns list of (H, W, 3) RGB arrays (or None on failure).
|
||||
"""
|
||||
unique_eps = np.unique(episode_all)
|
||||
ep_start: dict[int, int] = {}
|
||||
for ep in unique_eps:
|
||||
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
||||
|
||||
# Build jobs: (output_index, video_path, local_frame_number)
|
||||
jobs: list[tuple[int, str, int]] = []
|
||||
for out_i, global_i in enumerate(chosen_idx):
|
||||
ep = int(episode_all[global_i])
|
||||
info = video_lookup.get(ep)
|
||||
if info is None:
|
||||
continue
|
||||
local_frame = global_i - ep_start[ep]
|
||||
jobs.append((out_i, str(info["video_path"]), local_frame))
|
||||
|
||||
# Group by video file, decode each video once
|
||||
from collections import defaultdict
|
||||
|
||||
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
|
||||
for out_i, vpath, local_frame in jobs:
|
||||
video_jobs[vpath].append((out_i, local_frame))
|
||||
|
||||
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
|
||||
extracted = 0
|
||||
n_videos = len(video_jobs)
|
||||
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
|
||||
if not Path(vpath).exists():
|
||||
continue
|
||||
try:
|
||||
decoded = _decode_video_frames(vpath)
|
||||
except Exception as exc:
|
||||
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
|
||||
continue
|
||||
for out_i, local_frame in frame_requests:
|
||||
if 0 <= local_frame < len(decoded):
|
||||
frames[out_i] = decoded[local_frame]
|
||||
extracted += 1
|
||||
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
|
||||
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
|
||||
del decoded
|
||||
|
||||
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
||||
return frames
|
||||
|
||||
|
||||
# ── SigLIP encoding ─────────────────────────────────────
|
||||
|
||||
|
||||
def encode_frames_siglip(
|
||||
frames: list[np.ndarray | None],
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Encode RGB frames through SigLIP vision encoder.
|
||||
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
||||
"""
|
||||
print(f" Loading SigLIP model: {model_name} …")
|
||||
processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
embed_dim = model.config.vision_config.hidden_size
|
||||
|
||||
n = len(frames)
|
||||
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
||||
|
||||
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
||||
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
||||
|
||||
for batch_start in range(0, len(valid_indices), batch_size):
|
||||
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
||||
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
|
||||
|
||||
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model.get_image_features(**inputs)
|
||||
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
||||
embeddings[batch_idx] = image_features.cpu().numpy()
|
||||
|
||||
done = min(batch_start + batch_size, len(valid_indices))
|
||||
if done % (batch_size * 10) == 0 or done == len(valid_indices):
|
||||
print(f" {done:,} / {len(valid_indices):,} encoded")
|
||||
|
||||
del model, processor
|
||||
torch.cuda.empty_cache()
|
||||
return embeddings
|
||||
|
||||
|
||||
# ── KNN consistency ─────────────────────────────────────
|
||||
|
||||
|
||||
def compute_consistency(
|
||||
features: np.ndarray,
|
||||
action_norm: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
k: int,
|
||||
label: str = "",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
For each frame, find K nearest neighbors in feature space from other episodes.
|
||||
Return per-frame action variance (mean across action dims).
|
||||
"""
|
||||
n = len(features)
|
||||
print(f" Building KD-tree on {n:,} vectors ({label}) …")
|
||||
tree = cKDTree(features)
|
||||
|
||||
k_query = min(k * 3, n - 1)
|
||||
print(f" Querying {k_query} neighbors per frame …")
|
||||
_dists, indices = tree.query(features, k=k_query + 1)
|
||||
indices = indices[:, 1:]
|
||||
|
||||
print(f" Computing cross-episode action variance ({label}) …")
|
||||
variance = np.zeros(n)
|
||||
for i in range(n):
|
||||
ep_i = episode_ids[i]
|
||||
neighbors = indices[i]
|
||||
cross_ep = neighbors[episode_ids[neighbors] != ep_i][:k]
|
||||
if len(cross_ep) < 2:
|
||||
variance[i] = 0.0
|
||||
continue
|
||||
neighbor_actions = action_norm[cross_ep]
|
||||
variance[i] = np.mean(np.var(neighbor_actions, axis=0))
|
||||
|
||||
return variance
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
|
||||
def _style_ax(ax: plt.Axes) -> None:
|
||||
ax.set_facecolor("#0d1117")
|
||||
ax.tick_params(colors="#555", labelsize=8)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
|
||||
|
||||
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
|
||||
_style_ax(ax)
|
||||
median_var = np.median(variance)
|
||||
mean_var = np.mean(variance)
|
||||
nonzero = variance[variance > 0]
|
||||
if len(nonzero) > 0:
|
||||
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
|
||||
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
|
||||
ax.set_xscale("log")
|
||||
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
|
||||
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
|
||||
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Frame count", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_episode_curves(
|
||||
ax: plt.Axes,
|
||||
var_state: np.ndarray,
|
||||
var_image: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
title: str,
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
unique_eps = np.unique(episode_ids)
|
||||
|
||||
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
|
||||
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
|
||||
|
||||
sorted_s = np.sort(ep_means_s)[::-1]
|
||||
sorted_i = np.sort(ep_means_i)[::-1]
|
||||
ep_x = np.arange(len(unique_eps))
|
||||
|
||||
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
|
||||
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
|
||||
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
|
||||
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
|
||||
|
||||
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_heatmap(
|
||||
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
order = np.argsort(variance)
|
||||
pts = tcp_xz[order]
|
||||
var_sorted = variance[order]
|
||||
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
|
||||
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
|
||||
sc = ax.scatter(
|
||||
pts[:, 0],
|
||||
pts[:, 1],
|
||||
c=var_sorted,
|
||||
cmap=CONSISTENCY_CMAP,
|
||||
s=0.5,
|
||||
alpha=0.6,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
rasterized=True,
|
||||
)
|
||||
ax.set_xlabel("X (m)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.set_aspect("equal")
|
||||
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||||
cbar.set_label("Action variance", color="white", fontsize=9)
|
||||
cbar.ax.tick_params(colors="#aaa", labelsize=7)
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
4-row x N-column figure:
|
||||
Row 0: State-based variance histogram
|
||||
Row 1: Image-based variance histogram
|
||||
Row 2: Per-episode curves (both overlaid)
|
||||
Row 3: Spatial heatmap (image-based variance)
|
||||
"""
|
||||
n_ds = len(results)
|
||||
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[:, np.newaxis]
|
||||
|
||||
headline_parts = []
|
||||
for col, r in enumerate(results):
|
||||
label = r["label"]
|
||||
var_s = r["var_state"]
|
||||
var_i = r["var_image"]
|
||||
tcp_xz = r["tcp_xz"]
|
||||
episode_ids = r["episode_ids"]
|
||||
|
||||
med_s = np.median(var_s)
|
||||
med_i = np.median(var_i)
|
||||
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
|
||||
|
||||
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
|
||||
_plot_histogram(
|
||||
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
|
||||
)
|
||||
_plot_episode_curves(
|
||||
axes[2, col],
|
||||
var_s,
|
||||
var_i,
|
||||
episode_ids,
|
||||
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
|
||||
)
|
||||
_plot_heatmap(
|
||||
axes[3, col],
|
||||
fig,
|
||||
tcp_xz,
|
||||
var_i,
|
||||
f"{label}\nImage-based variance by TCP position (XZ)",
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
|
||||
+ " | ".join(headline_parts),
|
||||
color="white",
|
||||
fontsize=15,
|
||||
y=0.99,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Device: {device}")
|
||||
rng = np.random.default_rng(SEED)
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id, CAMERA_KEY)
|
||||
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
|
||||
|
||||
# --- State-based KNN ---
|
||||
var_state = compute_consistency(
|
||||
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
|
||||
)
|
||||
print(
|
||||
f" State variance: median={np.median(var_state):.4f} "
|
||||
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
|
||||
)
|
||||
|
||||
# --- Image-based KNN ---
|
||||
print("\n Preparing image embeddings …")
|
||||
video_lookup = build_video_lookup(local, CAMERA_KEY)
|
||||
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
|
||||
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
|
||||
del frames # free memory
|
||||
|
||||
var_image = compute_consistency(
|
||||
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
|
||||
)
|
||||
print(
|
||||
f" Image variance: median={np.median(var_image):.4f} "
|
||||
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
|
||||
)
|
||||
|
||||
# FK for spatial heatmap
|
||||
print(" Computing FK for spatial heatmap …")
|
||||
left_raw = data["state_raw"][:, data["left_joint_idx"]]
|
||||
left_rad = _detect_and_convert(left_raw)
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_rad)
|
||||
tcp_xz = left_tcp[:, [0, 2]]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"var_state": var_state,
|
||||
"var_image": var_image,
|
||||
"episode_ids": data["episode_ids"],
|
||||
"tcp_xz": tcp_xz,
|
||||
"n_total": data["n_total"],
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
|
||||
render(results, out)
|
||||
|
||||
# Save worst-episodes summary (image-based, since that's the stronger signal)
|
||||
worst_summary = {}
|
||||
for r in results:
|
||||
unique_eps = np.unique(r["episode_ids"])
|
||||
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
|
||||
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
|
||||
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
|
||||
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
|
||||
worst_path.write_text(json.dumps(worst_summary, indent=2))
|
||||
print(f"✓ Saved worst episodes: {worst_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
178
examples/dataset/visualization_tools/create_frame_grid.py
Normal file
178
examples/dataset/visualization_tools/create_frame_grid.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Create a JPG grid of random frames sampled from a LeRobot video dataset.
|
||||
Downloads metadata + video chunks from HuggingFace, picks random frames,
|
||||
decodes them, and tiles into a single image.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
REPO_ID = "lerobot-data-collection/level2_final_quality3"
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
GRID_COLS = 15
|
||||
GRID_ROWS = 10
|
||||
THUMB_WIDTH = 160
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
SEED = 1
|
||||
|
||||
|
||||
def download_metadata(repo_id: str) -> Path:
|
||||
"""Download only metadata (no videos yet)."""
|
||||
print(f"[1/3] Downloading metadata for {repo_id} …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_video_info(local: Path) -> tuple[str, list[dict], int]:
|
||||
"""Parse info.json and episode parquets. Returns (camera_key, episode_rows, fps)."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
cam = CAMERA_KEY
|
||||
else:
|
||||
cam = video_keys[0]
|
||||
print(f" camera='{cam}' all_cams={video_keys} fps={fps}")
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
chunk_col = f"videos/{cam}/chunk_index"
|
||||
file_col = f"videos/{cam}/file_index"
|
||||
ts_from = f"videos/{cam}/from_timestamp"
|
||||
ts_to = f"videos/{cam}/to_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{cam}/chunk_index"
|
||||
file_col = f"{cam}/file_index"
|
||||
ts_from = f"{cam}/from_timestamp"
|
||||
ts_to = f"{cam}/to_timestamp"
|
||||
|
||||
episodes = []
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
episodes.append(
|
||||
{
|
||||
"episode_index": int(row["episode_index"]),
|
||||
"chunk_index": ci,
|
||||
"file_index": fi,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"to_ts": float(row[ts_to]),
|
||||
"video_rel": video_template.format(video_key=cam, chunk_index=ci, file_index=fi),
|
||||
}
|
||||
)
|
||||
return cam, episodes, fps
|
||||
|
||||
|
||||
def pick_random_frames(episodes: list[dict], fps: int, n: int, rng: random.Random) -> list[dict]:
|
||||
"""Pick n random (episode, timestamp) pairs, return sorted by video file for efficient access."""
|
||||
picks = []
|
||||
for _ in range(n):
|
||||
ep = rng.choice(episodes)
|
||||
duration = ep["to_ts"] - ep["from_ts"]
|
||||
if duration <= 0:
|
||||
continue
|
||||
t = ep["from_ts"] + rng.random() * duration
|
||||
picks.append({**ep, "seek_ts": t})
|
||||
picks.sort(key=lambda p: (p["video_rel"], p["seek_ts"]))
|
||||
return picks
|
||||
|
||||
|
||||
def download_video_files(repo_id: str, local: Path, picks: list[dict]) -> None:
|
||||
"""Download only the video files we need."""
|
||||
needed = sorted({p["video_rel"] for p in picks})
|
||||
print(f"[2/3] Downloading {len(needed)} video file(s) …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=needed,
|
||||
)
|
||||
|
||||
|
||||
def extract_frame(video_path: Path, seek_ts: float) -> np.ndarray | None:
|
||||
"""Decode a single frame at the given timestamp."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
return frame if ret else None
|
||||
|
||||
|
||||
def build_grid(frames: list[np.ndarray], cols: int, thumb_w: int) -> np.ndarray:
|
||||
"""Resize frames to uniform thumbnails and tile into a grid."""
|
||||
if not frames:
|
||||
raise RuntimeError("No frames decoded")
|
||||
|
||||
h0, w0 = frames[0].shape[:2]
|
||||
thumb_h = int(thumb_w * h0 / w0)
|
||||
|
||||
thumbs = [cv2.resize(f, (thumb_w, thumb_h), interpolation=cv2.INTER_AREA) for f in frames]
|
||||
|
||||
rows = []
|
||||
for i in range(0, len(thumbs), cols):
|
||||
row_thumbs = thumbs[i : i + cols]
|
||||
while len(row_thumbs) < cols:
|
||||
row_thumbs.append(np.zeros_like(row_thumbs[0]))
|
||||
rows.append(np.hstack(row_thumbs))
|
||||
return np.vstack(rows)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
rng = random.Random(SEED)
|
||||
n_frames = GRID_COLS * GRID_ROWS
|
||||
|
||||
local = download_metadata(REPO_ID)
|
||||
cam, episodes, fps = load_video_info(local)
|
||||
picks = pick_random_frames(episodes, fps, n_frames, rng)
|
||||
download_video_files(REPO_ID, local, picks)
|
||||
|
||||
print(f"[3/3] Decoding {n_frames} frames …")
|
||||
frames: list[np.ndarray] = []
|
||||
for p in picks:
|
||||
vp = local / p["video_rel"]
|
||||
if not vp.exists():
|
||||
print(f" SKIP: {p['video_rel']} not found")
|
||||
continue
|
||||
frame = extract_frame(vp, p["seek_ts"])
|
||||
if frame is not None:
|
||||
frames.append(frame)
|
||||
|
||||
print(f" Decoded {len(frames)}/{n_frames} frames")
|
||||
grid = build_grid(frames, GRID_COLS, THUMB_WIDTH)
|
||||
|
||||
safe_name = REPO_ID.replace("/", "_")
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_grid_{GRID_COLS}x{GRID_ROWS}.jpg"
|
||||
cv2.imwrite(str(out_path), grid, [cv2.IMWRITE_JPEG_QUALITY, 92])
|
||||
print(f"\n✓ Saved: {out_path} ({grid.shape[1]}×{grid.shape[0]})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
526
examples/dataset/visualization_tools/create_progress_videos.py
Normal file
526
examples/dataset/visualization_tools/create_progress_videos.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""
|
||||
Create MP4 videos with sarm_progress overlay for specified episodes.
|
||||
Downloads datasets from HuggingFace, extracts episode video + progress data,
|
||||
and draws the progress line directly on each frame (no panel, no axes).
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "episode": 250},
|
||||
]
|
||||
CAMERA_KEY = (
|
||||
"observation.images.base" # None = auto-select first camera, or set e.g. "observation.images.top"
|
||||
)
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Progress line spans the full video height
|
||||
GRAPH_Y_TOP_FRAC = 0.01
|
||||
GRAPH_Y_BOT_FRAC = 0.99
|
||||
LINE_THICKNESS = 3
|
||||
SHADOW_THICKNESS = 6 # white edge thickness
|
||||
REF_ALPHA = 0.45 # opacity of the 1.0 reference line
|
||||
FILL_ALPHA = 0.55 # opacity of the grey fill under the line
|
||||
SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the files needed for this episode."""
|
||||
# We need: meta/, sarm_progress.parquet, and the relevant video/data chunks.
|
||||
# We'll download meta + sarm first, then figure out chunks.
|
||||
print(f"\n[1/5] Downloading metadata for {repo_id} …")
|
||||
local = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
return local
|
||||
|
||||
|
||||
def load_episode_meta(local: Path, episode: int) -> dict:
|
||||
"""Read info.json + episode-level parquet to get fps, video paths, timestamps."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
# Find video keys (keys whose dtype=="video")
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
first_cam = CAMERA_KEY
|
||||
else:
|
||||
first_cam = video_keys[0]
|
||||
print(f" fps={fps} camera='{first_cam}' all_cams={video_keys}")
|
||||
|
||||
# Load all episode-meta parquet files and find our episode
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
df = pd.read_parquet(pq)
|
||||
ep_rows.append(df)
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
row = ep_df[ep_df["episode_index"] == episode]
|
||||
if row.empty:
|
||||
raise RuntimeError(f"Episode {episode} not found in episode metadata")
|
||||
row = row.iloc[0]
|
||||
|
||||
# Extract video chunk/file index for first camera
|
||||
# Try both dot and slash variants of the key
|
||||
chunk_col = f"videos/{first_cam}/chunk_index"
|
||||
file_col = f"videos/{first_cam}/file_index"
|
||||
ts_col = f"videos/{first_cam}/from_timestamp"
|
||||
to_col = f"videos/{first_cam}/to_timestamp"
|
||||
|
||||
# Some datasets use different column naming
|
||||
if chunk_col not in row.index:
|
||||
# Try without the 'videos/' prefix
|
||||
chunk_col = f"{first_cam}/chunk_index"
|
||||
file_col = f"{first_cam}/file_index"
|
||||
ts_col = f"{first_cam}/from_timestamp"
|
||||
to_col = f"{first_cam}/to_timestamp"
|
||||
if chunk_col not in row.index:
|
||||
raise RuntimeError(
|
||||
f"Cannot find video metadata columns for {first_cam}.\nAvailable: {list(row.index)}"
|
||||
)
|
||||
|
||||
chunk_idx = int(row[chunk_col])
|
||||
file_idx = int(row[file_col])
|
||||
from_ts = float(row[ts_col])
|
||||
to_ts = float(row[to_col])
|
||||
|
||||
video_template = info.get(
|
||||
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
|
||||
)
|
||||
video_rel = video_template.format(
|
||||
video_key=first_cam,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
# Load task name for this episode
|
||||
# tasks.parquet uses the task string as the row index; task_index column holds the int id
|
||||
task_name = ""
|
||||
try:
|
||||
# Prefer the 'tasks' list directly on the episode row
|
||||
if "tasks" in row.index and row["tasks"] is not None:
|
||||
tasks_val = row["tasks"]
|
||||
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
|
||||
task_name = str(tasks_val[0])
|
||||
else:
|
||||
task_name = str(tasks_val).strip("[]'")
|
||||
else:
|
||||
tasks_pq = local / "meta" / "tasks.parquet"
|
||||
if tasks_pq.exists():
|
||||
tasks_df = pd.read_parquet(tasks_pq)
|
||||
# Row index is the task string; task_index column is the int
|
||||
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
|
||||
match = tasks_df[tasks_df["task_index"] == task_idx]
|
||||
if not match.empty:
|
||||
task_name = str(match.index[0])
|
||||
print(f" Task name: '{task_name}'")
|
||||
except Exception as e:
|
||||
print(f" WARNING: could not load task name: {e}")
|
||||
|
||||
return {
|
||||
"fps": fps,
|
||||
"first_cam": first_cam,
|
||||
"video_rel": video_rel,
|
||||
"chunk_index": chunk_idx,
|
||||
"file_index": file_idx,
|
||||
"from_ts": from_ts,
|
||||
"to_ts": to_ts,
|
||||
"task_name": task_name,
|
||||
}
|
||||
|
||||
|
||||
def download_video(repo_id: str, local: Path, video_rel: str) -> Path:
|
||||
"""Download the specific video file if not already present."""
|
||||
video_path = local / video_rel
|
||||
if video_path.exists():
|
||||
print(f" Video already cached: {video_path}")
|
||||
return video_path
|
||||
print(f"[2/5] Downloading video file {video_rel} …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=[video_rel],
|
||||
)
|
||||
if not video_path.exists():
|
||||
raise RuntimeError(f"Video not found after download: {video_path}")
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress(local: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for this episode. Returns sorted array of (frame_index, progress)."""
|
||||
pq_path = local / "sarm_progress.parquet"
|
||||
if not pq_path.exists():
|
||||
print(" WARNING: sarm_progress.parquet not found, trying data parquet …")
|
||||
return None
|
||||
df = pd.read_parquet(pq_path)
|
||||
print(f" sarm_progress.parquet columns: {list(df.columns)}")
|
||||
ep_df = df[df["episode_index"] == episode].copy()
|
||||
if ep_df.empty:
|
||||
print(f" WARNING: No sarm_progress rows for episode {episode}")
|
||||
return None
|
||||
ep_df = ep_df.sort_values("frame_index")
|
||||
|
||||
# Prefer dense, fall back to sparse
|
||||
if "progress_dense" in ep_df.columns and ep_df["progress_dense"].notna().any():
|
||||
prog_col = "progress_dense"
|
||||
elif "progress_sparse" in ep_df.columns:
|
||||
prog_col = "progress_sparse"
|
||||
else:
|
||||
# Last resort: any column with 'progress' in the name
|
||||
prog_cols = [c for c in ep_df.columns if "progress" in c.lower()]
|
||||
if not prog_cols:
|
||||
return None
|
||||
prog_col = prog_cols[0]
|
||||
|
||||
print(f" Using progress column: '{prog_col}'")
|
||||
return ep_df[["frame_index", prog_col]].rename(columns={prog_col: "progress"}).values
|
||||
|
||||
|
||||
def extract_episode_clip(video_path: Path, from_ts: float, to_ts: float, out_path: Path) -> Path:
|
||||
"""Use ffmpeg to cut the episode segment from the combined video file."""
|
||||
duration = to_ts - from_ts
|
||||
print(f"[3/5] Extracting clip [{from_ts:.3f}s → {to_ts:.3f}s] ({duration:.2f}s) …")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-ss",
|
||||
str(from_ts),
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-t",
|
||||
str(duration),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"fast",
|
||||
"-crf",
|
||||
"18",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg clip extraction failed:\n{result.stderr}")
|
||||
return out_path
|
||||
|
||||
|
||||
def precompute_pixels(
|
||||
progress_data: np.ndarray,
|
||||
n_frames: int,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Map each progress sample to pixel coordinates.
|
||||
Returns array of shape (N, 2) with (x, y) in pixel space.
|
||||
x spans full video width; y maps progress [0,1] to graph band.
|
||||
"""
|
||||
frame_indices = progress_data[:, 0].astype(float)
|
||||
progress_vals = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
|
||||
|
||||
y_top = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
graph_h = y_bot - y_top
|
||||
|
||||
xs = (frame_indices / (n_frames - 1) * (frame_w - 1)).astype(int)
|
||||
# progress=1 → y_top, progress=0 → y_bot
|
||||
ys = (y_bot - progress_vals * graph_h).astype(int)
|
||||
|
||||
return np.stack([xs, ys], axis=1) # (N, 2)
|
||||
|
||||
|
||||
def progress_color(t: float) -> tuple[int, int, int]:
|
||||
"""Interpolate BGR color red→green based on normalised position t in [0,1]."""
|
||||
r = int(255 * (1.0 - t))
|
||||
g = int(255 * t)
|
||||
return (0, g, r) # BGR
|
||||
|
||||
|
||||
def prerender_fill(
|
||||
pixels: np.ndarray,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""Pre-render the full grey fill polygon under the curve as a BGRA image."""
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
fill_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
poly = np.concatenate(
|
||||
[
|
||||
pixels,
|
||||
[[pixels[-1][0], y_bot], [pixels[0][0], y_bot]],
|
||||
],
|
||||
axis=0,
|
||||
).astype(np.int32)
|
||||
cv2.fillPoly(fill_img, [poly], color=(128, 128, 128, int(255 * FILL_ALPHA)))
|
||||
return fill_img
|
||||
|
||||
|
||||
def alpha_composite(base: np.ndarray, overlay_bgra: np.ndarray, x_max: int) -> None:
|
||||
"""Blend overlay onto base in-place, but only for x < x_max."""
|
||||
if x_max <= 0:
|
||||
return
|
||||
roi_b = base[:, :x_max]
|
||||
roi_o = overlay_bgra[:, :x_max]
|
||||
alpha = roi_o[:, :, 3:4].astype(np.float32) / 255.0
|
||||
roi_b[:] = np.clip(
|
||||
roi_o[:, :, :3].astype(np.float32) * alpha + roi_b.astype(np.float32) * (1.0 - alpha),
|
||||
0,
|
||||
255,
|
||||
).astype(np.uint8)
|
||||
|
||||
|
||||
def draw_text_outlined(
|
||||
frame: np.ndarray,
|
||||
text: str,
|
||||
pos: tuple[int, int],
|
||||
font_scale: float,
|
||||
thickness: int = 1,
|
||||
) -> None:
|
||||
"""Draw text with a dark outline for readability on any background."""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
cv2.putText(frame, text, pos, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
|
||||
cv2.putText(frame, text, pos, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||
|
||||
|
||||
def composite_video(
|
||||
clip_path: Path,
|
||||
progress_data: np.ndarray,
|
||||
out_path: Path,
|
||||
fps: float,
|
||||
frame_h: int,
|
||||
frame_w: int,
|
||||
task_name: str = "",
|
||||
) -> Path:
|
||||
"""Read clip frames, draw gradient progress line with fill + labels, export as GIF."""
|
||||
n_total = int(cv2.VideoCapture(str(clip_path)).get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
pixels = precompute_pixels(progress_data, n_total, frame_w, frame_h)
|
||||
|
||||
y_ref = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
|
||||
# Pre-render fill polygon (line is drawn per-frame with live color)
|
||||
fill_img = prerender_fill(pixels, frame_w, frame_h)
|
||||
|
||||
# 1.0 reference line overlay (full width, drawn once)
|
||||
ref_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
cv2.line(ref_img, (0, y_ref), (frame_w - 1, y_ref), (200, 200, 200, int(255 * REF_ALPHA)), 1, cv2.LINE_AA)
|
||||
|
||||
frame_indices = progress_data[:, 0].astype(int)
|
||||
progress_vals = progress_data[:, 1].astype(float)
|
||||
|
||||
print(f"[4/4] Compositing {n_total} frames …")
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
tmp_path = out_path.parent / (out_path.stem + "_tmp.mp4")
|
||||
writer = cv2.VideoWriter(str(tmp_path), fourcc, fps, (frame_w, frame_h))
|
||||
|
||||
fi = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
n_drawn = int(np.searchsorted(frame_indices, fi, side="right"))
|
||||
x_cur = int(pixels[min(n_drawn, len(pixels)) - 1][0]) + 1 if n_drawn > 0 else 0
|
||||
|
||||
# 1. reference line (full width, always)
|
||||
alpha_composite(frame, ref_img, frame_w)
|
||||
|
||||
# 2. grey fill under curve up to current x
|
||||
alpha_composite(frame, fill_img, x_cur)
|
||||
|
||||
# 3. progress line — single color that transitions red→green over time
|
||||
if n_drawn >= 2:
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
line_col = progress_color(t_cur)
|
||||
pts = pixels[:n_drawn].reshape(-1, 1, 2).astype(np.int32)
|
||||
cv2.polylines(
|
||||
frame,
|
||||
[pts],
|
||||
isClosed=False,
|
||||
color=(255, 255, 255),
|
||||
thickness=SHADOW_THICKNESS,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
cv2.polylines(
|
||||
frame, [pts], isClosed=False, color=line_col, thickness=LINE_THICKNESS, lineType=cv2.LINE_AA
|
||||
)
|
||||
|
||||
# 4. score — bottom right
|
||||
if n_drawn > 0:
|
||||
score = float(progress_vals[min(n_drawn, len(progress_vals)) - 1])
|
||||
score_text = f"{score:.2f}"
|
||||
(tw, th), _ = cv2.getTextSize(score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2)
|
||||
sx = frame_w - tw - 12
|
||||
sy = frame_h - 12
|
||||
# coloured score matching current gradient position
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
score_col = progress_color(t_cur)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
(0, 0, 0),
|
||||
4,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
score_col,
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
# 5. task name — top centre
|
||||
if task_name:
|
||||
(tw, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
|
||||
tx = max((frame_w - tw) // 2, 4)
|
||||
draw_text_outlined(frame, task_name, (tx, 22), TASK_FONT_SCALE)
|
||||
|
||||
writer.write(frame)
|
||||
fi += 1
|
||||
if fi % 100 == 0:
|
||||
print(f" Frame {fi}/{n_total} …", end="\r")
|
||||
|
||||
cap.release()
|
||||
writer.release()
|
||||
print()
|
||||
|
||||
# Convert to GIF: full resolution, 12fps, 128-color diff palette (<40MB)
|
||||
gif_path = out_path.with_suffix(".gif")
|
||||
palette = out_path.parent / "_palette.png"
|
||||
r1 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-vf",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
|
||||
"-update",
|
||||
"1",
|
||||
str(palette),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r1.returncode != 0:
|
||||
print(f" WARNING: palettegen failed:\n{r1.stderr[-500:]}")
|
||||
r2 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-i",
|
||||
str(palette),
|
||||
"-filter_complex",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
|
||||
str(gif_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r2.returncode != 0:
|
||||
print(f" WARNING: gif encode failed:\n{r2.stderr[-500:]}")
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
palette.unlink(missing_ok=True)
|
||||
return gif_path
|
||||
|
||||
|
||||
def process_dataset(repo_id: str, episode: int):
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Processing: {repo_id} | episode {episode}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 1. Download metadata
|
||||
local = download_episode(repo_id, episode)
|
||||
print(f" Local cache: {local}")
|
||||
|
||||
# 2. Read episode metadata
|
||||
ep_meta = load_episode_meta(local, episode)
|
||||
print(f" Episode meta: {ep_meta}")
|
||||
|
||||
# 3. Download video file
|
||||
video_path = download_video(repo_id, local, ep_meta["video_rel"])
|
||||
|
||||
# 4. Extract clip
|
||||
clip_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_clip.mp4"
|
||||
extract_episode_clip(video_path, ep_meta["from_ts"], ep_meta["to_ts"], clip_path)
|
||||
|
||||
# 5. Load progress data
|
||||
progress_data = load_progress(local, episode)
|
||||
if progress_data is None:
|
||||
print(" ERROR: Could not load sarm_progress data. Skipping overlay.")
|
||||
return
|
||||
|
||||
n_progress = len(progress_data)
|
||||
print(f" Progress frames: {n_progress}")
|
||||
|
||||
# 6. Get clip dimensions
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
actual_fps = cap.get(cv2.CAP_PROP_FPS) or ep_meta["fps"]
|
||||
cap.release()
|
||||
print(f" Clip: {frame_w}×{frame_h} {n_frames} frames @ {actual_fps:.1f}fps")
|
||||
|
||||
# 7. Composite (draw line directly on frames)
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_progress.mp4"
|
||||
final = composite_video(
|
||||
clip_path,
|
||||
progress_data,
|
||||
out_path,
|
||||
actual_fps,
|
||||
frame_h,
|
||||
frame_w,
|
||||
task_name=ep_meta.get("task_name", ""),
|
||||
)
|
||||
clip_path.unlink(missing_ok=True)
|
||||
print(f"\n✓ Done: {final}")
|
||||
return final
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = []
|
||||
for cfg in DATASETS:
|
||||
try:
|
||||
out = process_dataset(cfg["repo_id"], cfg["episode"])
|
||||
if out:
|
||||
results.append(out)
|
||||
except Exception as e:
|
||||
print(f"\nERROR processing {cfg['repo_id']}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Output files:")
|
||||
for r in results:
|
||||
print(f" {r}")
|
||||
496
examples/dataset/visualization_tools/workspace_density.py
Normal file
496
examples/dataset/visualization_tools/workspace_density.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""
|
||||
Visualize end-effector workspace density and trajectory clusters for OpenArm datasets.
|
||||
Downloads joint position data (no videos) from HuggingFace, computes forward
|
||||
kinematics per episode, clusters trajectories with K-means, and renders
|
||||
2D projections comparing dataset coverage and multimodality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
N_CLUSTERS = 10
|
||||
WAYPOINTS = 50
|
||||
SEED = 42
|
||||
DPI = 180
|
||||
|
||||
CLUSTER_COLORS = [
|
||||
"#e6194b",
|
||||
"#3cb44b",
|
||||
"#4363d8",
|
||||
"#f58231",
|
||||
"#911eb4",
|
||||
"#42d4f4",
|
||||
"#f032e6",
|
||||
"#bfef45",
|
||||
"#fabed4",
|
||||
"#dcbeff",
|
||||
"#9a6324",
|
||||
"#fffac8",
|
||||
"#800000",
|
||||
"#aaffc3",
|
||||
"#808000",
|
||||
"#ffd8b1",
|
||||
"#000075",
|
||||
"#a9a9a9",
|
||||
]
|
||||
|
||||
# FK chains extracted from OpenArm bimanual URDF.
|
||||
# Each entry: (rpy, xyz, revolute_axis_or_None).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
"""Build a 4x4 homogeneous transform from URDF rpy + xyz."""
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
"""Batched Rodrigues rotation: (n,) angles around a fixed axis → (n, 4, 4)."""
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
"""Vectorized FK: (n, 7) radians → (n, 3) TCP positions in world frame."""
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
"""Recursively flatten a names structure (list, dict, or nested) into a flat string list."""
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
"""Auto-detect servo ticks / degrees / radians and convert to radians."""
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
"""Try to find left/right joint indices from info.json feature names."""
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet only) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "data/**"],
|
||||
ignore_patterns=["*.mp4", "videos/**"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def resample_trajectory(traj: np.ndarray, n_waypoints: int) -> np.ndarray:
|
||||
"""Resample a (F, 3) trajectory to exactly n_waypoints via linear interpolation."""
|
||||
f = traj.shape[0]
|
||||
if f == n_waypoints:
|
||||
return traj
|
||||
old_t = np.linspace(0, 1, f)
|
||||
new_t = np.linspace(0, 1, n_waypoints)
|
||||
return np.column_stack([np.interp(new_t, old_t, traj[:, d]) for d in range(3)])
|
||||
|
||||
|
||||
def load_episode_trajectories(local: Path) -> list[dict]:
|
||||
"""
|
||||
Load per-episode joint data, compute FK, return list of trajectory dicts.
|
||||
Each dict: {"left_tcp": (F,3), "right_tcp": (F,3), "episode_index": int}.
|
||||
Uses all episodes in the dataset for a fair comparison.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
print(f" Total frames: {len(df):,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
|
||||
first = df[state_col].iloc[0]
|
||||
if not hasattr(first, "__len__"):
|
||||
raise RuntimeError(f"observation.state is scalar ({type(first)}), expected array")
|
||||
|
||||
state = np.stack(df[state_col].values).astype(np.float64)
|
||||
n_dim = state.shape[1]
|
||||
print(f" State dim: {n_dim} max|val|: {np.max(np.abs(state)):.1f}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
episode_ids = df[ep_col].values
|
||||
unique_eps = np.unique(episode_ids)
|
||||
print(f" Episodes: {len(unique_eps):,}")
|
||||
|
||||
left_raw = state[:, left_idx]
|
||||
right_raw = state[:, right_idx]
|
||||
left_all = _detect_and_convert(left_raw)
|
||||
right_all = _detect_and_convert(right_raw)
|
||||
|
||||
print(" Computing FK per episode …")
|
||||
trajectories = []
|
||||
for ep_id in unique_eps:
|
||||
mask = episode_ids == ep_id
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_all[mask])
|
||||
right_tcp = batch_fk(RIGHT_CHAIN, right_all[mask])
|
||||
if len(left_tcp) < 3:
|
||||
continue
|
||||
trajectories.append({"left_tcp": left_tcp, "right_tcp": right_tcp, "episode_index": int(ep_id)})
|
||||
|
||||
print(f" Valid trajectories: {len(trajectories):,}")
|
||||
return trajectories
|
||||
|
||||
|
||||
# ── Clustering ──────────────────────────────────────────
|
||||
|
||||
|
||||
def cluster_trajectories(
|
||||
trajectories: list[dict], n_clusters: int, n_waypoints: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
K-means on resampled trajectory features.
|
||||
Combines left+right TCP into a single feature vector per episode.
|
||||
Returns (labels, centroid_trajs (k, waypoints, 6), spread_per_cluster (k,) in metres).
|
||||
Spread = mean per-waypoint Euclidean distance from each trajectory to its centroid.
|
||||
"""
|
||||
feat_vecs = []
|
||||
for t in trajectories:
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
feat_vecs.append(np.concatenate([left_rs.ravel(), right_rs.ravel()]))
|
||||
feat_matrix = np.array(feat_vecs)
|
||||
|
||||
k = min(n_clusters, len(feat_vecs))
|
||||
km = KMeans(n_clusters=k, n_init=10, random_state=SEED)
|
||||
labels = km.fit_predict(feat_matrix)
|
||||
|
||||
centroids_flat = km.cluster_centers_
|
||||
centroid_trajs = np.zeros((k, n_waypoints, 6))
|
||||
for ci in range(k):
|
||||
left_flat = centroids_flat[ci, : n_waypoints * 3]
|
||||
right_flat = centroids_flat[ci, n_waypoints * 3 :]
|
||||
centroid_trajs[ci, :, :3] = left_flat.reshape(n_waypoints, 3)
|
||||
centroid_trajs[ci, :, 3:] = right_flat.reshape(n_waypoints, 3)
|
||||
|
||||
# Mean per-waypoint distance to centroid (in metres) for each cluster
|
||||
spread = np.zeros(k)
|
||||
for ci in range(k):
|
||||
members = np.where(labels == ci)[0]
|
||||
if len(members) == 0:
|
||||
continue
|
||||
centroid_left = centroid_trajs[ci, :, :3]
|
||||
centroid_right = centroid_trajs[ci, :, 3:]
|
||||
dists = []
|
||||
for mi in members:
|
||||
t = trajectories[mi]
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
d_left = np.linalg.norm(left_rs - centroid_left, axis=1).mean()
|
||||
d_right = np.linalg.norm(right_rs - centroid_right, axis=1).mean()
|
||||
dists.append((d_left + d_right) / 2)
|
||||
spread[ci] = np.mean(dists)
|
||||
|
||||
return labels, centroid_trajs, spread
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
PROJ_VIEWS = [
|
||||
("XZ (side)", 0, 2, "X (m)", "Z (m)"),
|
||||
("XY (top)", 0, 1, "X (m)", "Y (m)"),
|
||||
("YZ (front)", 1, 2, "Y (m)", "Z (m)"),
|
||||
]
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
2-row × 3-col grid per dataset (3 projections × 2 datasets).
|
||||
Trajectory lines colored by cluster, centroid trajectories drawn thick.
|
||||
"""
|
||||
n_ds = len(results)
|
||||
n_proj = len(PROJ_VIEWS)
|
||||
fig, axes = plt.subplots(n_ds, n_proj, figsize=(7 * n_proj, 7 * n_ds), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[np.newaxis, :]
|
||||
|
||||
for row, r in enumerate(results):
|
||||
trajectories = r["trajectories"]
|
||||
labels = r["labels"]
|
||||
centroids = r["centroids"]
|
||||
k = centroids.shape[0]
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=k)
|
||||
size_order = np.argsort(-cluster_sizes)
|
||||
pcts = cluster_sizes / len(labels) * 100
|
||||
spread = r["spread"]
|
||||
|
||||
for col, (view_name, dim_a, dim_b, xlabel, ylabel) in enumerate(PROJ_VIEWS):
|
||||
ax = axes[row, col]
|
||||
ax.set_facecolor("#0d1117")
|
||||
|
||||
for ti, traj in enumerate(trajectories):
|
||||
color = CLUSTER_COLORS[labels[ti] % len(CLUSTER_COLORS)]
|
||||
for tcp_key in ("left_tcp", "right_tcp"):
|
||||
pts = traj[tcp_key]
|
||||
ax.plot(pts[:, dim_a], pts[:, dim_b], color=color, alpha=0.12, linewidth=0.4)
|
||||
|
||||
for ci in range(k):
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
left_c = centroids[ci, :, :3]
|
||||
right_c = centroids[ci, :, 3:]
|
||||
lw = 1.5 + 2.0 * cluster_sizes[ci] / cluster_sizes.max()
|
||||
for c_pts in (left_c, right_c):
|
||||
ax.plot(
|
||||
c_pts[:, dim_a],
|
||||
c_pts[:, dim_b],
|
||||
color=color,
|
||||
linewidth=lw,
|
||||
alpha=0.95,
|
||||
zorder=10,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[0, dim_a],
|
||||
c_pts[0, dim_b],
|
||||
"o",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[-1, dim_a],
|
||||
c_pts[-1, dim_b],
|
||||
"s",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
|
||||
ax.set_xlabel(xlabel, color="#888", fontsize=9)
|
||||
ax.set_ylabel(ylabel, color="#888", fontsize=9)
|
||||
ax.tick_params(colors="#555", labelsize=7)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
ax.set_aspect("equal")
|
||||
|
||||
mean_spread_cm = np.average(spread, weights=cluster_sizes) * 100
|
||||
if col == 0:
|
||||
ax.set_title(
|
||||
f"{r['label']} ({r['n_episodes']:,} episodes, {k} clusters, "
|
||||
f"avg spread {mean_spread_cm:.1f}cm)",
|
||||
color="white",
|
||||
fontsize=11,
|
||||
pad=10,
|
||||
)
|
||||
else:
|
||||
ax.set_title(view_name, color="#aaa", fontsize=10, pad=8)
|
||||
|
||||
# Cluster size + spread legend on the rightmost panel
|
||||
legend_ax = axes[row, -1]
|
||||
for ci in size_order:
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
spread_cm = spread[ci] * 100
|
||||
label = f"C{ci}: {cluster_sizes[ci]} eps ({pcts[ci]:.0f}%) ±{spread_cm:.1f}cm"
|
||||
legend_ax.plot([], [], color=color, linewidth=3, label=label)
|
||||
legend_ax.legend(
|
||||
loc="upper right",
|
||||
fontsize=7,
|
||||
frameon=True,
|
||||
facecolor="#1a1a2e",
|
||||
edgecolor="#333",
|
||||
labelcolor="white",
|
||||
handlelength=1.5,
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
"End-Effector Trajectory Clusters (FK · K-means)",
|
||||
color="white",
|
||||
fontsize=16,
|
||||
y=0.98,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id)
|
||||
trajectories = load_episode_trajectories(local)
|
||||
labels, centroids, spread = cluster_trajectories(trajectories, N_CLUSTERS, WAYPOINTS)
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=centroids.shape[0])
|
||||
print(f" Cluster sizes: {sorted(cluster_sizes, reverse=True)}")
|
||||
for ci in np.argsort(-cluster_sizes):
|
||||
print(
|
||||
f" C{ci}: {cluster_sizes[ci]} eps ({cluster_sizes[ci] / len(labels) * 100:.0f}%) "
|
||||
f"spread ±{spread[ci] * 100:.1f}cm"
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"trajectories": trajectories,
|
||||
"labels": labels,
|
||||
"centroids": centroids,
|
||||
"spread": spread,
|
||||
"n_episodes": len(trajectories),
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "workspace_trajectory_clusters.jpg"
|
||||
render(results, out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,956 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
SONIC planner with full mode control.
|
||||
|
||||
Keyboard controls:
|
||||
N / P - next / previous motion set
|
||||
1-8 - select mode within current set
|
||||
WASD - movement direction
|
||||
Q / E - rotate facing left / right
|
||||
9 / 0 - decrease / increase speed
|
||||
- / = - decrease / increase height
|
||||
R - force replan
|
||||
Space - emergency stop -> IDLE
|
||||
Esc - quit
|
||||
|
||||
Gamepad controls (Unitree wireless controller):
|
||||
Left stick Y - speed (forward = fast, back = stop)
|
||||
Left stick X - movement direction (offset from facing)
|
||||
Right stick X - facing direction (incremental rotation)
|
||||
Right stick Y - height (up = tall 0.8m, down = low 0.1m)
|
||||
Buttons - unused (mode selection is keyboard-only)
|
||||
"""
|
||||
|
||||
import argparse, gc, math, select, sys, termios, tty
|
||||
import multiprocessing as mp
|
||||
import threading, time
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
|
||||
# ── Constants ────────────────────────────────────────────────────────────────
|
||||
|
||||
DEFAULT_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0,
|
||||
0.0, 0.0, 0.0,
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0,
|
||||
], dtype=np.float32)
|
||||
|
||||
NATURAL_FREQ = 10.0 * 2.0 * np.pi
|
||||
ARMATURE = {"5020": 0.003609725, "7520_14": 0.010177520, "7520_22": 0.025101925, "4010": 0.00425}
|
||||
EFFORT = {"5020": 25.0, "7520_14": 88.0, "7520_22": 139.0, "4010": 5.0}
|
||||
|
||||
def _action_scale(k):
|
||||
return 0.25 * EFFORT[k] / (ARMATURE[k] * NATURAL_FREQ**2)
|
||||
|
||||
_J = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
|
||||
["7520_14","5020","5020"] + \
|
||||
["5020","5020","5020","5020","5020","4010","4010"] * 2
|
||||
ACTION_SCALE = np.array([_action_scale(k) for k in _J], dtype=np.float32)
|
||||
|
||||
CONTROL_DT = 0.02
|
||||
DEFAULT_HEIGHT = 0.788740
|
||||
TOKEN_DIM = 64
|
||||
ENCODER_UPDATE_EVERY = 5
|
||||
DEBUG_PRINT_EVERY = 100
|
||||
MOTION_LOOK_AHEAD_STEPS = 2
|
||||
INITIAL_RANDOM_SEED = 1234
|
||||
MIN_TOKENS, MAX_TOKENS = 6, 16
|
||||
K = MAX_TOKENS - MIN_TOKENS + 1
|
||||
DEADZONE = 0.05
|
||||
BLEND_FRAMES = 8
|
||||
|
||||
REPLAN_INTERVAL = {
|
||||
"running": 0.1, "crawling": 0.2, "boxing": 1.0, "default": 1.0
|
||||
}
|
||||
|
||||
ISAACLAB_TO_MUJOCO = np.array([
|
||||
0, 3, 6, 9, 13, 17, 1, 4, 7, 10, 14, 18, 2, 5, 8,
|
||||
11, 15, 19, 21, 23, 25, 27, 12, 16, 20, 22, 24, 26, 28
|
||||
], dtype=np.int32)
|
||||
|
||||
MUJOCO_TO_ISAACLAB = np.array([
|
||||
0, 6, 12, 1, 7, 13, 2, 8, 14, 3, 9, 15, 22, 4, 10,
|
||||
16, 23, 5, 11, 17, 24, 18, 25, 19, 26, 20, 27, 21, 28
|
||||
], dtype=np.int32)
|
||||
|
||||
def _to_mujoco(a): return a[MUJOCO_TO_ISAACLAB]
|
||||
def _to_runtime(a): r = np.zeros(29, np.float32); r[MUJOCO_TO_ISAACLAB] = a; return r
|
||||
|
||||
DEFAULT_ANGLES_MUJOCO = _to_mujoco(DEFAULT_ANGLES)
|
||||
ENCODER_STANDING_REF = DEFAULT_ANGLES.copy()
|
||||
|
||||
LOWER_BODY_IL = np.array([0,3,6,9,13,17,1,4,7,10,14,18], dtype=np.int32)
|
||||
WRIST_IL = np.array([23,24,25,26,27,28], dtype=np.int32)
|
||||
VR_TARGET_DEF = np.zeros(9, dtype=np.float32)
|
||||
VR_ORN_DEF = np.array([1,0,0,0,1,0,0,0,1,0,0,0], dtype=np.float32)
|
||||
SMPL_DEF = np.zeros(720, dtype=np.float32)
|
||||
|
||||
# ── PD gains ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def _kp_kd():
|
||||
s = lambda k: ARMATURE[k] * NATURAL_FREQ**2
|
||||
d = lambda k: 2.0 * 2.0 * ARMATURE[k] * NATURAL_FREQ
|
||||
_kp_keys = ["7520_22","7520_22","7520_14","7520_22","5020","5020"] * 2 + \
|
||||
["7520_14","5020","5020"] + \
|
||||
["5020","5020","5020","5020","5020","4010","4010"] * 2
|
||||
_kd_keys = _kp_keys
|
||||
_double = {4,5,10,11,13,14} # ankle + waist indices with factor 2
|
||||
kp = np.array([2*s(k) if i in _double else s(k) for i,k in enumerate(_kp_keys)], dtype=np.float32)
|
||||
kd = np.array([2*d(k) if i in _double else d(k) for i,k in enumerate(_kd_keys)], dtype=np.float32)
|
||||
return kp, kd
|
||||
|
||||
# ── Quaternion helpers ────────────────────────────────────────────────────────
|
||||
|
||||
def quat_conj(q):
|
||||
return np.array([q[0], -q[1], -q[2], -q[3]], dtype=np.float32)
|
||||
|
||||
def quat_mul(q1, q2):
|
||||
w1,x1,y1,z1 = q1; w2,x2,y2,z2 = q2
|
||||
return np.array([
|
||||
w1*w2 - x1*x2 - y1*y2 - z1*z2,
|
||||
w1*x2 + x1*w2 + y1*z2 - z1*y2,
|
||||
w1*y2 - x1*z2 + y1*w2 + z1*x2,
|
||||
w1*z2 + x1*y2 - y1*x2 + z1*w2,
|
||||
], dtype=np.float32)
|
||||
|
||||
def gravity_dir(q):
|
||||
q = q / (np.linalg.norm(q) + 1e-8)
|
||||
qv = np.array([0, 0, 0, -1], dtype=np.float32)
|
||||
return quat_mul(quat_mul(quat_conj(q), qv), q)[1:]
|
||||
|
||||
def quat_to_6d(q):
|
||||
w,x,y,z = q
|
||||
return np.array([
|
||||
1-2*(y*y+z*z), 2*(x*y-z*w),
|
||||
2*(x*y+z*w), 1-2*(x*x+z*z),
|
||||
2*(x*z-y*w), 2*(y*z+x*w),
|
||||
], dtype=np.float32)
|
||||
|
||||
def calc_heading(q):
|
||||
w,x,y,z = q
|
||||
return float(np.arctan2(2*(x*y + w*z), 1-2*(y*y+z*z)))
|
||||
|
||||
def heading_quat(q, sign=1.0):
|
||||
a = sign * calc_heading(q) / 2.0
|
||||
return np.array([np.cos(a), 0, 0, np.sin(a)], dtype=np.float64)
|
||||
|
||||
heading_quat_inv = lambda q: heading_quat(q, -1.0)
|
||||
|
||||
def quat_slerp(q0, q1, t):
|
||||
q0 = q0 / (np.linalg.norm(q0)+1e-12); q1 = q1 / (np.linalg.norm(q1)+1e-12)
|
||||
dot = float(np.dot(q0, q1))
|
||||
if dot < 0: q1, dot = -q1, -dot
|
||||
dot = min(dot, 1.0)
|
||||
if dot > 0.9995:
|
||||
r = q0 + t*(q1-q0); return r/(np.linalg.norm(r)+1e-12)
|
||||
th = np.arccos(dot); st = np.sin(th)
|
||||
return (np.sin((1-t)*th)/st)*q0 + (np.sin(t*th)/st)*q1
|
||||
|
||||
def quat_slerp_batch(q0, q1, t):
|
||||
q0 = q0 / (np.linalg.norm(q0,axis=1,keepdims=True)+1e-12)
|
||||
q1 = q1 / (np.linalg.norm(q1,axis=1,keepdims=True)+1e-12)
|
||||
dot = np.sum(q0*q1, axis=1); neg = dot<0
|
||||
q1=q1.copy(); q1[neg]=-q1[neg]; dot[neg]=-dot[neg]; dot=np.clip(dot,-1,1)
|
||||
lin = dot>0.9995; th=np.arccos(dot); st=np.where(np.sin(th)==0,1,np.sin(th))
|
||||
c0=np.sin((1-t)*th)/st; c1=np.sin(t*th)/st
|
||||
c0[lin]=1-t[lin]; c1[lin]=t[lin]
|
||||
r = c0[:,None]*q0 + c1[:,None]*q1
|
||||
return r / (np.linalg.norm(r,axis=1,keepdims=True)+1e-12)
|
||||
|
||||
# ── Locomotion modes ──────────────────────────────────────────────────────────
|
||||
|
||||
class LocomotionMode(IntEnum):
|
||||
IDLE=0; SLOW_WALK=1; WALK=2; RUN=3; SQUAT=4; KNEEL_TWO_LEGS=5; KNEEL=6
|
||||
LYING_FACE_DOWN=7; CRAWLING=8; IDLE_BOXING=9; WALK_BOXING=10
|
||||
LEFT_PUNCH=11; RIGHT_PUNCH=12; RANDOM_PUNCH=13; ELBOW_CRAWLING=14
|
||||
LEFT_HOOK=15; RIGHT_HOOK=16; FORWARD_JUMP=17; STEALTH_WALK=18
|
||||
INJURED_WALK=19; LEDGE_WALKING=20; OBJECT_CARRYING=21; STEALTH_WALK_2=22
|
||||
HAPPY_DANCE_WALK=23; ZOMBIE_WALK=24; GUN_WALK=25; SCARE_WALK=26
|
||||
|
||||
LM = LocomotionMode
|
||||
|
||||
MOTION_SETS = [
|
||||
("Standing", [LM.SLOW_WALK, LM.WALK, LM.RUN, LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK]),
|
||||
("Squat / Low", [LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.CRAWLING, LM.ELBOW_CRAWLING]),
|
||||
("Boxing", [LM.IDLE_BOXING, LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
|
||||
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK]),
|
||||
("Styled Walks", [LM.LEDGE_WALKING, LM.OBJECT_CARRYING, LM.STEALTH_WALK_2,
|
||||
LM.HAPPY_DANCE_WALK, LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK]),
|
||||
]
|
||||
|
||||
STATIC_MODES = {LM.IDLE, LM.SQUAT, LM.KNEEL_TWO_LEGS, LM.KNEEL, LM.LYING_FACE_DOWN, LM.IDLE_BOXING}
|
||||
STANDING_MODES = {LM.IDLE, LM.SLOW_WALK, LM.WALK, LM.RUN, LM.IDLE_BOXING, LM.WALK_BOXING,
|
||||
LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK,
|
||||
LM.FORWARD_JUMP, LM.STEALTH_WALK, LM.INJURED_WALK, LM.LEDGE_WALKING,
|
||||
LM.OBJECT_CARRYING, LM.STEALTH_WALK_2, LM.HAPPY_DANCE_WALK,
|
||||
LM.ZOMBIE_WALK, LM.GUN_WALK, LM.SCARE_WALK}
|
||||
BOXING_MODES = {LM.WALK_BOXING, LM.LEFT_PUNCH, LM.RIGHT_PUNCH,
|
||||
LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}
|
||||
SPEED_RANGES = {LM.SLOW_WALK:(0.2,0.8), LM.WALK:(0.8,1.5), LM.RUN:(1.5,3.0),
|
||||
LM.CRAWLING:(0.4,1.0), LM.ELBOW_CRAWLING:(0.7,1.0)}
|
||||
|
||||
def clamp_mode_params(ms):
|
||||
m = LM(ms.mode)
|
||||
ms.height = -1.0 if m in STANDING_MODES else max(0.1, min(0.8, ms.height if ms.height>=0 else 0.2))
|
||||
if m in STATIC_MODES:
|
||||
ms.speed = -1.0
|
||||
elif m in SPEED_RANGES:
|
||||
lo, hi = SPEED_RANGES[m]
|
||||
ms.speed = max(lo, min(hi, ms.speed if ms.speed>=0 else lo))
|
||||
elif m in BOXING_MODES:
|
||||
ms.speed = max(0.7, min(1.5, ms.speed if ms.speed>=0 else 0.7))
|
||||
else:
|
||||
ms.speed = -1.0
|
||||
|
||||
def replan_interval(mode):
|
||||
m = LM(mode)
|
||||
if m == LM.RUN: return REPLAN_INTERVAL["running"]
|
||||
if m == LM.CRAWLING: return REPLAN_INTERVAL["crawling"]
|
||||
if m in {LM.LEFT_PUNCH, LM.RIGHT_PUNCH, LM.RANDOM_PUNCH, LM.LEFT_HOOK, LM.RIGHT_HOOK}:
|
||||
return REPLAN_INTERVAL["boxing"]
|
||||
return REPLAN_INTERVAL["default"]
|
||||
|
||||
# ── Movement state ────────────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class MovementState:
|
||||
mode: int = 0
|
||||
speed: float = -1.0
|
||||
height: float = -1.0
|
||||
facing_angle: float = 0.0
|
||||
movement_angle: float = 0.0
|
||||
has_movement: bool = False
|
||||
motion_set_idx: int = 0
|
||||
needs_replan: bool = False
|
||||
|
||||
@property
|
||||
def movement_direction(self):
|
||||
if not self.has_movement: return (0.0, 0.0, 0.0)
|
||||
return (math.cos(self.movement_angle), math.sin(self.movement_angle), 0.0)
|
||||
|
||||
@property
|
||||
def facing_direction(self):
|
||||
return (math.cos(self.facing_angle), math.sin(self.facing_angle), 0.0)
|
||||
|
||||
def status_line(self):
|
||||
return (f"[{MOTION_SETS[self.motion_set_idx][0]}] mode={self.mode}({LM(self.mode).name}) "
|
||||
f"spd={'default' if self.speed<0 else f'{self.speed:.1f}'} "
|
||||
f"hgt={'default' if self.height<0 else f'{self.height:.2f}'} "
|
||||
f"facing={math.degrees(self.facing_angle):.0f}° "
|
||||
f"{'moving' if self.has_movement else 'still'}")
|
||||
|
||||
# ── Encoder / Decoder ─────────────────────────────────────────────────────────
|
||||
|
||||
class StandingEncoderDecoder:
|
||||
def __init__(self, encoder, decoder):
|
||||
self.encoder, self.decoder = encoder, decoder
|
||||
self.encoder_input = encoder.get_inputs()[0].name
|
||||
self.decoder_input = decoder.get_inputs()[0].name
|
||||
enc_dim = int(encoder.get_inputs()[0].shape[1])
|
||||
dec_dim = int(decoder.get_inputs()[0].shape[1])
|
||||
if enc_dim != 1762 or dec_dim != 994:
|
||||
raise RuntimeError(f"Unexpected dims encoder={enc_dim}, decoder={dec_dim}")
|
||||
self.token = np.zeros(TOKEN_DIM, np.float32)
|
||||
self.last_action_mj = np.zeros(29, np.float32)
|
||||
self.h_q_mj = [np.zeros(29, np.float32)] * 10
|
||||
self.h_dq_mj = [np.zeros(29, np.float32)] * 10
|
||||
self.h_ang = [np.zeros(3, np.float32)] * 10
|
||||
self.h_act_mj = [np.zeros(29, np.float32)] * 10
|
||||
self.h_quat = [np.array([1,0,0,0], np.float32)] * 10
|
||||
self.init_base_quat = np.array([1,0,0,0], np.float32)
|
||||
self.init_ref_quat = np.array([1,0,0,0], np.float32)
|
||||
self._heading_init = False
|
||||
self.encode_mode = 0
|
||||
self.vr_3point_local_target = VR_TARGET_DEF.copy()
|
||||
self.vr_3point_local_orn_target = VR_ORN_DEF.copy()
|
||||
self.smpl_joints_10frame_step1 = SMPL_DEF.copy()
|
||||
self.set_zero_reference()
|
||||
|
||||
def update_history(self, q, dq, ang, quat):
|
||||
quat = quat / (np.linalg.norm(quat)+1e-8)
|
||||
q_mj = _to_mujoco(q); dq_mj = _to_mujoco(dq)
|
||||
self.h_q_mj = [q_mj - DEFAULT_ANGLES_MUJOCO] + self.h_q_mj[:-1]
|
||||
self.h_dq_mj = [dq_mj] + self.h_dq_mj[:-1]
|
||||
self.h_ang = [ang.copy()] + self.h_ang[:-1]
|
||||
self.h_act_mj = [self.last_action_mj.copy()] + self.h_act_mj[:-1]
|
||||
self.h_quat = [quat.copy()] + self.h_quat[:-1]
|
||||
if not self._heading_init:
|
||||
self.init_base_quat = quat.copy(); self._heading_init = True
|
||||
|
||||
def _heading_quat(self, q):
|
||||
h = calc_heading(q) / 2.0
|
||||
return np.array([np.cos(h), 0, 0, np.sin(h)], np.float32)
|
||||
|
||||
def _heading_quat_inv(self, q):
|
||||
h = calc_heading(q) / 2.0
|
||||
return np.array([np.cos(-h), 0, 0, np.sin(-h)], np.float32)
|
||||
|
||||
def _anchor_6d(self, base_quat, ref_quat=None):
|
||||
if ref_quat is None: ref_quat = self.init_ref_quat
|
||||
delta = quat_mul(self._heading_quat(self.init_base_quat), self._heading_quat_inv(self.init_ref_quat))
|
||||
new_ref = quat_mul(delta, ref_quat)
|
||||
return quat_to_6d(quat_mul(quat_conj(base_quat), new_ref))
|
||||
|
||||
def set_zero_reference(self):
|
||||
self.motion_joint_positions = [ENCODER_STANDING_REF.copy()]
|
||||
self.motion_joint_velocities = [np.zeros(29, np.float32)]
|
||||
self.motion_body_quats = [np.array([1,0,0,0], np.float32)]
|
||||
self.motion_body_z = [DEFAULT_HEIGHT]
|
||||
self.motion_timesteps = 1
|
||||
self.freeze_ref_frame = 0
|
||||
self.init_ref_quat = self.motion_body_quats[0].copy()
|
||||
|
||||
def build_encoder_obs(self):
|
||||
obs = np.zeros(1762, np.float32)
|
||||
obs[0] = float(self.encode_mode)
|
||||
rf = min(self.freeze_ref_frame, self.motion_timesteps - 1)
|
||||
ref_pos, ref_quat = self.motion_joint_positions[rf], self.motion_body_quats[rf]
|
||||
if self.encode_mode == 0:
|
||||
for f in range(10):
|
||||
obs[4+29*f:4+29*(f+1)] = ref_pos
|
||||
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
elif self.encode_mode == 1:
|
||||
ref_lower = ref_pos[LOWER_BODY_IL]
|
||||
for f in range(10):
|
||||
obs[661+12*f:661+12*(f+1)] = ref_lower
|
||||
obs[901:910] = self.vr_3point_local_target
|
||||
obs[910:922] = self.vr_3point_local_orn_target
|
||||
obs[595:601] = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
elif self.encode_mode == 2:
|
||||
obs[922:1642] = self.smpl_joints_10frame_step1
|
||||
for f in range(10):
|
||||
obs[1642+6*f:1642+6*(f+1)] = self._anchor_6d(self.h_quat[0], ref_quat)
|
||||
obs[1702+6*f:1702+6*(f+1)] = ref_pos[WRIST_IL]
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported encoder mode: {self.encode_mode}")
|
||||
return obs
|
||||
|
||||
def build_decoder_obs(self):
|
||||
obs = np.zeros(994, np.float32); off = 0
|
||||
obs[off:off+64] = self.token; off += 64
|
||||
for h, sz in [(list(reversed(self.h_ang)),3), (list(reversed(self.h_q_mj)),29),
|
||||
(list(reversed(self.h_dq_mj)),29), (list(reversed(self.h_act_mj)),29)]:
|
||||
for f in range(10): obs[off:off+sz] = h[f]; off += sz
|
||||
for q in reversed(self.h_quat):
|
||||
obs[off:off+3] = gravity_dir(q); off += 3
|
||||
assert off == 994, f"Decoder obs mismatch: {off}"
|
||||
return obs
|
||||
|
||||
def run_encoder(self):
|
||||
return self.encoder.run(None, {self.encoder_input: self.build_encoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
|
||||
|
||||
def step(self, robot_obs, update_encoder, debug=False):
|
||||
jnames = [m.name for m in G1_29_JointIndex]
|
||||
q = np.array([robot_obs.get(f"{n}.q", DEFAULT_ANGLES[m.value]) for m,n in zip(G1_29_JointIndex,jnames)], np.float32)
|
||||
dq = np.array([robot_obs.get(f"{n}.dq", 0.0) for n in jnames], np.float32)
|
||||
quat = np.array([robot_obs.get("imu.quat.w",1), robot_obs.get("imu.quat.x",0),
|
||||
robot_obs.get("imu.quat.y",0), robot_obs.get("imu.quat.z",0)], np.float32)
|
||||
ang = np.array([robot_obs.get(f"imu.gyro.{a}",0) for a in "xyz"], np.float32)
|
||||
self.update_history(q, dq, ang, quat)
|
||||
if update_encoder: self.token = self.run_encoder()
|
||||
action_mj = self.decoder.run(None, {self.decoder_input: self.build_decoder_obs().reshape(1,-1)})[0].squeeze().astype(np.float32)
|
||||
self.last_action_mj = action_mj.copy()
|
||||
target = DEFAULT_ANGLES + action_mj[ISAACLAB_TO_MUJOCO] * ACTION_SCALE
|
||||
if debug:
|
||||
delta = target - q
|
||||
print(f"token_norm={np.linalg.norm(self.token):.4f} action_norm={np.linalg.norm(action_mj):.4f} "
|
||||
f"delta_max={np.max(np.abs(delta)):.4f} delta_rms={np.sqrt(np.mean(delta**2)):.4f}")
|
||||
return {f"{m.name}.q": float(target[m.value]) for m in G1_29_JointIndex}
|
||||
|
||||
def print_input_diagnostics(self):
|
||||
print("\n[Diag] Standing reference checks")
|
||||
names = {0:"g1", 1:"teleop", 2:"smpl"}
|
||||
print(f" encoder mode: {self.encode_mode} ({names.get(self.encode_mode,'unknown')})")
|
||||
print(f" DEFAULT_ANGLES range: [{DEFAULT_ANGLES.min():+.4f}, {DEFAULT_ANGLES.max():+.4f}]")
|
||||
print(f" anchor_6d(identity): {self._anchor_6d(np.array([1,0,0,0],np.float32), np.array([1,0,0,0],np.float32))}")
|
||||
print(f" gravity(identity): {gravity_dir(np.array([1,0,0,0],np.float32))} (expect [0,0,-1])")
|
||||
dec0 = self.build_decoder_obs()
|
||||
print(f" decoder q-delta max: {np.max(np.abs(dec0[94:384])):.6f}")
|
||||
print(f" decoder dq max: {np.max(np.abs(dec0[384:674])):.6f}")
|
||||
|
||||
# ── Planner motion buffer ─────────────────────────────────────────────────────
|
||||
|
||||
class PlannerMotion:
|
||||
def __init__(self, max_frames=1500):
|
||||
self.timesteps = 0
|
||||
self.joint_positions = np.zeros((max_frames, 29), np.float64)
|
||||
self.joint_velocities = np.zeros((max_frames, 29), np.float64)
|
||||
self.body_positions = np.zeros((max_frames, 3), np.float64)
|
||||
self.body_quaternions = np.zeros((max_frames, 4), np.float64)
|
||||
self.body_quaternions[:, 0] = 1.0
|
||||
|
||||
# ── Subprocess planner ────────────────────────────────────────────────────────
|
||||
|
||||
def _resample_30_to_50(qpos, n30):
|
||||
t50 = int(np.floor(n30 / 30.0 * 50))
|
||||
f30 = np.arange(t50) / 50.0 * 30.0
|
||||
f0 = np.floor(f30).astype(int)
|
||||
f1 = np.minimum(f0+1, n30-1)
|
||||
frac, w0 = (f30-f0).astype(np.float64), None
|
||||
w0 = 1.0 - frac
|
||||
jp = (w0[:,None]*qpos[f0,7:36] + frac[:,None]*qpos[f1,7:36])[:,MUJOCO_TO_ISAACLAB]
|
||||
jv = np.zeros_like(jp)
|
||||
if t50 >= 2: jv[:t50-1] = (jp[1:] - jp[:-1]) * 50.0; jv[-1] = jv[-2]
|
||||
return {
|
||||
"timesteps": t50,
|
||||
"joint_positions": jp,
|
||||
"joint_velocities": jv,
|
||||
"body_positions": w0[:,None]*qpos[f0,:3] + frac[:,None]*qpos[f1,:3],
|
||||
"body_quaternions": quat_slerp_batch(qpos[f0,3:7], qpos[f1,3:7], frac),
|
||||
}
|
||||
|
||||
def _build_planner_inputs(ctx, ms_dict, version, seed):
|
||||
inp = {
|
||||
"context_mujoco_qpos": ctx.astype(np.float32).reshape(1,4,36),
|
||||
"target_vel": np.array([ms_dict["speed"]], np.float32),
|
||||
"mode": np.array([ms_dict["mode"]], np.int64),
|
||||
"movement_direction": np.array(ms_dict["movement_direction"], np.float32).reshape(1,3),
|
||||
"facing_direction": np.array(ms_dict["facing_direction"], np.float32).reshape(1,3),
|
||||
"random_seed": np.array([seed], np.int64),
|
||||
}
|
||||
if version >= 1:
|
||||
allowed = np.zeros((1,K), np.int64); allowed[0,:6] = 1
|
||||
inp.update({
|
||||
"height": np.array([ms_dict["height"]], np.float32),
|
||||
"has_specific_target": np.array([[0]], np.int64),
|
||||
"specific_target_positions": np.zeros((1,4,3), np.float32),
|
||||
"specific_target_headings": np.zeros((1,4), np.float32),
|
||||
"allowed_pred_num_tokens": allowed,
|
||||
})
|
||||
return inp
|
||||
|
||||
def _planner_worker(path, req_q, res_q, stop_evt, version, seed):
|
||||
so = ort.SessionOptions(); so.log_severity_level = 3
|
||||
sess = ort.InferenceSession(path, sess_options=so, providers=["CPUExecutionProvider"])
|
||||
while not stop_evt.is_set():
|
||||
try: ctx, gf, ms_dict = req_q.get(timeout=0.05)
|
||||
except Exception: continue
|
||||
try:
|
||||
inp = _build_planner_inputs(ctx, ms_dict, version, seed)
|
||||
t0 = time.time()
|
||||
qpos_out, num_pred = sess.run(None, inp)
|
||||
t_inf = time.time()
|
||||
n = int(num_pred.flat[0])
|
||||
qpos = qpos_out[0,:n]
|
||||
if np.any(np.isnan(qpos)): continue
|
||||
motion = _resample_30_to_50(qpos, n)
|
||||
motion["gen_frame"] = gf
|
||||
print(f"[Planner] inf={1000*(t_inf-t0):.1f}ms total={1000*(time.time()-t0):.1f}ms frames={n}", flush=True)
|
||||
while not res_q.empty():
|
||||
try: res_q.get_nowait()
|
||||
except Exception: break
|
||||
res_q.put(motion)
|
||||
except Exception as e:
|
||||
print(f"[Planner] Error: {e}", flush=True)
|
||||
|
||||
# ── SonicPlanner ──────────────────────────────────────────────────────────────
|
||||
|
||||
class SonicPlanner:
|
||||
def __init__(self, session, planner_path):
|
||||
self.session = session
|
||||
self.planner_path = planner_path
|
||||
self.gen_frame = 0
|
||||
self.random_seed = INITIAL_RANDOM_SEED
|
||||
self.version = 1 if len(session.get_inputs()) >= 11 else 0
|
||||
self.motion_50hz = PlannerMotion()
|
||||
self._snapshot = PlannerMotion()
|
||||
self._req_q = self._res_q = self._stop_evt = self._proc = None
|
||||
self._ctrl = None
|
||||
|
||||
def _build_inputs(self, ctx, ms):
|
||||
return _build_planner_inputs(
|
||||
ctx,
|
||||
{"mode": ms.mode, "speed": ms.speed, "height": ms.height,
|
||||
"movement_direction": list(ms.movement_direction),
|
||||
"facing_direction": list(ms.facing_direction)},
|
||||
self.version, self.random_seed)
|
||||
|
||||
@staticmethod
|
||||
def build_initial_context(joint_positions):
|
||||
ctx = np.zeros((4,36), np.float32)
|
||||
for n in range(4):
|
||||
ctx[n,2] = DEFAULT_HEIGHT; ctx[n,3] = 1.0
|
||||
ctx[n,7:36] = joint_positions.astype(np.float32)
|
||||
return ctx
|
||||
|
||||
def _context_from_controller(self, current_frame):
|
||||
ctrl = self._ctrl
|
||||
gen_frame = current_frame + MOTION_LOOK_AHEAD_STEPS
|
||||
t_arr = gen_frame/50.0 + np.arange(4)/30.0
|
||||
f50 = t_arr * 50.0
|
||||
with ctrl.motion_lock:
|
||||
ts = ctrl.motion_timesteps
|
||||
bp = ctrl.motion_body_pos[:ts].copy()
|
||||
bq = ctrl.motion_body_quats[:ts].copy()
|
||||
jp = ctrl.motion_joint_positions[:ts].copy()
|
||||
f0 = np.minimum(np.floor(f50).astype(int), ts-1)
|
||||
f1 = np.minimum(f0+1, ts-1)
|
||||
frac, w0 = f50-f0, None; w0 = 1.0-frac
|
||||
ctx = np.zeros((4,36), np.float32)
|
||||
ctx[:,0:3] = w0[:,None]*bp[f0] + frac[:,None]*bp[f1]
|
||||
ctx[:,3:7] = quat_slerp_batch(bq[f0], bq[f1], frac)
|
||||
ij = w0[:,None]*jp[f0] + frac[:,None]*jp[f1]
|
||||
ctx[:,7:36] = ij[:,ISAACLAB_TO_MUJOCO]
|
||||
self.gen_frame = gen_frame
|
||||
return ctx
|
||||
|
||||
def _load_motion_in_place(self, qpos, n30, target=None):
|
||||
if target is None: target = self.motion_50hz
|
||||
r = _resample_30_to_50(qpos, n30)
|
||||
n = r["timesteps"]; target.timesteps = n
|
||||
target.joint_positions[:n] = r["joint_positions"]
|
||||
target.joint_velocities[:n] = r["joint_velocities"]
|
||||
target.body_positions[:n] = r["body_positions"]
|
||||
target.body_quaternions[:n] = r["body_quaternions"]
|
||||
return target
|
||||
|
||||
def initialize(self, joint_positions, ms):
|
||||
ctx = self.build_initial_context(joint_positions)
|
||||
qpos_out, num_pred = self.session.run(None, self._build_inputs(ctx, ms))
|
||||
n = int(num_pred.flat[0]); qpos = qpos_out[0,:n]
|
||||
if np.any(np.isnan(qpos)): raise RuntimeError("Planner initial output contains NaN")
|
||||
print(f"[Planner] Init: {n} frames @ 30 Hz")
|
||||
self._load_motion_in_place(qpos, n)
|
||||
print(f"[Planner] Resampled to {self.motion_50hz.timesteps} frames @ 50 Hz")
|
||||
return self.motion_50hz
|
||||
|
||||
def request_replan(self, cursor, ms):
|
||||
if self._req_q is None: return
|
||||
ctx = self._context_from_controller(cursor)
|
||||
ms_dict = {"mode": ms.mode, "speed": ms.speed, "height": ms.height,
|
||||
"movement_direction": list(ms.movement_direction),
|
||||
"facing_direction": list(ms.facing_direction)}
|
||||
while not self._req_q.empty():
|
||||
try: self._req_q.get_nowait()
|
||||
except Exception: break
|
||||
self._req_q.put((ctx, self.gen_frame, ms_dict))
|
||||
|
||||
def try_get_new_motion(self):
|
||||
if self._res_q is None: return None
|
||||
result = None
|
||||
while not self._res_q.empty():
|
||||
try: result = self._res_q.get_nowait()
|
||||
except Exception: break
|
||||
if result is None: return None
|
||||
n, gf = result["timesteps"], result["gen_frame"]
|
||||
s = self._snapshot; s.timesteps = n
|
||||
s.joint_positions[:n] = result["joint_positions"]
|
||||
s.joint_velocities[:n] = result["joint_velocities"]
|
||||
s.body_positions[:n] = result["body_positions"]
|
||||
s.body_quaternions[:n] = result["body_quaternions"]
|
||||
return s, gf
|
||||
|
||||
def start_subprocess(self, controller):
|
||||
self._ctrl = controller
|
||||
self._req_q, self._res_q, self._stop_evt = mp.Queue(), mp.Queue(), mp.Event()
|
||||
self._proc = mp.Process(
|
||||
target=_planner_worker,
|
||||
args=(self.planner_path, self._req_q, self._res_q,
|
||||
self._stop_evt, self.version, self.random_seed),
|
||||
daemon=True)
|
||||
self._proc.start()
|
||||
print(f"[Planner] Background process started (PID={self._proc.pid})")
|
||||
|
||||
def stop_subprocess(self):
|
||||
if self._stop_evt: self._stop_evt.set()
|
||||
if self._proc:
|
||||
self._proc.join(timeout=3.0)
|
||||
if self._proc.is_alive(): self._proc.terminate()
|
||||
print("[Planner] Background process stopped")
|
||||
for q in (self._req_q, self._res_q):
|
||||
if q: q.close()
|
||||
|
||||
# ── PlannerController ─────────────────────────────────────────────────────────
|
||||
|
||||
class PlannerController(StandingEncoderDecoder):
|
||||
def __init__(self, planner, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
self.planner = planner
|
||||
self.ref_cursor = 0
|
||||
self.motion_timesteps = 0
|
||||
self.motion_joint_positions = np.zeros((1500,29), np.float64)
|
||||
self.motion_joint_velocities = np.zeros((1500,29), np.float64)
|
||||
self.motion_body_quats = np.zeros((1500,4), np.float64); self.motion_body_quats[:,0] = 1.0
|
||||
self.motion_body_pos = np.zeros((1500,3), np.float64)
|
||||
self.init_ref_quat = np.array([1,0,0,0], np.float64)
|
||||
self.heading_init_base_quat = np.array([1,0,0,0], np.float64)
|
||||
self.delta_heading = 0.0
|
||||
self.reinit_heading = False
|
||||
self.playing = self.first_motion = False
|
||||
self.motion_lock = threading.Lock()
|
||||
|
||||
def load_initial_motion(self, motion):
|
||||
with self.motion_lock:
|
||||
n = motion.timesteps
|
||||
self.motion_timesteps = n
|
||||
self.motion_joint_positions[:n] = motion.joint_positions[:n]
|
||||
self.motion_joint_velocities[:n] = motion.joint_velocities[:n]
|
||||
self.motion_body_quats[:n] = motion.body_quaternions[:n]
|
||||
self.motion_body_pos[:n] = motion.body_positions[:n]
|
||||
self.init_ref_quat = motion.body_quaternions[0].copy()
|
||||
self.ref_cursor = 0; self.first_motion = True
|
||||
self.playing = True; self.delta_heading = 0.0
|
||||
|
||||
def blend_new_motion(self, new_motion, gen_frame):
|
||||
with self.motion_lock:
|
||||
cur = self.ref_cursor
|
||||
new_len = gen_frame - cur + new_motion.timesteps
|
||||
if new_len <= 0: return
|
||||
f_arr = np.arange(new_len)
|
||||
f_old = np.minimum(f_arr + cur, self.motion_timesteps - 1)
|
||||
f_new = np.clip(f_arr + cur - gen_frame, 0, new_motion.timesteps - 1)
|
||||
blend_start = max(0, gen_frame - cur)
|
||||
w_new = np.clip((f_arr - blend_start) / BLEND_FRAMES if BLEND_FRAMES > 0
|
||||
else np.ones(new_len), 0.0, 1.0)
|
||||
w_old = 1.0 - w_new
|
||||
self.motion_joint_positions[:new_len] = w_old[:,None]*self.motion_joint_positions[f_old] + w_new[:,None]*new_motion.joint_positions[f_new]
|
||||
self.motion_joint_velocities[:new_len] = w_old[:,None]*self.motion_joint_velocities[f_old] + w_new[:,None]*new_motion.joint_velocities[f_new]
|
||||
self.motion_body_pos[:new_len] = w_old[:,None]*self.motion_body_pos[f_old] + w_new[:,None]*new_motion.body_positions[f_new]
|
||||
self.motion_body_quats[:new_len] = quat_slerp_batch(self.motion_body_quats[f_old], new_motion.body_quaternions[f_new], w_new)
|
||||
self.motion_timesteps = new_len; self.first_motion = False; self.ref_cursor = 0
|
||||
self.init_ref_quat = self.motion_body_quats[0].copy()
|
||||
|
||||
def _heading_apply_delta(self):
|
||||
delta = quat_mul(heading_quat(self.heading_init_base_quat).astype(np.float32),
|
||||
heading_quat_inv(self.init_ref_quat).astype(np.float32))
|
||||
if self.delta_heading:
|
||||
h = self.delta_heading / 2.0
|
||||
delta = quat_mul(np.array([np.cos(h),0,0,np.sin(h)], np.float32), delta)
|
||||
return delta
|
||||
|
||||
def _anchor_6d(self, base_quat, ref_quat=None):
|
||||
if ref_quat is None: ref_quat = self.init_ref_quat
|
||||
new_ref = quat_mul(self._heading_apply_delta(), ref_quat.astype(np.float32))
|
||||
return quat_to_6d(quat_mul(quat_conj(base_quat.astype(np.float32)), new_ref))
|
||||
|
||||
def build_encoder_obs(self):
|
||||
obs = np.zeros(1762, np.float32); obs[0] = float(self.encode_mode)
|
||||
with self.motion_lock:
|
||||
for f in range(10):
|
||||
tf = min(self.ref_cursor + f*5 if self.playing else self.ref_cursor,
|
||||
self.motion_timesteps - 1)
|
||||
obs[4+29*f:4+29*(f+1)] = self.motion_joint_positions[tf].astype(np.float32)
|
||||
if self.playing:
|
||||
obs[294+29*f:294+29*(f+1)] = self.motion_joint_velocities[tf].astype(np.float32)
|
||||
obs[601+6*f:601+6*(f+1)] = self._anchor_6d(
|
||||
self.h_quat[0], self.motion_body_quats[tf].astype(np.float32))
|
||||
return obs
|
||||
|
||||
def step(self, robot_obs, update_encoder, debug=False):
|
||||
if robot_obs and (self.first_motion or self.reinit_heading):
|
||||
q = robot_obs.get("imu.quaternion")
|
||||
if q is not None:
|
||||
self.heading_init_base_quat = np.array(q, np.float64)
|
||||
with self.motion_lock:
|
||||
rf = min(self.ref_cursor, self.motion_timesteps - 1)
|
||||
self.init_ref_quat = self.motion_body_quats[rf].copy()
|
||||
self.delta_heading = 0.0
|
||||
self.first_motion = False
|
||||
self.reinit_heading = False
|
||||
print(f"[Heading] init quat: {self.heading_init_base_quat}")
|
||||
return super().step(robot_obs, update_encoder=update_encoder, debug=debug)
|
||||
|
||||
def advance_cursor(self, wall_dt):
|
||||
if not self.playing: return
|
||||
frames = max(1, round(wall_dt / CONTROL_DT))
|
||||
with self.motion_lock:
|
||||
self.ref_cursor = min(self.ref_cursor + frames, self.motion_timesteps - 1)
|
||||
|
||||
# ── Keyboard ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class RawKeyboard:
|
||||
def __init__(self):
|
||||
self.fd = sys.stdin.fileno()
|
||||
self.old = termios.tcgetattr(self.fd)
|
||||
def __enter__(self): tty.setcbreak(self.fd); return self
|
||||
def __exit__(self, *_): termios.tcsetattr(self.fd, termios.TCSADRAIN, self.old)
|
||||
def get_key(self):
|
||||
return sys.stdin.read(1) if select.select([sys.stdin],[],[],0)[0] else None
|
||||
|
||||
def process_keyboard(key, ms, controller=None):
|
||||
if key is None: return False
|
||||
if key == '\x1b': return True
|
||||
if key == ' ':
|
||||
ms.mode = LM.IDLE; ms.speed = ms.height = -1.0
|
||||
ms.has_movement = False; ms.needs_replan = True
|
||||
if controller: controller.playing = False; controller.reinit_heading = True
|
||||
print("\n >> EMERGENCY STOP -> IDLE"); return False
|
||||
if key in ('r','R'):
|
||||
ms.needs_replan = True; print("\n >> Manual replan"); return False
|
||||
if key in ('n','N','p','P'):
|
||||
ms.motion_set_idx = (ms.motion_set_idx + (1 if key in ('n','N') else -1)) % len(MOTION_SETS)
|
||||
name, modes = MOTION_SETS[ms.motion_set_idx]
|
||||
print(f"\n >> Motion set: {name}")
|
||||
[print(f" {i+1}: {m.name}") for i,m in enumerate(modes)]
|
||||
return False
|
||||
if key.isdigit() and key not in ('9','0'):
|
||||
idx = int(key) - 1; modes = MOTION_SETS[ms.motion_set_idx][1]
|
||||
if 0 <= idx < len(modes):
|
||||
ms.mode = modes[idx]; ms.needs_replan = True
|
||||
if controller: controller.playing = True; controller.reinit_heading = True
|
||||
print(f"\n >> Mode: {LM(ms.mode).name} ({ms.mode}) [replanning...]")
|
||||
return False
|
||||
if key == '9':
|
||||
ms.speed = max(0.0, (ms.speed if ms.speed>=0 else 1.0) - 0.1)
|
||||
print(f"\n >> Speed: {ms.speed:.1f}"); return False
|
||||
if key == '0':
|
||||
ms.speed = min(5.0, (ms.speed if ms.speed>=0 else 1.0) + 0.1)
|
||||
print(f"\n >> Speed: {ms.speed:.1f}"); return False
|
||||
if key == '-':
|
||||
ms.height = max(0.2, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) - 0.02)
|
||||
print(f"\n >> Height: {ms.height:.2f}"); return False
|
||||
if key == '=':
|
||||
ms.height = min(1.0, (ms.height if ms.height>=0 else DEFAULT_HEIGHT) + 0.02)
|
||||
print(f"\n >> Height: {ms.height:.2f}"); return False
|
||||
if key.lower() == 'w': ms.movement_angle = ms.facing_angle
|
||||
elif key.lower() == 's': ms.movement_angle = ms.facing_angle + math.pi
|
||||
elif key.lower() == 'a': ms.movement_angle = ms.facing_angle + math.pi/2
|
||||
elif key.lower() == 'd': ms.movement_angle = ms.facing_angle - math.pi/2
|
||||
if key.lower() in ('w','s','a','d'):
|
||||
ms.has_movement = ms.needs_replan = True
|
||||
elif key.lower() == 'q':
|
||||
ms.facing_angle += 0.1
|
||||
if controller: controller.delta_heading += 0.1
|
||||
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
|
||||
elif key.lower() == 'e':
|
||||
ms.facing_angle -= 0.1
|
||||
if controller: controller.delta_heading -= 0.1
|
||||
print(f"\n >> Facing: {math.degrees(ms.facing_angle):.0f}°")
|
||||
return False
|
||||
|
||||
_joy_prev_active = False
|
||||
|
||||
|
||||
def _parse_wireless(wr):
|
||||
"""Parse wireless_remote (bytes or int-array) into (lx, ly, rx, ry)."""
|
||||
import struct as _st
|
||||
if not isinstance(wr, (bytes, bytearray)):
|
||||
wr = bytes(wr)
|
||||
if len(wr) < 24:
|
||||
return None
|
||||
lx = _st.unpack("f", wr[4:8])[0]
|
||||
rx = _st.unpack("f", wr[8:12])[0]
|
||||
ry = _st.unpack("f", wr[12:16])[0]
|
||||
ly = _st.unpack("f", wr[20:24])[0]
|
||||
return lx, ly, rx, ry
|
||||
|
||||
|
||||
def process_joystick(obs, ms, controller=None):
|
||||
"""Joystick mirrors keyboard: left stick=WASD, right stick X=Q/E, right stick Y=height."""
|
||||
global _joy_prev_active
|
||||
wr = obs.get("wireless_remote")
|
||||
if wr is None:
|
||||
return
|
||||
parsed = _parse_wireless(wr)
|
||||
if parsed is None:
|
||||
return
|
||||
lx, ly, rx, ry = parsed
|
||||
|
||||
# Dead zone + negate both Y axes (bridge already flips them once)
|
||||
lx = 0.0 if abs(lx) < DEADZONE else lx
|
||||
ly = 0.0 if abs(ly) < DEADZONE else -ly
|
||||
rx = 0.0 if abs(rx) < DEADZONE else rx
|
||||
ry = 0.0 if abs(ry) < DEADZONE else -ry
|
||||
|
||||
left_active = abs(lx) > 0 or abs(ly) > 0
|
||||
|
||||
# Left stick → WASD (movement direction relative to facing)
|
||||
if left_active:
|
||||
ms.movement_angle = ms.facing_angle + math.atan2(-lx, -ly)
|
||||
ms.has_movement = True
|
||||
if not _joy_prev_active:
|
||||
ms.needs_replan = True
|
||||
_joy_prev_active = True
|
||||
elif _joy_prev_active and not (abs(rx) > 0 or abs(ry) > 0):
|
||||
_joy_prev_active = False
|
||||
ms.has_movement = False
|
||||
|
||||
# Right stick X → Q/E (facing rotation, ~1 rad/s at full deflection)
|
||||
if abs(rx) > 0:
|
||||
delta = -0.02 * rx
|
||||
ms.facing_angle += delta
|
||||
if controller:
|
||||
controller.delta_heading += delta
|
||||
|
||||
# Right stick Y → -/= (height adjustment, ~0.25/s at full deflection)
|
||||
if abs(ry) > 0:
|
||||
step = -0.005 * ry
|
||||
ms.height = max(0.1, min(1.0, (ms.height if ms.height >= 0 else DEFAULT_HEIGHT) + step))
|
||||
|
||||
# ── Main ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="SONIC planner with keyboard + gamepad control")
|
||||
parser.add_argument("--ip", type=str, default=None,
|
||||
help="Robot IP for real hardware (e.g. 192.168.123.164). "
|
||||
"Omit for simulation.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 60)
|
||||
print("SONIC planner - full mode control")
|
||||
print(" N/P cycle sets | 1-8 select mode | WASD move")
|
||||
print(" Q/E rotate | 9/0 speed | -/= height")
|
||||
print(" R replan | Space IDLE | Esc quit")
|
||||
if args.ip:
|
||||
print(f" Robot IP: {args.ip}")
|
||||
else:
|
||||
print(" Mode: simulation")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
planner_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="planner_sonic.onnx")
|
||||
encoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_encoder.onnx")
|
||||
decoder_path = hf_hub_download(repo_id="nvidia/GEAR-SONIC", filename="model_decoder.onnx")
|
||||
|
||||
providers = ort.get_available_providers()
|
||||
use_gpu = "CUDAExecutionProvider" in providers
|
||||
gpu_ep = (["CUDAExecutionProvider","CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"])
|
||||
so = ort.SessionOptions(); so.log_severity_level = 3
|
||||
|
||||
print(f"[ONNX] enc/dec={'GPU' if use_gpu else 'CPU'}, planner=CPU")
|
||||
planner_sess = ort.InferenceSession(planner_path, sess_options=so, providers=["CPUExecutionProvider"])
|
||||
encoder_sess = ort.InferenceSession(encoder_path, sess_options=so, providers=gpu_ep)
|
||||
decoder_sess = ort.InferenceSession(decoder_path, sess_options=so, providers=gpu_ep)
|
||||
print(f"[Planner] version={'v1+' if len(planner_sess.get_inputs())>=11 else 'v0'}")
|
||||
|
||||
cfg = UnitreeG1Config()
|
||||
if args.ip:
|
||||
cfg.is_simulation = False
|
||||
cfg.robot_ip = args.ip
|
||||
robot = UnitreeG1(cfg); robot.connect()
|
||||
kp, kd = _kp_kd(); robot.kp = kp.copy(); robot.kd = kd.copy()
|
||||
|
||||
ms = MovementState()
|
||||
planner = SonicPlanner(planner_sess, planner_path)
|
||||
controller = PlannerController(planner, encoder_sess, decoder_sess)
|
||||
|
||||
motion = planner.initialize(DEFAULT_ANGLES, ms)
|
||||
controller.load_initial_motion(motion)
|
||||
controller.print_input_diagnostics()
|
||||
planner.start_subprocess(controller)
|
||||
|
||||
print(f"\nStarting: {MOTION_SETS[0][0]}")
|
||||
[print(f" {i+1}: {m.name}") for i,m in enumerate(MOTION_SETS[0][1])]
|
||||
|
||||
with RawKeyboard() as kb:
|
||||
try:
|
||||
gc.disable(); gc_timer = 0.0
|
||||
robot.reset(CONTROL_DT, DEFAULT_ANGLES); time.sleep(1.0)
|
||||
|
||||
step = 0; last_status = replan_timer = 0.0
|
||||
loop_t = enc_t = dec_t = obs_t = act_t = []
|
||||
slow_n = blend_n = 0; stall_src = ""; did_blend = False
|
||||
prev_end = time.time(); t_start = time.time()
|
||||
|
||||
log_path = "/tmp/sonic_pose_log.csv"
|
||||
jnames = [m.name for m in G1_29_JointIndex]
|
||||
with open(log_path, "w") as log_f:
|
||||
log_f.write("t,step,cursor,ts,blend,mode," +
|
||||
",".join(f"q{i}" for i in range(29)) + "," +
|
||||
",".join(f"ref{i}" for i in range(29)) + "," +
|
||||
",".join(f"act{i}" for i in range(29)) +
|
||||
",delta_max,action_norm,token_norm\n")
|
||||
|
||||
while not robot._shutdown_event.is_set():
|
||||
t0 = time.time()
|
||||
if process_keyboard(kb.get_key(), ms, controller): break
|
||||
|
||||
obs = robot.get_observation(); t_obs = time.time()
|
||||
obs_t.append(1000*(t_obs - t0))
|
||||
if not obs:
|
||||
step += 1; prev_end = time.time()
|
||||
time.sleep(max(0.0, CONTROL_DT-(time.time()-t0))); continue
|
||||
|
||||
process_joystick(obs, ms, controller)
|
||||
clamp_mode_params(ms)
|
||||
|
||||
is_static = LM(ms.mode) in STATIC_MODES
|
||||
do_req = ms.needs_replan and step > 0
|
||||
if do_req: ms.needs_replan = False; replan_timer = 0.0
|
||||
elif not is_static and step > 0 and ms.speed != 0:
|
||||
replan_timer += CONTROL_DT
|
||||
if replan_timer >= replan_interval(ms.mode):
|
||||
do_req = True; replan_timer = 0.0
|
||||
if do_req: planner.request_replan(controller.ref_cursor, ms)
|
||||
|
||||
do_enc = (step % ENCODER_UPDATE_EVERY == 0)
|
||||
t_step = time.time()
|
||||
action = controller.step(obs, update_encoder=do_enc, debug=(step % DEBUG_PRINT_EVERY == 0))
|
||||
step_ms = 1000*(time.time()-t_step)
|
||||
(enc_t if do_enc else dec_t).append(step_ms)
|
||||
|
||||
t_act = time.time()
|
||||
robot.send_action(action)
|
||||
act_t.append(1000*(time.time()-t_act))
|
||||
|
||||
result = planner.try_get_new_motion()
|
||||
t_blend = time.time()
|
||||
if result:
|
||||
controller.blend_new_motion(*result)
|
||||
blend_ms = 1000*(time.time()-t_blend)
|
||||
blend_n += 1; did_blend = True
|
||||
else:
|
||||
blend_ms = 0.0
|
||||
|
||||
if step % 5 == 0:
|
||||
t_rel = time.time() - t_start
|
||||
q_r = np.array([obs.get(f"{n}.q", 0) for n in jnames])
|
||||
a_v = np.array([action.get(f"{n}.q", 0) for n in jnames])
|
||||
cur, ts = controller.ref_cursor, controller.motion_timesteps
|
||||
q_ref = controller.motion_joint_positions[min(cur,ts-1)] if ts > 0 else np.zeros(29)
|
||||
log_f.write(f"{t_rel:.4f},{step},{cur},{ts},{int(did_blend)},{ms.mode}," +
|
||||
",".join(f"{v:.6f}" for v in q_r) + "," +
|
||||
",".join(f"{v:.6f}" for v in q_ref) + "," +
|
||||
",".join(f"{v:.6f}" for v in a_v) + "," +
|
||||
f"{np.max(np.abs(a_v-q_r)):.6f},"
|
||||
f"{np.linalg.norm(a_v):.6f},"
|
||||
f"{np.linalg.norm(controller.token):.6f}\n")
|
||||
did_blend = False
|
||||
|
||||
now = time.time(); loop_ms = 1000*(now-t0)
|
||||
wall_dt = now - prev_end; loop_t.append(loop_ms)
|
||||
if loop_ms > 50:
|
||||
stall_src = (f"[STALL] {loop_ms:.0f}ms: "
|
||||
f"obs={obs_t[-1]:.0f} blend={blend_ms:.0f} step={step_ms:.0f} act={act_t[-1]:.0f}")
|
||||
if loop_ms > CONTROL_DT*1500: slow_n += 1
|
||||
|
||||
controller.advance_cursor(wall_dt)
|
||||
|
||||
if now - last_status > 2.0:
|
||||
def _avg(l): return sum(l)/len(l) if l else 0
|
||||
hz = 1000/_avg(loop_t) if _avg(loop_t) else 0
|
||||
print(f"\r {ms.status_line()} step={step} ref={controller.ref_cursor}/{controller.motion_timesteps} "
|
||||
f"loop={_avg(loop_t):.1f}ms(max={max(loop_t,default=0):.1f}) hz={hz:.0f} "
|
||||
f"enc={_avg(enc_t):.1f} dec={_avg(dec_t):.1f} obs={_avg(obs_t):.1f} "
|
||||
f"slow={slow_n} blends={blend_n}", end="", flush=True)
|
||||
if stall_src: print(f"\n {stall_src}"); stall_src = ""
|
||||
last_status = now
|
||||
loop_t=enc_t=dec_t=obs_t=act_t=[]; slow_n=blend_n=0
|
||||
|
||||
prev_end = time.time()
|
||||
gc_timer += CONTROL_DT
|
||||
if gc_timer >= 10.0: gc.collect(); gc_timer = 0.0
|
||||
step += 1
|
||||
time.sleep(max(0.0, CONTROL_DT-(time.time()-t0)))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
gc.enable()
|
||||
print(f"\n[Log] Saved to {log_path}")
|
||||
planner.stop_subprocess()
|
||||
print("\nStopping...")
|
||||
if robot.is_connected: robot.disconnect()
|
||||
print("Done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user