mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
Compare commits
7 Commits
feat/datas
...
fix/re-ena
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0dc324b81 | ||
|
|
1d275e2021 | ||
|
|
24bb2cb0ff | ||
|
|
1d414c07e2 | ||
|
|
e04e3399b9 | ||
|
|
017ff73fbf | ||
|
|
f90db58c15 |
@@ -19,6 +19,8 @@
|
|||||||
title: Multi GPU training
|
title: Multi GPU training
|
||||||
- local: peft_training
|
- local: peft_training
|
||||||
title: Training with PEFT (e.g., LoRA)
|
title: Training with PEFT (e.g., LoRA)
|
||||||
|
- local: rename_map
|
||||||
|
title: Using Rename Map and Empty Cameras
|
||||||
title: "Tutorials"
|
title: "Tutorials"
|
||||||
- sections:
|
- sections:
|
||||||
- local: lerobot-dataset-v3
|
- local: lerobot-dataset-v3
|
||||||
|
|||||||
@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
|
|||||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||||
|
|
||||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
|
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||||
|
|||||||
114
docs/source/rename_map.mdx
Normal file
114
docs/source/rename_map.mdx
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
# Rename Map and Empty Cameras
|
||||||
|
|
||||||
|
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
|
||||||
|
|
||||||
|
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
|
||||||
|
|
||||||
|
## Why observation keys don't always match
|
||||||
|
|
||||||
|
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
|
||||||
|
|
||||||
|
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
|
||||||
|
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
|
||||||
|
|
||||||
|
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
|
||||||
|
|
||||||
|
## Using the rename map
|
||||||
|
|
||||||
|
Pass the mapping as a JSON string on the command line. The convention is always:
|
||||||
|
|
||||||
|
```
|
||||||
|
--rename_map='{"source_key": "policy_key", ...}'
|
||||||
|
```
|
||||||
|
|
||||||
|
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
|
||||||
|
|
||||||
|
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
|
||||||
|
|
||||||
|
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=YOUR_DATASET \
|
||||||
|
--output_dir=./outputs/xvla_training \
|
||||||
|
--job_name=xvla_training \
|
||||||
|
--policy.path="lerobot/xvla-base" \
|
||||||
|
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||||
|
--policy.dtype=bfloat16 \
|
||||||
|
--policy.action_mode=auto \
|
||||||
|
--steps=20000 \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.freeze_vision_encoder=false \
|
||||||
|
--policy.freeze_language_encoder=false \
|
||||||
|
--policy.train_policy_transformer=true \
|
||||||
|
--policy.train_soft_prompts=true \
|
||||||
|
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
|
||||||
|
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path=lerobot/pi0fast-libero \
|
||||||
|
--env.type=libero \
|
||||||
|
... \
|
||||||
|
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Recording
|
||||||
|
|
||||||
|
`lerobot-record` also supports rename maps, nested under the dataset config:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-record \ # When running inference
|
||||||
|
--policy.path="<user>/smolVLA_finetuned" \
|
||||||
|
... \
|
||||||
|
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Alternative: edit the policy config directly
|
||||||
|
|
||||||
|
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||||
|
|
||||||
|
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
|
||||||
|
|
||||||
|
## Empty cameras: fewer views than the policy expects
|
||||||
|
|
||||||
|
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
|
||||||
|
|
||||||
|
### How it works
|
||||||
|
|
||||||
|
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
|
||||||
|
|
||||||
|
```
|
||||||
|
observation.images.empty_camera_0
|
||||||
|
observation.images.empty_camera_1
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
|
||||||
|
|
||||||
|
### Example
|
||||||
|
|
||||||
|
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
|
||||||
|
|
||||||
|
1. Set `--policy.empty_cameras=1`.
|
||||||
|
2. The config adds a third key: `observation.images.empty_camera_0`.
|
||||||
|
3. Use the rename map for your two real cameras as usual.
|
||||||
|
4. The third slot is masked out — no fake images needed in your dataset.
|
||||||
|
|
||||||
|
## Quick reference
|
||||||
|
|
||||||
|
| Goal | What to do |
|
||||||
|
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
||||||
|
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||||
|
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||||
|
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
||||||
|
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||||
|
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||||
@@ -1,717 +0,0 @@
|
|||||||
"""
|
|
||||||
Action consistency analysis for imitation learning datasets.
|
|
||||||
|
|
||||||
Two parallel analyses per dataset:
|
|
||||||
1. State-based: KNN in joint-state space → action chunk variance
|
|
||||||
2. Image-based: KNN in SigLIP embedding space → action chunk variance
|
|
||||||
|
|
||||||
Comparing them reveals whether visual similarity and proprioceptive similarity
|
|
||||||
agree on where the data is inconsistent — and images are what the policy
|
|
||||||
primarily sees.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import av
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from matplotlib.colors import LinearSegmentedColormap
|
|
||||||
from PIL import Image
|
|
||||||
from scipy.spatial import cKDTree
|
|
||||||
from transformers import AutoImageProcessor, AutoModel
|
|
||||||
|
|
||||||
DATASETS = [
|
|
||||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
|
||||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
|
||||||
]
|
|
||||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
|
||||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
MAX_FRAMES = 100_000
|
|
||||||
K_NEIGHBORS = 50
|
|
||||||
ACTION_CHUNK_SIZE = 30
|
|
||||||
CAMERA_KEY = "observation.images.base"
|
|
||||||
ENCODER_MODEL = "google/siglip-base-patch16-224"
|
|
||||||
ENCODE_BATCH_SIZE = 512
|
|
||||||
SEED = 42
|
|
||||||
DPI = 150
|
|
||||||
|
|
||||||
CONSISTENCY_CMAP = LinearSegmentedColormap.from_list(
|
|
||||||
"consistency", ["#0a2e0a", "#1a8e1a", "#88cc22", "#ffaa22", "#ff2222"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# FK chains from OpenArm bimanual URDF (same as workspace_density.py).
|
|
||||||
LEFT_CHAIN = [
|
|
||||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
|
||||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
|
||||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
|
||||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
|
||||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
|
||||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
|
||||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
|
||||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
|
||||||
((0, 0, 0), (0, 0, 0.1001), None),
|
|
||||||
((0, 0, 0), (0, 0, 0.08), None),
|
|
||||||
]
|
|
||||||
RIGHT_CHAIN = [
|
|
||||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
|
||||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
|
||||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
|
||||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
|
||||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
|
||||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
|
||||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
|
||||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
|
||||||
((0, 0, 0), (0, 0, 0.1001), None),
|
|
||||||
((0, 0, 0), (0, 0, 0.08), None),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ── FK math ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _rot_x(a: float) -> np.ndarray:
|
|
||||||
c, s = np.cos(a), np.sin(a)
|
|
||||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
|
||||||
|
|
||||||
|
|
||||||
def _rot_y(a: float) -> np.ndarray:
|
|
||||||
c, s = np.cos(a), np.sin(a)
|
|
||||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
|
||||||
|
|
||||||
|
|
||||||
def _rot_z(a: float) -> np.ndarray:
|
|
||||||
c, s = np.cos(a), np.sin(a)
|
|
||||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
|
||||||
|
|
||||||
|
|
||||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
|
||||||
r, p, y = rpy
|
|
||||||
mat = np.eye(4)
|
|
||||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
|
||||||
mat[:3, 3] = xyz
|
|
||||||
return mat
|
|
||||||
|
|
||||||
|
|
||||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
|
||||||
n = len(angles)
|
|
||||||
ax = np.asarray(axis, dtype=np.float64)
|
|
||||||
ax = ax / np.linalg.norm(ax)
|
|
||||||
x, y, z = ax
|
|
||||||
c = np.cos(angles)
|
|
||||||
s = np.sin(angles)
|
|
||||||
t = 1 - c
|
|
||||||
rot = np.zeros((n, 4, 4))
|
|
||||||
rot[:, 0, 0] = t * x * x + c
|
|
||||||
rot[:, 0, 1] = t * x * y - s * z
|
|
||||||
rot[:, 0, 2] = t * x * z + s * y
|
|
||||||
rot[:, 1, 0] = t * x * y + s * z
|
|
||||||
rot[:, 1, 1] = t * y * y + c
|
|
||||||
rot[:, 1, 2] = t * y * z - s * x
|
|
||||||
rot[:, 2, 0] = t * x * z - s * y
|
|
||||||
rot[:, 2, 1] = t * y * z + s * x
|
|
||||||
rot[:, 2, 2] = t * z * z + c
|
|
||||||
rot[:, 3, 3] = 1.0
|
|
||||||
return rot
|
|
||||||
|
|
||||||
|
|
||||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
|
||||||
n = joint_angles.shape[0]
|
|
||||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
|
||||||
qi = 0
|
|
||||||
for rpy, xyz, axis in chain:
|
|
||||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
|
||||||
if axis is not None:
|
|
||||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
|
||||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
|
||||||
qi += 1
|
|
||||||
return tf_batch[:, :3, 3]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Data helpers ────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _flatten_names(obj: object) -> list[str]:
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
out: list[str] = []
|
|
||||||
for v in obj.values():
|
|
||||||
out.extend(_flatten_names(v))
|
|
||||||
return out
|
|
||||||
if isinstance(obj, (list, tuple)):
|
|
||||||
out = []
|
|
||||||
for item in obj:
|
|
||||||
if isinstance(item, (list, tuple, dict)):
|
|
||||||
out.extend(_flatten_names(item))
|
|
||||||
else:
|
|
||||||
out.append(str(item))
|
|
||||||
return out
|
|
||||||
return [str(obj)]
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
|
||||||
mx = np.max(np.abs(vals))
|
|
||||||
if mx > 360:
|
|
||||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
|
||||||
return (vals - 2048) / 2048 * np.pi
|
|
||||||
if mx > 6.3:
|
|
||||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
|
||||||
return np.deg2rad(vals)
|
|
||||||
print(f" Unit detection: radians (max={mx:.3f})")
|
|
||||||
return vals.astype(np.float64)
|
|
||||||
|
|
||||||
|
|
||||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
|
||||||
feat = features.get("observation.state", features.get(state_col, {}))
|
|
||||||
names = _flatten_names(feat.get("names", []))
|
|
||||||
left_idx: list[int] = []
|
|
||||||
right_idx: list[int] = []
|
|
||||||
if names and len(names) == n_dim:
|
|
||||||
names_l = [n.lower() for n in names]
|
|
||||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
|
||||||
for j in range(1, 8):
|
|
||||||
for i, nm in enumerate(names_l):
|
|
||||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
|
||||||
left_idx.append(i)
|
|
||||||
break
|
|
||||||
for i, nm in enumerate(names_l):
|
|
||||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
|
||||||
right_idx.append(i)
|
|
||||||
break
|
|
||||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
|
||||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
|
||||||
return left_idx, right_idx
|
|
||||||
if n_dim >= 16:
|
|
||||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
|
||||||
return list(range(7)), list(range(8, 15))
|
|
||||||
if n_dim >= 14:
|
|
||||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
|
||||||
return list(range(7)), list(range(7, 14))
|
|
||||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
|
||||||
|
|
||||||
|
|
||||||
def download_data(repo_id: str, camera_key: str) -> Path:
|
|
||||||
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
|
|
||||||
return Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type="dataset",
|
|
||||||
allow_patterns=[
|
|
||||||
"meta/**",
|
|
||||||
"data/**",
|
|
||||||
f"videos/{camera_key}/**",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Data loading ────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _build_action_chunks(
|
|
||||||
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
|
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""
|
|
||||||
For each frame, concatenate the next chunk_size actions from the same episode.
|
|
||||||
Returns (action_chunks, valid_mask).
|
|
||||||
"""
|
|
||||||
n = len(actions)
|
|
||||||
act_dim = actions.shape[1]
|
|
||||||
chunks = np.zeros((n, chunk_size * act_dim), dtype=np.float64)
|
|
||||||
valid = np.zeros(n, dtype=bool)
|
|
||||||
|
|
||||||
for i in range(n):
|
|
||||||
end = i + chunk_size
|
|
||||||
if end > n:
|
|
||||||
continue
|
|
||||||
if episode_ids[i] != episode_ids[end - 1]:
|
|
||||||
continue
|
|
||||||
chunks[i] = actions[i:end].ravel()
|
|
||||||
valid[i] = True
|
|
||||||
|
|
||||||
return chunks, valid
|
|
||||||
|
|
||||||
|
|
||||||
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
|
|
||||||
"""
|
|
||||||
Load observation.state and action, build action chunks, subsample, normalize.
|
|
||||||
Also returns the original row indices (`chosen_idx`) for video frame mapping.
|
|
||||||
"""
|
|
||||||
info = json.loads((local / "meta" / "info.json").read_text())
|
|
||||||
features = info.get("features", {})
|
|
||||||
|
|
||||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
|
||||||
df = pd.concat(dfs, ignore_index=True)
|
|
||||||
n_total = len(df)
|
|
||||||
print(f" Total frames: {n_total:,}")
|
|
||||||
|
|
||||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
|
||||||
action_col = next((c for c in df.columns if c == "action"), None)
|
|
||||||
if state_col is None:
|
|
||||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
|
||||||
if action_col is None:
|
|
||||||
raise RuntimeError(f"No action column. Available: {list(df.columns)}")
|
|
||||||
|
|
||||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
|
||||||
if ep_col is None:
|
|
||||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
|
||||||
|
|
||||||
state_all = np.stack(df[state_col].values).astype(np.float64)
|
|
||||||
action_all = np.stack(df[action_col].values).astype(np.float64)
|
|
||||||
episode_all = df[ep_col].values.astype(np.int64)
|
|
||||||
|
|
||||||
n_dim = state_all.shape[1]
|
|
||||||
act_dim = action_all.shape[1]
|
|
||||||
print(f" State dim: {n_dim} Action dim: {act_dim} Chunk size: {chunk_size}")
|
|
||||||
print(f" Action chunk dim: {chunk_size * act_dim}")
|
|
||||||
|
|
||||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
|
||||||
|
|
||||||
print(" Building action chunks …")
|
|
||||||
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
|
|
||||||
valid_idx = np.where(valid)[0]
|
|
||||||
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
|
|
||||||
|
|
||||||
if len(valid_idx) > max_frames:
|
|
||||||
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
|
|
||||||
else:
|
|
||||||
chosen = valid_idx
|
|
||||||
print(f" Using {len(chosen):,} frames")
|
|
||||||
|
|
||||||
state_raw = state_all[chosen]
|
|
||||||
action_raw = action_chunks[chosen]
|
|
||||||
episode_ids = episode_all[chosen]
|
|
||||||
|
|
||||||
state_mean = state_raw.mean(axis=0)
|
|
||||||
state_std = state_raw.std(axis=0)
|
|
||||||
state_std[state_std < 1e-8] = 1.0
|
|
||||||
state_norm = (state_raw - state_mean) / state_std
|
|
||||||
|
|
||||||
action_mean = action_raw.mean(axis=0)
|
|
||||||
action_std = action_raw.std(axis=0)
|
|
||||||
action_std[action_std < 1e-8] = 1.0
|
|
||||||
action_norm = (action_raw - action_mean) / action_std
|
|
||||||
|
|
||||||
return {
|
|
||||||
"state_raw": state_raw,
|
|
||||||
"state_norm": state_norm,
|
|
||||||
"action_raw": action_raw,
|
|
||||||
"action_norm": action_norm,
|
|
||||||
"episode_ids": episode_ids,
|
|
||||||
"episode_all": episode_all,
|
|
||||||
"left_joint_idx": left_idx,
|
|
||||||
"right_joint_idx": right_idx,
|
|
||||||
"n_total": n_total,
|
|
||||||
"chosen_idx": chosen,
|
|
||||||
"df": df,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Video → frame extraction ──────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def build_video_lookup(local: Path, camera_key: str) -> dict:
|
|
||||||
"""
|
|
||||||
Build a mapping from episode_index → {video_path, fps, from_ts}.
|
|
||||||
"""
|
|
||||||
info = json.loads((local / "meta" / "info.json").read_text())
|
|
||||||
fps = info["fps"]
|
|
||||||
video_template = info.get(
|
|
||||||
"video_path",
|
|
||||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
|
||||||
)
|
|
||||||
|
|
||||||
ep_rows = []
|
|
||||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
|
||||||
ep_rows.append(pd.read_parquet(pq))
|
|
||||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
|
||||||
|
|
||||||
chunk_col = f"videos/{camera_key}/chunk_index"
|
|
||||||
file_col = f"videos/{camera_key}/file_index"
|
|
||||||
ts_from = f"videos/{camera_key}/from_timestamp"
|
|
||||||
if chunk_col not in ep_df.columns:
|
|
||||||
chunk_col = f"{camera_key}/chunk_index"
|
|
||||||
file_col = f"{camera_key}/file_index"
|
|
||||||
ts_from = f"{camera_key}/from_timestamp"
|
|
||||||
|
|
||||||
lookup: dict[int, dict] = {}
|
|
||||||
for _, row in ep_df.iterrows():
|
|
||||||
ci = int(row[chunk_col])
|
|
||||||
fi = int(row[file_col])
|
|
||||||
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
|
|
||||||
lookup[int(row["episode_index"])] = {
|
|
||||||
"video_path": local / video_rel,
|
|
||||||
"from_ts": float(row[ts_from]),
|
|
||||||
"fps": fps,
|
|
||||||
}
|
|
||||||
return lookup
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
|
|
||||||
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
|
|
||||||
container = av.open(video_path)
|
|
||||||
stream = container.streams.video[0]
|
|
||||||
stream.thread_type = "AUTO"
|
|
||||||
decoded = []
|
|
||||||
for frame in container.decode(stream):
|
|
||||||
decoded.append(frame.to_ndarray(format="rgb24"))
|
|
||||||
container.close()
|
|
||||||
return decoded
|
|
||||||
|
|
||||||
|
|
||||||
def extract_frames(
|
|
||||||
chosen_idx: np.ndarray,
|
|
||||||
episode_all: np.ndarray,
|
|
||||||
video_lookup: dict,
|
|
||||||
) -> list[np.ndarray | None]:
|
|
||||||
"""
|
|
||||||
Extract RGB frames for each chosen global index using PyAV.
|
|
||||||
Returns list of (H, W, 3) RGB arrays (or None on failure).
|
|
||||||
"""
|
|
||||||
unique_eps = np.unique(episode_all)
|
|
||||||
ep_start: dict[int, int] = {}
|
|
||||||
for ep in unique_eps:
|
|
||||||
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
|
||||||
|
|
||||||
# Build jobs: (output_index, video_path, local_frame_number)
|
|
||||||
jobs: list[tuple[int, str, int]] = []
|
|
||||||
for out_i, global_i in enumerate(chosen_idx):
|
|
||||||
ep = int(episode_all[global_i])
|
|
||||||
info = video_lookup.get(ep)
|
|
||||||
if info is None:
|
|
||||||
continue
|
|
||||||
local_frame = global_i - ep_start[ep]
|
|
||||||
jobs.append((out_i, str(info["video_path"]), local_frame))
|
|
||||||
|
|
||||||
# Group by video file, decode each video once
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
|
|
||||||
for out_i, vpath, local_frame in jobs:
|
|
||||||
video_jobs[vpath].append((out_i, local_frame))
|
|
||||||
|
|
||||||
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
|
|
||||||
extracted = 0
|
|
||||||
n_videos = len(video_jobs)
|
|
||||||
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
|
|
||||||
if not Path(vpath).exists():
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
decoded = _decode_video_frames(vpath)
|
|
||||||
except Exception as exc:
|
|
||||||
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
|
|
||||||
continue
|
|
||||||
for out_i, local_frame in frame_requests:
|
|
||||||
if 0 <= local_frame < len(decoded):
|
|
||||||
frames[out_i] = decoded[local_frame]
|
|
||||||
extracted += 1
|
|
||||||
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
|
|
||||||
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
|
|
||||||
del decoded
|
|
||||||
|
|
||||||
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
# ── SigLIP encoding ─────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def encode_frames_siglip(
|
|
||||||
frames: list[np.ndarray | None],
|
|
||||||
model_name: str,
|
|
||||||
batch_size: int,
|
|
||||||
device: torch.device,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Encode RGB frames through SigLIP vision encoder.
|
|
||||||
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
|
||||||
"""
|
|
||||||
print(f" Loading SigLIP model: {model_name} …")
|
|
||||||
processor = AutoImageProcessor.from_pretrained(model_name)
|
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
|
||||||
embed_dim = model.config.vision_config.hidden_size
|
|
||||||
|
|
||||||
n = len(frames)
|
|
||||||
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
|
||||||
|
|
||||||
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
|
||||||
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
|
||||||
|
|
||||||
for batch_start in range(0, len(valid_indices), batch_size):
|
|
||||||
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
|
||||||
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
|
|
||||||
|
|
||||||
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
|
||||||
with torch.no_grad():
|
|
||||||
image_features = model.get_image_features(**inputs)
|
|
||||||
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
|
||||||
embeddings[batch_idx] = image_features.cpu().numpy()
|
|
||||||
|
|
||||||
done = min(batch_start + batch_size, len(valid_indices))
|
|
||||||
if done % (batch_size * 10) == 0 or done == len(valid_indices):
|
|
||||||
print(f" {done:,} / {len(valid_indices):,} encoded")
|
|
||||||
|
|
||||||
del model, processor
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
# ── KNN consistency ─────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def compute_consistency(
|
|
||||||
features: np.ndarray,
|
|
||||||
action_norm: np.ndarray,
|
|
||||||
episode_ids: np.ndarray,
|
|
||||||
k: int,
|
|
||||||
label: str = "",
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
For each frame, find K nearest neighbors in feature space from other episodes.
|
|
||||||
Return per-frame action variance (mean across action dims).
|
|
||||||
"""
|
|
||||||
n = len(features)
|
|
||||||
print(f" Building KD-tree on {n:,} vectors ({label}) …")
|
|
||||||
tree = cKDTree(features)
|
|
||||||
|
|
||||||
k_query = min(k * 3, n - 1)
|
|
||||||
print(f" Querying {k_query} neighbors per frame …")
|
|
||||||
_dists, indices = tree.query(features, k=k_query + 1)
|
|
||||||
indices = indices[:, 1:]
|
|
||||||
|
|
||||||
print(f" Computing cross-episode action variance ({label}) …")
|
|
||||||
variance = np.zeros(n)
|
|
||||||
for i in range(n):
|
|
||||||
ep_i = episode_ids[i]
|
|
||||||
neighbors = indices[i]
|
|
||||||
cross_ep = neighbors[episode_ids[neighbors] != ep_i][:k]
|
|
||||||
if len(cross_ep) < 2:
|
|
||||||
variance[i] = 0.0
|
|
||||||
continue
|
|
||||||
neighbor_actions = action_norm[cross_ep]
|
|
||||||
variance[i] = np.mean(np.var(neighbor_actions, axis=0))
|
|
||||||
|
|
||||||
return variance
|
|
||||||
|
|
||||||
|
|
||||||
# ── Visualization ───────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _style_ax(ax: plt.Axes) -> None:
|
|
||||||
ax.set_facecolor("#0d1117")
|
|
||||||
ax.tick_params(colors="#555", labelsize=8)
|
|
||||||
for spine in ax.spines.values():
|
|
||||||
spine.set_color("#333")
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
|
|
||||||
_style_ax(ax)
|
|
||||||
median_var = np.median(variance)
|
|
||||||
mean_var = np.mean(variance)
|
|
||||||
nonzero = variance[variance > 0]
|
|
||||||
if len(nonzero) > 0:
|
|
||||||
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
|
|
||||||
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
|
|
||||||
ax.set_xscale("log")
|
|
||||||
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
|
|
||||||
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
|
|
||||||
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
|
|
||||||
ax.set_ylabel("Frame count", color="#888", fontsize=10)
|
|
||||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
|
||||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_episode_curves(
|
|
||||||
ax: plt.Axes,
|
|
||||||
var_state: np.ndarray,
|
|
||||||
var_image: np.ndarray,
|
|
||||||
episode_ids: np.ndarray,
|
|
||||||
title: str,
|
|
||||||
) -> None:
|
|
||||||
_style_ax(ax)
|
|
||||||
unique_eps = np.unique(episode_ids)
|
|
||||||
|
|
||||||
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
|
|
||||||
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
|
|
||||||
|
|
||||||
sorted_s = np.sort(ep_means_s)[::-1]
|
|
||||||
sorted_i = np.sort(ep_means_i)[::-1]
|
|
||||||
ep_x = np.arange(len(unique_eps))
|
|
||||||
|
|
||||||
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
|
|
||||||
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
|
|
||||||
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
|
|
||||||
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
|
|
||||||
|
|
||||||
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
|
|
||||||
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
|
|
||||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
|
||||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_heatmap(
|
|
||||||
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
|
|
||||||
) -> None:
|
|
||||||
_style_ax(ax)
|
|
||||||
order = np.argsort(variance)
|
|
||||||
pts = tcp_xz[order]
|
|
||||||
var_sorted = variance[order]
|
|
||||||
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
|
|
||||||
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
|
|
||||||
sc = ax.scatter(
|
|
||||||
pts[:, 0],
|
|
||||||
pts[:, 1],
|
|
||||||
c=var_sorted,
|
|
||||||
cmap=CONSISTENCY_CMAP,
|
|
||||||
s=0.5,
|
|
||||||
alpha=0.6,
|
|
||||||
vmin=vmin,
|
|
||||||
vmax=vmax,
|
|
||||||
rasterized=True,
|
|
||||||
)
|
|
||||||
ax.set_xlabel("X (m)", color="#888", fontsize=10)
|
|
||||||
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
|
|
||||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
|
||||||
ax.set_aspect("equal")
|
|
||||||
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
|
||||||
cbar.set_label("Action variance", color="white", fontsize=9)
|
|
||||||
cbar.ax.tick_params(colors="#aaa", labelsize=7)
|
|
||||||
|
|
||||||
|
|
||||||
def render(results: list[dict], out_path: Path) -> None:
|
|
||||||
"""
|
|
||||||
4-row x N-column figure:
|
|
||||||
Row 0: State-based variance histogram
|
|
||||||
Row 1: Image-based variance histogram
|
|
||||||
Row 2: Per-episode curves (both overlaid)
|
|
||||||
Row 3: Spatial heatmap (image-based variance)
|
|
||||||
"""
|
|
||||||
n_ds = len(results)
|
|
||||||
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
|
|
||||||
if n_ds == 1:
|
|
||||||
axes = axes[:, np.newaxis]
|
|
||||||
|
|
||||||
headline_parts = []
|
|
||||||
for col, r in enumerate(results):
|
|
||||||
label = r["label"]
|
|
||||||
var_s = r["var_state"]
|
|
||||||
var_i = r["var_image"]
|
|
||||||
tcp_xz = r["tcp_xz"]
|
|
||||||
episode_ids = r["episode_ids"]
|
|
||||||
|
|
||||||
med_s = np.median(var_s)
|
|
||||||
med_i = np.median(var_i)
|
|
||||||
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
|
|
||||||
|
|
||||||
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
|
|
||||||
_plot_histogram(
|
|
||||||
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
|
|
||||||
)
|
|
||||||
_plot_episode_curves(
|
|
||||||
axes[2, col],
|
|
||||||
var_s,
|
|
||||||
var_i,
|
|
||||||
episode_ids,
|
|
||||||
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
|
|
||||||
)
|
|
||||||
_plot_heatmap(
|
|
||||||
axes[3, col],
|
|
||||||
fig,
|
|
||||||
tcp_xz,
|
|
||||||
var_i,
|
|
||||||
f"{label}\nImage-based variance by TCP position (XZ)",
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.suptitle(
|
|
||||||
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
|
|
||||||
+ " | ".join(headline_parts),
|
|
||||||
color="white",
|
|
||||||
fontsize=15,
|
|
||||||
y=0.99,
|
|
||||||
)
|
|
||||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
|
||||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
|
||||||
plt.close()
|
|
||||||
print(f"\n✓ Saved: {out_path}")
|
|
||||||
|
|
||||||
|
|
||||||
# ── Main ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Device: {device}")
|
|
||||||
rng = np.random.default_rng(SEED)
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for ds in DATASETS:
|
|
||||||
repo_id, label = ds["repo_id"], ds["label"]
|
|
||||||
print(f"\n{'=' * 60}")
|
|
||||||
print(f" {label}: {repo_id}")
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
|
|
||||||
local = download_data(repo_id, CAMERA_KEY)
|
|
||||||
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
|
|
||||||
|
|
||||||
# --- State-based KNN ---
|
|
||||||
var_state = compute_consistency(
|
|
||||||
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f" State variance: median={np.median(var_state):.4f} "
|
|
||||||
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Image-based KNN ---
|
|
||||||
print("\n Preparing image embeddings …")
|
|
||||||
video_lookup = build_video_lookup(local, CAMERA_KEY)
|
|
||||||
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
|
|
||||||
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
|
|
||||||
del frames # free memory
|
|
||||||
|
|
||||||
var_image = compute_consistency(
|
|
||||||
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f" Image variance: median={np.median(var_image):.4f} "
|
|
||||||
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# FK for spatial heatmap
|
|
||||||
print(" Computing FK for spatial heatmap …")
|
|
||||||
left_raw = data["state_raw"][:, data["left_joint_idx"]]
|
|
||||||
left_rad = _detect_and_convert(left_raw)
|
|
||||||
left_tcp = batch_fk(LEFT_CHAIN, left_rad)
|
|
||||||
tcp_xz = left_tcp[:, [0, 2]]
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"label": label,
|
|
||||||
"var_state": var_state,
|
|
||||||
"var_image": var_image,
|
|
||||||
"episode_ids": data["episode_ids"],
|
|
||||||
"tcp_xz": tcp_xz,
|
|
||||||
"n_total": data["n_total"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
|
|
||||||
render(results, out)
|
|
||||||
|
|
||||||
# Save worst-episodes summary (image-based, since that's the stronger signal)
|
|
||||||
worst_summary = {}
|
|
||||||
for r in results:
|
|
||||||
unique_eps = np.unique(r["episode_ids"])
|
|
||||||
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
|
|
||||||
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
|
|
||||||
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
|
|
||||||
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
|
|
||||||
worst_path.write_text(json.dumps(worst_summary, indent=2))
|
|
||||||
print(f"✓ Saved worst episodes: {worst_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"""
|
|
||||||
Create a JPG grid of random frames sampled from a LeRobot video dataset.
|
|
||||||
Downloads metadata + video chunks from HuggingFace, picks random frames,
|
|
||||||
decodes them, and tiles into a single image.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
REPO_ID = "lerobot-data-collection/level2_final_quality3"
|
|
||||||
CAMERA_KEY = "observation.images.base"
|
|
||||||
GRID_COLS = 15
|
|
||||||
GRID_ROWS = 10
|
|
||||||
THUMB_WIDTH = 160
|
|
||||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
|
||||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
|
||||||
SEED = 1
|
|
||||||
|
|
||||||
|
|
||||||
def download_metadata(repo_id: str) -> Path:
|
|
||||||
"""Download only metadata (no videos yet)."""
|
|
||||||
print(f"[1/3] Downloading metadata for {repo_id} …")
|
|
||||||
return Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type="dataset",
|
|
||||||
allow_patterns=["meta/**"],
|
|
||||||
ignore_patterns=["*.mp4"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_video_info(local: Path) -> tuple[str, list[dict], int]:
|
|
||||||
"""Parse info.json and episode parquets. Returns (camera_key, episode_rows, fps)."""
|
|
||||||
info = json.loads((local / "meta" / "info.json").read_text())
|
|
||||||
fps = info["fps"]
|
|
||||||
features = info["features"]
|
|
||||||
|
|
||||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
|
||||||
if not video_keys:
|
|
||||||
raise RuntimeError("No video keys found in dataset features")
|
|
||||||
|
|
||||||
if CAMERA_KEY is not None:
|
|
||||||
if CAMERA_KEY not in video_keys:
|
|
||||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
|
||||||
cam = CAMERA_KEY
|
|
||||||
else:
|
|
||||||
cam = video_keys[0]
|
|
||||||
print(f" camera='{cam}' all_cams={video_keys} fps={fps}")
|
|
||||||
|
|
||||||
ep_rows = []
|
|
||||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
|
||||||
ep_rows.append(pd.read_parquet(pq))
|
|
||||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
|
||||||
|
|
||||||
video_template = info.get(
|
|
||||||
"video_path",
|
|
||||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk_col = f"videos/{cam}/chunk_index"
|
|
||||||
file_col = f"videos/{cam}/file_index"
|
|
||||||
ts_from = f"videos/{cam}/from_timestamp"
|
|
||||||
ts_to = f"videos/{cam}/to_timestamp"
|
|
||||||
if chunk_col not in ep_df.columns:
|
|
||||||
chunk_col = f"{cam}/chunk_index"
|
|
||||||
file_col = f"{cam}/file_index"
|
|
||||||
ts_from = f"{cam}/from_timestamp"
|
|
||||||
ts_to = f"{cam}/to_timestamp"
|
|
||||||
|
|
||||||
episodes = []
|
|
||||||
for _, row in ep_df.iterrows():
|
|
||||||
ci = int(row[chunk_col])
|
|
||||||
fi = int(row[file_col])
|
|
||||||
episodes.append(
|
|
||||||
{
|
|
||||||
"episode_index": int(row["episode_index"]),
|
|
||||||
"chunk_index": ci,
|
|
||||||
"file_index": fi,
|
|
||||||
"from_ts": float(row[ts_from]),
|
|
||||||
"to_ts": float(row[ts_to]),
|
|
||||||
"video_rel": video_template.format(video_key=cam, chunk_index=ci, file_index=fi),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return cam, episodes, fps
|
|
||||||
|
|
||||||
|
|
||||||
def pick_random_frames(episodes: list[dict], fps: int, n: int, rng: random.Random) -> list[dict]:
|
|
||||||
"""Pick n random (episode, timestamp) pairs, return sorted by video file for efficient access."""
|
|
||||||
picks = []
|
|
||||||
for _ in range(n):
|
|
||||||
ep = rng.choice(episodes)
|
|
||||||
duration = ep["to_ts"] - ep["from_ts"]
|
|
||||||
if duration <= 0:
|
|
||||||
continue
|
|
||||||
t = ep["from_ts"] + rng.random() * duration
|
|
||||||
picks.append({**ep, "seek_ts": t})
|
|
||||||
picks.sort(key=lambda p: (p["video_rel"], p["seek_ts"]))
|
|
||||||
return picks
|
|
||||||
|
|
||||||
|
|
||||||
def download_video_files(repo_id: str, local: Path, picks: list[dict]) -> None:
|
|
||||||
"""Download only the video files we need."""
|
|
||||||
needed = sorted({p["video_rel"] for p in picks})
|
|
||||||
print(f"[2/3] Downloading {len(needed)} video file(s) …")
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type="dataset",
|
|
||||||
local_dir=str(local),
|
|
||||||
allow_patterns=needed,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_frame(video_path: Path, seek_ts: float) -> np.ndarray | None:
|
|
||||||
"""Decode a single frame at the given timestamp."""
|
|
||||||
cap = cv2.VideoCapture(str(video_path))
|
|
||||||
cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
|
|
||||||
ret, frame = cap.read()
|
|
||||||
cap.release()
|
|
||||||
return frame if ret else None
|
|
||||||
|
|
||||||
|
|
||||||
def build_grid(frames: list[np.ndarray], cols: int, thumb_w: int) -> np.ndarray:
|
|
||||||
"""Resize frames to uniform thumbnails and tile into a grid."""
|
|
||||||
if not frames:
|
|
||||||
raise RuntimeError("No frames decoded")
|
|
||||||
|
|
||||||
h0, w0 = frames[0].shape[:2]
|
|
||||||
thumb_h = int(thumb_w * h0 / w0)
|
|
||||||
|
|
||||||
thumbs = [cv2.resize(f, (thumb_w, thumb_h), interpolation=cv2.INTER_AREA) for f in frames]
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
for i in range(0, len(thumbs), cols):
|
|
||||||
row_thumbs = thumbs[i : i + cols]
|
|
||||||
while len(row_thumbs) < cols:
|
|
||||||
row_thumbs.append(np.zeros_like(row_thumbs[0]))
|
|
||||||
rows.append(np.hstack(row_thumbs))
|
|
||||||
return np.vstack(rows)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
rng = random.Random(SEED)
|
|
||||||
n_frames = GRID_COLS * GRID_ROWS
|
|
||||||
|
|
||||||
local = download_metadata(REPO_ID)
|
|
||||||
cam, episodes, fps = load_video_info(local)
|
|
||||||
picks = pick_random_frames(episodes, fps, n_frames, rng)
|
|
||||||
download_video_files(REPO_ID, local, picks)
|
|
||||||
|
|
||||||
print(f"[3/3] Decoding {n_frames} frames …")
|
|
||||||
frames: list[np.ndarray] = []
|
|
||||||
for p in picks:
|
|
||||||
vp = local / p["video_rel"]
|
|
||||||
if not vp.exists():
|
|
||||||
print(f" SKIP: {p['video_rel']} not found")
|
|
||||||
continue
|
|
||||||
frame = extract_frame(vp, p["seek_ts"])
|
|
||||||
if frame is not None:
|
|
||||||
frames.append(frame)
|
|
||||||
|
|
||||||
print(f" Decoded {len(frames)}/{n_frames} frames")
|
|
||||||
grid = build_grid(frames, GRID_COLS, THUMB_WIDTH)
|
|
||||||
|
|
||||||
safe_name = REPO_ID.replace("/", "_")
|
|
||||||
out_path = OUTPUT_DIR / f"{safe_name}_grid_{GRID_COLS}x{GRID_ROWS}.jpg"
|
|
||||||
cv2.imwrite(str(out_path), grid, [cv2.IMWRITE_JPEG_QUALITY, 92])
|
|
||||||
print(f"\n✓ Saved: {out_path} ({grid.shape[1]}×{grid.shape[0]})")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,526 +0,0 @@
|
|||||||
"""
|
|
||||||
Create MP4 videos with sarm_progress overlay for specified episodes.
|
|
||||||
Downloads datasets from HuggingFace, extracts episode video + progress data,
|
|
||||||
and draws the progress line directly on each frame (no panel, no axes).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
DATASETS = [
|
|
||||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "episode": 250},
|
|
||||||
]
|
|
||||||
CAMERA_KEY = (
|
|
||||||
"observation.images.base" # None = auto-select first camera, or set e.g. "observation.images.top"
|
|
||||||
)
|
|
||||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
|
||||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
# Progress line spans the full video height
|
|
||||||
GRAPH_Y_TOP_FRAC = 0.01
|
|
||||||
GRAPH_Y_BOT_FRAC = 0.99
|
|
||||||
LINE_THICKNESS = 3
|
|
||||||
SHADOW_THICKNESS = 6 # white edge thickness
|
|
||||||
REF_ALPHA = 0.45 # opacity of the 1.0 reference line
|
|
||||||
FILL_ALPHA = 0.55 # opacity of the grey fill under the line
|
|
||||||
SCORE_FONT_SCALE = 0.8
|
|
||||||
TASK_FONT_SCALE = 0.55
|
|
||||||
|
|
||||||
|
|
||||||
def download_episode(repo_id: str, episode: int) -> Path:
|
|
||||||
"""Download only the files needed for this episode."""
|
|
||||||
# We need: meta/, sarm_progress.parquet, and the relevant video/data chunks.
|
|
||||||
# We'll download meta + sarm first, then figure out chunks.
|
|
||||||
print(f"\n[1/5] Downloading metadata for {repo_id} …")
|
|
||||||
local = Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type="dataset",
|
|
||||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
|
||||||
ignore_patterns=["*.mp4"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return local
|
|
||||||
|
|
||||||
|
|
||||||
def load_episode_meta(local: Path, episode: int) -> dict:
|
|
||||||
"""Read info.json + episode-level parquet to get fps, video paths, timestamps."""
|
|
||||||
info = json.loads((local / "meta" / "info.json").read_text())
|
|
||||||
fps = info["fps"]
|
|
||||||
features = info["features"]
|
|
||||||
|
|
||||||
# Find video keys (keys whose dtype=="video")
|
|
||||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
|
||||||
if not video_keys:
|
|
||||||
raise RuntimeError("No video keys found in dataset features")
|
|
||||||
if CAMERA_KEY is not None:
|
|
||||||
if CAMERA_KEY not in video_keys:
|
|
||||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
|
||||||
first_cam = CAMERA_KEY
|
|
||||||
else:
|
|
||||||
first_cam = video_keys[0]
|
|
||||||
print(f" fps={fps} camera='{first_cam}' all_cams={video_keys}")
|
|
||||||
|
|
||||||
# Load all episode-meta parquet files and find our episode
|
|
||||||
ep_rows = []
|
|
||||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
|
||||||
df = pd.read_parquet(pq)
|
|
||||||
ep_rows.append(df)
|
|
||||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
|
||||||
row = ep_df[ep_df["episode_index"] == episode]
|
|
||||||
if row.empty:
|
|
||||||
raise RuntimeError(f"Episode {episode} not found in episode metadata")
|
|
||||||
row = row.iloc[0]
|
|
||||||
|
|
||||||
# Extract video chunk/file index for first camera
|
|
||||||
# Try both dot and slash variants of the key
|
|
||||||
chunk_col = f"videos/{first_cam}/chunk_index"
|
|
||||||
file_col = f"videos/{first_cam}/file_index"
|
|
||||||
ts_col = f"videos/{first_cam}/from_timestamp"
|
|
||||||
to_col = f"videos/{first_cam}/to_timestamp"
|
|
||||||
|
|
||||||
# Some datasets use different column naming
|
|
||||||
if chunk_col not in row.index:
|
|
||||||
# Try without the 'videos/' prefix
|
|
||||||
chunk_col = f"{first_cam}/chunk_index"
|
|
||||||
file_col = f"{first_cam}/file_index"
|
|
||||||
ts_col = f"{first_cam}/from_timestamp"
|
|
||||||
to_col = f"{first_cam}/to_timestamp"
|
|
||||||
if chunk_col not in row.index:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Cannot find video metadata columns for {first_cam}.\nAvailable: {list(row.index)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
chunk_idx = int(row[chunk_col])
|
|
||||||
file_idx = int(row[file_col])
|
|
||||||
from_ts = float(row[ts_col])
|
|
||||||
to_ts = float(row[to_col])
|
|
||||||
|
|
||||||
video_template = info.get(
|
|
||||||
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
|
|
||||||
)
|
|
||||||
video_rel = video_template.format(
|
|
||||||
video_key=first_cam,
|
|
||||||
chunk_index=chunk_idx,
|
|
||||||
file_index=file_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load task name for this episode
|
|
||||||
# tasks.parquet uses the task string as the row index; task_index column holds the int id
|
|
||||||
task_name = ""
|
|
||||||
try:
|
|
||||||
# Prefer the 'tasks' list directly on the episode row
|
|
||||||
if "tasks" in row.index and row["tasks"] is not None:
|
|
||||||
tasks_val = row["tasks"]
|
|
||||||
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
|
|
||||||
task_name = str(tasks_val[0])
|
|
||||||
else:
|
|
||||||
task_name = str(tasks_val).strip("[]'")
|
|
||||||
else:
|
|
||||||
tasks_pq = local / "meta" / "tasks.parquet"
|
|
||||||
if tasks_pq.exists():
|
|
||||||
tasks_df = pd.read_parquet(tasks_pq)
|
|
||||||
# Row index is the task string; task_index column is the int
|
|
||||||
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
|
|
||||||
match = tasks_df[tasks_df["task_index"] == task_idx]
|
|
||||||
if not match.empty:
|
|
||||||
task_name = str(match.index[0])
|
|
||||||
print(f" Task name: '{task_name}'")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" WARNING: could not load task name: {e}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"fps": fps,
|
|
||||||
"first_cam": first_cam,
|
|
||||||
"video_rel": video_rel,
|
|
||||||
"chunk_index": chunk_idx,
|
|
||||||
"file_index": file_idx,
|
|
||||||
"from_ts": from_ts,
|
|
||||||
"to_ts": to_ts,
|
|
||||||
"task_name": task_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def download_video(repo_id: str, local: Path, video_rel: str) -> Path:
|
|
||||||
"""Download the specific video file if not already present."""
|
|
||||||
video_path = local / video_rel
|
|
||||||
if video_path.exists():
|
|
||||||
print(f" Video already cached: {video_path}")
|
|
||||||
return video_path
|
|
||||||
print(f"[2/5] Downloading video file {video_rel} …")
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type="dataset",
|
|
||||||
local_dir=str(local),
|
|
||||||
allow_patterns=[video_rel],
|
|
||||||
)
|
|
||||||
if not video_path.exists():
|
|
||||||
raise RuntimeError(f"Video not found after download: {video_path}")
|
|
||||||
return video_path
|
|
||||||
|
|
||||||
|
|
||||||
def load_progress(local: Path, episode: int) -> np.ndarray | None:
|
|
||||||
"""Load sarm_progress values for this episode. Returns sorted array of (frame_index, progress)."""
|
|
||||||
pq_path = local / "sarm_progress.parquet"
|
|
||||||
if not pq_path.exists():
|
|
||||||
print(" WARNING: sarm_progress.parquet not found, trying data parquet …")
|
|
||||||
return None
|
|
||||||
df = pd.read_parquet(pq_path)
|
|
||||||
print(f" sarm_progress.parquet columns: {list(df.columns)}")
|
|
||||||
ep_df = df[df["episode_index"] == episode].copy()
|
|
||||||
if ep_df.empty:
|
|
||||||
print(f" WARNING: No sarm_progress rows for episode {episode}")
|
|
||||||
return None
|
|
||||||
ep_df = ep_df.sort_values("frame_index")
|
|
||||||
|
|
||||||
# Prefer dense, fall back to sparse
|
|
||||||
if "progress_dense" in ep_df.columns and ep_df["progress_dense"].notna().any():
|
|
||||||
prog_col = "progress_dense"
|
|
||||||
elif "progress_sparse" in ep_df.columns:
|
|
||||||
prog_col = "progress_sparse"
|
|
||||||
else:
|
|
||||||
# Last resort: any column with 'progress' in the name
|
|
||||||
prog_cols = [c for c in ep_df.columns if "progress" in c.lower()]
|
|
||||||
if not prog_cols:
|
|
||||||
return None
|
|
||||||
prog_col = prog_cols[0]
|
|
||||||
|
|
||||||
print(f" Using progress column: '{prog_col}'")
|
|
||||||
return ep_df[["frame_index", prog_col]].rename(columns={prog_col: "progress"}).values
|
|
||||||
|
|
||||||
|
|
||||||
def extract_episode_clip(video_path: Path, from_ts: float, to_ts: float, out_path: Path) -> Path:
|
|
||||||
"""Use ffmpeg to cut the episode segment from the combined video file."""
|
|
||||||
duration = to_ts - from_ts
|
|
||||||
print(f"[3/5] Extracting clip [{from_ts:.3f}s → {to_ts:.3f}s] ({duration:.2f}s) …")
|
|
||||||
cmd = [
|
|
||||||
"ffmpeg",
|
|
||||||
"-y",
|
|
||||||
"-ss",
|
|
||||||
str(from_ts),
|
|
||||||
"-i",
|
|
||||||
str(video_path),
|
|
||||||
"-t",
|
|
||||||
str(duration),
|
|
||||||
"-c:v",
|
|
||||||
"libx264",
|
|
||||||
"-preset",
|
|
||||||
"fast",
|
|
||||||
"-crf",
|
|
||||||
"18",
|
|
||||||
"-an",
|
|
||||||
str(out_path),
|
|
||||||
]
|
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(f"ffmpeg clip extraction failed:\n{result.stderr}")
|
|
||||||
return out_path
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_pixels(
|
|
||||||
progress_data: np.ndarray,
|
|
||||||
n_frames: int,
|
|
||||||
frame_w: int,
|
|
||||||
frame_h: int,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Map each progress sample to pixel coordinates.
|
|
||||||
Returns array of shape (N, 2) with (x, y) in pixel space.
|
|
||||||
x spans full video width; y maps progress [0,1] to graph band.
|
|
||||||
"""
|
|
||||||
frame_indices = progress_data[:, 0].astype(float)
|
|
||||||
progress_vals = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
|
|
||||||
|
|
||||||
y_top = int(frame_h * GRAPH_Y_TOP_FRAC)
|
|
||||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
|
||||||
graph_h = y_bot - y_top
|
|
||||||
|
|
||||||
xs = (frame_indices / (n_frames - 1) * (frame_w - 1)).astype(int)
|
|
||||||
# progress=1 → y_top, progress=0 → y_bot
|
|
||||||
ys = (y_bot - progress_vals * graph_h).astype(int)
|
|
||||||
|
|
||||||
return np.stack([xs, ys], axis=1) # (N, 2)
|
|
||||||
|
|
||||||
|
|
||||||
def progress_color(t: float) -> tuple[int, int, int]:
|
|
||||||
"""Interpolate BGR color red→green based on normalised position t in [0,1]."""
|
|
||||||
r = int(255 * (1.0 - t))
|
|
||||||
g = int(255 * t)
|
|
||||||
return (0, g, r) # BGR
|
|
||||||
|
|
||||||
|
|
||||||
def prerender_fill(
|
|
||||||
pixels: np.ndarray,
|
|
||||||
frame_w: int,
|
|
||||||
frame_h: int,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Pre-render the full grey fill polygon under the curve as a BGRA image."""
|
|
||||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
|
||||||
fill_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
|
||||||
poly = np.concatenate(
|
|
||||||
[
|
|
||||||
pixels,
|
|
||||||
[[pixels[-1][0], y_bot], [pixels[0][0], y_bot]],
|
|
||||||
],
|
|
||||||
axis=0,
|
|
||||||
).astype(np.int32)
|
|
||||||
cv2.fillPoly(fill_img, [poly], color=(128, 128, 128, int(255 * FILL_ALPHA)))
|
|
||||||
return fill_img
|
|
||||||
|
|
||||||
|
|
||||||
def alpha_composite(base: np.ndarray, overlay_bgra: np.ndarray, x_max: int) -> None:
|
|
||||||
"""Blend overlay onto base in-place, but only for x < x_max."""
|
|
||||||
if x_max <= 0:
|
|
||||||
return
|
|
||||||
roi_b = base[:, :x_max]
|
|
||||||
roi_o = overlay_bgra[:, :x_max]
|
|
||||||
alpha = roi_o[:, :, 3:4].astype(np.float32) / 255.0
|
|
||||||
roi_b[:] = np.clip(
|
|
||||||
roi_o[:, :, :3].astype(np.float32) * alpha + roi_b.astype(np.float32) * (1.0 - alpha),
|
|
||||||
0,
|
|
||||||
255,
|
|
||||||
).astype(np.uint8)
|
|
||||||
|
|
||||||
|
|
||||||
def draw_text_outlined(
|
|
||||||
frame: np.ndarray,
|
|
||||||
text: str,
|
|
||||||
pos: tuple[int, int],
|
|
||||||
font_scale: float,
|
|
||||||
thickness: int = 1,
|
|
||||||
) -> None:
|
|
||||||
"""Draw text with a dark outline for readability on any background."""
|
|
||||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
||||||
cv2.putText(frame, text, pos, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
|
|
||||||
cv2.putText(frame, text, pos, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
|
||||||
|
|
||||||
|
|
||||||
def composite_video(
|
|
||||||
clip_path: Path,
|
|
||||||
progress_data: np.ndarray,
|
|
||||||
out_path: Path,
|
|
||||||
fps: float,
|
|
||||||
frame_h: int,
|
|
||||||
frame_w: int,
|
|
||||||
task_name: str = "",
|
|
||||||
) -> Path:
|
|
||||||
"""Read clip frames, draw gradient progress line with fill + labels, export as GIF."""
|
|
||||||
n_total = int(cv2.VideoCapture(str(clip_path)).get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
pixels = precompute_pixels(progress_data, n_total, frame_w, frame_h)
|
|
||||||
|
|
||||||
y_ref = int(frame_h * GRAPH_Y_TOP_FRAC)
|
|
||||||
|
|
||||||
# Pre-render fill polygon (line is drawn per-frame with live color)
|
|
||||||
fill_img = prerender_fill(pixels, frame_w, frame_h)
|
|
||||||
|
|
||||||
# 1.0 reference line overlay (full width, drawn once)
|
|
||||||
ref_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
|
||||||
cv2.line(ref_img, (0, y_ref), (frame_w - 1, y_ref), (200, 200, 200, int(255 * REF_ALPHA)), 1, cv2.LINE_AA)
|
|
||||||
|
|
||||||
frame_indices = progress_data[:, 0].astype(int)
|
|
||||||
progress_vals = progress_data[:, 1].astype(float)
|
|
||||||
|
|
||||||
print(f"[4/4] Compositing {n_total} frames …")
|
|
||||||
cap = cv2.VideoCapture(str(clip_path))
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
|
||||||
tmp_path = out_path.parent / (out_path.stem + "_tmp.mp4")
|
|
||||||
writer = cv2.VideoWriter(str(tmp_path), fourcc, fps, (frame_w, frame_h))
|
|
||||||
|
|
||||||
fi = 0
|
|
||||||
while True:
|
|
||||||
ret, frame = cap.read()
|
|
||||||
if not ret:
|
|
||||||
break
|
|
||||||
|
|
||||||
n_drawn = int(np.searchsorted(frame_indices, fi, side="right"))
|
|
||||||
x_cur = int(pixels[min(n_drawn, len(pixels)) - 1][0]) + 1 if n_drawn > 0 else 0
|
|
||||||
|
|
||||||
# 1. reference line (full width, always)
|
|
||||||
alpha_composite(frame, ref_img, frame_w)
|
|
||||||
|
|
||||||
# 2. grey fill under curve up to current x
|
|
||||||
alpha_composite(frame, fill_img, x_cur)
|
|
||||||
|
|
||||||
# 3. progress line — single color that transitions red→green over time
|
|
||||||
if n_drawn >= 2:
|
|
||||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
|
||||||
line_col = progress_color(t_cur)
|
|
||||||
pts = pixels[:n_drawn].reshape(-1, 1, 2).astype(np.int32)
|
|
||||||
cv2.polylines(
|
|
||||||
frame,
|
|
||||||
[pts],
|
|
||||||
isClosed=False,
|
|
||||||
color=(255, 255, 255),
|
|
||||||
thickness=SHADOW_THICKNESS,
|
|
||||||
lineType=cv2.LINE_AA,
|
|
||||||
)
|
|
||||||
cv2.polylines(
|
|
||||||
frame, [pts], isClosed=False, color=line_col, thickness=LINE_THICKNESS, lineType=cv2.LINE_AA
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. score — bottom right
|
|
||||||
if n_drawn > 0:
|
|
||||||
score = float(progress_vals[min(n_drawn, len(progress_vals)) - 1])
|
|
||||||
score_text = f"{score:.2f}"
|
|
||||||
(tw, th), _ = cv2.getTextSize(score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2)
|
|
||||||
sx = frame_w - tw - 12
|
|
||||||
sy = frame_h - 12
|
|
||||||
# coloured score matching current gradient position
|
|
||||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
|
||||||
score_col = progress_color(t_cur)
|
|
||||||
cv2.putText(
|
|
||||||
frame,
|
|
||||||
score_text,
|
|
||||||
(sx, sy),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
|
||||||
SCORE_FONT_SCALE,
|
|
||||||
(0, 0, 0),
|
|
||||||
4,
|
|
||||||
cv2.LINE_AA,
|
|
||||||
)
|
|
||||||
cv2.putText(
|
|
||||||
frame,
|
|
||||||
score_text,
|
|
||||||
(sx, sy),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
|
||||||
SCORE_FONT_SCALE,
|
|
||||||
score_col,
|
|
||||||
2,
|
|
||||||
cv2.LINE_AA,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. task name — top centre
|
|
||||||
if task_name:
|
|
||||||
(tw, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
|
|
||||||
tx = max((frame_w - tw) // 2, 4)
|
|
||||||
draw_text_outlined(frame, task_name, (tx, 22), TASK_FONT_SCALE)
|
|
||||||
|
|
||||||
writer.write(frame)
|
|
||||||
fi += 1
|
|
||||||
if fi % 100 == 0:
|
|
||||||
print(f" Frame {fi}/{n_total} …", end="\r")
|
|
||||||
|
|
||||||
cap.release()
|
|
||||||
writer.release()
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Convert to GIF: full resolution, 12fps, 128-color diff palette (<40MB)
|
|
||||||
gif_path = out_path.with_suffix(".gif")
|
|
||||||
palette = out_path.parent / "_palette.png"
|
|
||||||
r1 = subprocess.run( # nosec B607
|
|
||||||
[
|
|
||||||
"ffmpeg",
|
|
||||||
"-y",
|
|
||||||
"-i",
|
|
||||||
str(tmp_path),
|
|
||||||
"-vf",
|
|
||||||
f"fps=10,scale={frame_w}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
|
|
||||||
"-update",
|
|
||||||
"1",
|
|
||||||
str(palette),
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
if r1.returncode != 0:
|
|
||||||
print(f" WARNING: palettegen failed:\n{r1.stderr[-500:]}")
|
|
||||||
r2 = subprocess.run( # nosec B607
|
|
||||||
[
|
|
||||||
"ffmpeg",
|
|
||||||
"-y",
|
|
||||||
"-i",
|
|
||||||
str(tmp_path),
|
|
||||||
"-i",
|
|
||||||
str(palette),
|
|
||||||
"-filter_complex",
|
|
||||||
f"fps=10,scale={frame_w}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
|
|
||||||
str(gif_path),
|
|
||||||
],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
)
|
|
||||||
if r2.returncode != 0:
|
|
||||||
print(f" WARNING: gif encode failed:\n{r2.stderr[-500:]}")
|
|
||||||
tmp_path.unlink(missing_ok=True)
|
|
||||||
palette.unlink(missing_ok=True)
|
|
||||||
return gif_path
|
|
||||||
|
|
||||||
|
|
||||||
def process_dataset(repo_id: str, episode: int):
|
|
||||||
safe_name = repo_id.replace("/", "_")
|
|
||||||
print(f"\n{'=' * 60}")
|
|
||||||
print(f"Processing: {repo_id} | episode {episode}")
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
|
|
||||||
# 1. Download metadata
|
|
||||||
local = download_episode(repo_id, episode)
|
|
||||||
print(f" Local cache: {local}")
|
|
||||||
|
|
||||||
# 2. Read episode metadata
|
|
||||||
ep_meta = load_episode_meta(local, episode)
|
|
||||||
print(f" Episode meta: {ep_meta}")
|
|
||||||
|
|
||||||
# 3. Download video file
|
|
||||||
video_path = download_video(repo_id, local, ep_meta["video_rel"])
|
|
||||||
|
|
||||||
# 4. Extract clip
|
|
||||||
clip_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_clip.mp4"
|
|
||||||
extract_episode_clip(video_path, ep_meta["from_ts"], ep_meta["to_ts"], clip_path)
|
|
||||||
|
|
||||||
# 5. Load progress data
|
|
||||||
progress_data = load_progress(local, episode)
|
|
||||||
if progress_data is None:
|
|
||||||
print(" ERROR: Could not load sarm_progress data. Skipping overlay.")
|
|
||||||
return
|
|
||||||
|
|
||||||
n_progress = len(progress_data)
|
|
||||||
print(f" Progress frames: {n_progress}")
|
|
||||||
|
|
||||||
# 6. Get clip dimensions
|
|
||||||
cap = cv2.VideoCapture(str(clip_path))
|
|
||||||
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
||||||
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
||||||
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
||||||
actual_fps = cap.get(cv2.CAP_PROP_FPS) or ep_meta["fps"]
|
|
||||||
cap.release()
|
|
||||||
print(f" Clip: {frame_w}×{frame_h} {n_frames} frames @ {actual_fps:.1f}fps")
|
|
||||||
|
|
||||||
# 7. Composite (draw line directly on frames)
|
|
||||||
out_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_progress.mp4"
|
|
||||||
final = composite_video(
|
|
||||||
clip_path,
|
|
||||||
progress_data,
|
|
||||||
out_path,
|
|
||||||
actual_fps,
|
|
||||||
frame_h,
|
|
||||||
frame_w,
|
|
||||||
task_name=ep_meta.get("task_name", ""),
|
|
||||||
)
|
|
||||||
clip_path.unlink(missing_ok=True)
|
|
||||||
print(f"\n✓ Done: {final}")
|
|
||||||
return final
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
results = []
|
|
||||||
for cfg in DATASETS:
|
|
||||||
try:
|
|
||||||
out = process_dataset(cfg["repo_id"], cfg["episode"])
|
|
||||||
if out:
|
|
||||||
results.append(out)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\nERROR processing {cfg['repo_id']}: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Output files:")
|
|
||||||
for r in results:
|
|
||||||
print(f" {r}")
|
|
||||||
@@ -1,496 +0,0 @@
|
|||||||
"""
|
|
||||||
Visualize end-effector workspace density and trajectory clusters for OpenArm datasets.
|
|
||||||
Downloads joint position data (no videos) from HuggingFace, computes forward
|
|
||||||
kinematics per episode, clusters trajectories with K-means, and renders
|
|
||||||
2D projections comparing dataset coverage and multimodality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from sklearn.cluster import KMeans
|
|
||||||
|
|
||||||
DATASETS = [
|
|
||||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
|
||||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
|
||||||
]
|
|
||||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
|
||||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
N_CLUSTERS = 10
|
|
||||||
WAYPOINTS = 50
|
|
||||||
SEED = 42
|
|
||||||
DPI = 180
|
|
||||||
|
|
||||||
CLUSTER_COLORS = [
|
|
||||||
"#e6194b",
|
|
||||||
"#3cb44b",
|
|
||||||
"#4363d8",
|
|
||||||
"#f58231",
|
|
||||||
"#911eb4",
|
|
||||||
"#42d4f4",
|
|
||||||
"#f032e6",
|
|
||||||
"#bfef45",
|
|
||||||
"#fabed4",
|
|
||||||
"#dcbeff",
|
|
||||||
"#9a6324",
|
|
||||||
"#fffac8",
|
|
||||||
"#800000",
|
|
||||||
"#aaffc3",
|
|
||||||
"#808000",
|
|
||||||
"#ffd8b1",
|
|
||||||
"#000075",
|
|
||||||
"#a9a9a9",
|
|
||||||
]
|
|
||||||
|
|
||||||
# FK chains extracted from OpenArm bimanual URDF.
|
|
||||||
# Each entry: (rpy, xyz, revolute_axis_or_None).
|
|
||||||
LEFT_CHAIN = [
|
|
||||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
|
||||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
|
||||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
|
||||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
|
||||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
|
||||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
|
||||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
|
||||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
|
||||||
((0, 0, 0), (0, 0, 0.1001), None),
|
|
||||||
((0, 0, 0), (0, 0, 0.08), None),
|
|
||||||
]
|
|
||||||
RIGHT_CHAIN = [
|
|
||||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
|
||||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
|
||||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
|
||||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
|
||||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
|
||||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
|
||||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
|
||||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
|
||||||
((0, 0, 0), (0, 0, 0.1001), None),
|
|
||||||
((0, 0, 0), (0, 0, 0.08), None),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ── FK math ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _rot_x(a: float) -> np.ndarray:
|
|
||||||
c, s = np.cos(a), np.sin(a)
|
|
||||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
|
||||||
|
|
||||||
|
|
||||||
def _rot_y(a: float) -> np.ndarray:
|
|
||||||
c, s = np.cos(a), np.sin(a)
|
|
||||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
|
||||||
|
|
||||||
|
|
||||||
def _rot_z(a: float) -> np.ndarray:
|
|
||||||
c, s = np.cos(a), np.sin(a)
|
|
||||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
|
||||||
|
|
||||||
|
|
||||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
|
||||||
"""Build a 4x4 homogeneous transform from URDF rpy + xyz."""
|
|
||||||
r, p, y = rpy
|
|
||||||
mat = np.eye(4)
|
|
||||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
|
||||||
mat[:3, 3] = xyz
|
|
||||||
return mat
|
|
||||||
|
|
||||||
|
|
||||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
|
||||||
"""Batched Rodrigues rotation: (n,) angles around a fixed axis → (n, 4, 4)."""
|
|
||||||
n = len(angles)
|
|
||||||
ax = np.asarray(axis, dtype=np.float64)
|
|
||||||
ax = ax / np.linalg.norm(ax)
|
|
||||||
x, y, z = ax
|
|
||||||
c = np.cos(angles)
|
|
||||||
s = np.sin(angles)
|
|
||||||
t = 1 - c
|
|
||||||
rot = np.zeros((n, 4, 4))
|
|
||||||
rot[:, 0, 0] = t * x * x + c
|
|
||||||
rot[:, 0, 1] = t * x * y - s * z
|
|
||||||
rot[:, 0, 2] = t * x * z + s * y
|
|
||||||
rot[:, 1, 0] = t * x * y + s * z
|
|
||||||
rot[:, 1, 1] = t * y * y + c
|
|
||||||
rot[:, 1, 2] = t * y * z - s * x
|
|
||||||
rot[:, 2, 0] = t * x * z - s * y
|
|
||||||
rot[:, 2, 1] = t * y * z + s * x
|
|
||||||
rot[:, 2, 2] = t * z * z + c
|
|
||||||
rot[:, 3, 3] = 1.0
|
|
||||||
return rot
|
|
||||||
|
|
||||||
|
|
||||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
|
||||||
"""Vectorized FK: (n, 7) radians → (n, 3) TCP positions in world frame."""
|
|
||||||
n = joint_angles.shape[0]
|
|
||||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
|
||||||
qi = 0
|
|
||||||
for rpy, xyz, axis in chain:
|
|
||||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
|
||||||
if axis is not None:
|
|
||||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
|
||||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
|
||||||
qi += 1
|
|
||||||
return tf_batch[:, :3, 3]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Data loading ────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _flatten_names(obj: object) -> list[str]:
|
|
||||||
"""Recursively flatten a names structure (list, dict, or nested) into a flat string list."""
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
out: list[str] = []
|
|
||||||
for v in obj.values():
|
|
||||||
out.extend(_flatten_names(v))
|
|
||||||
return out
|
|
||||||
if isinstance(obj, (list, tuple)):
|
|
||||||
out = []
|
|
||||||
for item in obj:
|
|
||||||
if isinstance(item, (list, tuple, dict)):
|
|
||||||
out.extend(_flatten_names(item))
|
|
||||||
else:
|
|
||||||
out.append(str(item))
|
|
||||||
return out
|
|
||||||
return [str(obj)]
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
|
||||||
"""Auto-detect servo ticks / degrees / radians and convert to radians."""
|
|
||||||
mx = np.max(np.abs(vals))
|
|
||||||
if mx > 360:
|
|
||||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
|
||||||
return (vals - 2048) / 2048 * np.pi
|
|
||||||
if mx > 6.3:
|
|
||||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
|
||||||
return np.deg2rad(vals)
|
|
||||||
print(f" Unit detection: radians (max={mx:.3f})")
|
|
||||||
return vals.astype(np.float64)
|
|
||||||
|
|
||||||
|
|
||||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
|
||||||
"""Try to find left/right joint indices from info.json feature names."""
|
|
||||||
feat = features.get("observation.state", features.get(state_col, {}))
|
|
||||||
names = _flatten_names(feat.get("names", []))
|
|
||||||
|
|
||||||
left_idx: list[int] = []
|
|
||||||
right_idx: list[int] = []
|
|
||||||
if names and len(names) == n_dim:
|
|
||||||
names_l = [n.lower() for n in names]
|
|
||||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
|
||||||
for j in range(1, 8):
|
|
||||||
for i, nm in enumerate(names_l):
|
|
||||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
|
||||||
left_idx.append(i)
|
|
||||||
break
|
|
||||||
for i, nm in enumerate(names_l):
|
|
||||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
|
||||||
right_idx.append(i)
|
|
||||||
break
|
|
||||||
|
|
||||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
|
||||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
|
||||||
return left_idx, right_idx
|
|
||||||
if n_dim >= 16:
|
|
||||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
|
||||||
return list(range(7)), list(range(8, 15))
|
|
||||||
if n_dim >= 14:
|
|
||||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
|
||||||
return list(range(7)), list(range(7, 14))
|
|
||||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
|
||||||
|
|
||||||
|
|
||||||
def download_data(repo_id: str) -> Path:
|
|
||||||
print(f" Downloading {repo_id} (parquet only) …")
|
|
||||||
return Path(
|
|
||||||
snapshot_download(
|
|
||||||
repo_id=repo_id,
|
|
||||||
repo_type="dataset",
|
|
||||||
allow_patterns=["meta/**", "data/**"],
|
|
||||||
ignore_patterns=["*.mp4", "videos/**"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_trajectory(traj: np.ndarray, n_waypoints: int) -> np.ndarray:
|
|
||||||
"""Resample a (F, 3) trajectory to exactly n_waypoints via linear interpolation."""
|
|
||||||
f = traj.shape[0]
|
|
||||||
if f == n_waypoints:
|
|
||||||
return traj
|
|
||||||
old_t = np.linspace(0, 1, f)
|
|
||||||
new_t = np.linspace(0, 1, n_waypoints)
|
|
||||||
return np.column_stack([np.interp(new_t, old_t, traj[:, d]) for d in range(3)])
|
|
||||||
|
|
||||||
|
|
||||||
def load_episode_trajectories(local: Path) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Load per-episode joint data, compute FK, return list of trajectory dicts.
|
|
||||||
Each dict: {"left_tcp": (F,3), "right_tcp": (F,3), "episode_index": int}.
|
|
||||||
Uses all episodes in the dataset for a fair comparison.
|
|
||||||
"""
|
|
||||||
info = json.loads((local / "meta" / "info.json").read_text())
|
|
||||||
features = info.get("features", {})
|
|
||||||
|
|
||||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
|
||||||
df = pd.concat(dfs, ignore_index=True)
|
|
||||||
print(f" Total frames: {len(df):,}")
|
|
||||||
|
|
||||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
|
||||||
if state_col is None:
|
|
||||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
|
||||||
|
|
||||||
first = df[state_col].iloc[0]
|
|
||||||
if not hasattr(first, "__len__"):
|
|
||||||
raise RuntimeError(f"observation.state is scalar ({type(first)}), expected array")
|
|
||||||
|
|
||||||
state = np.stack(df[state_col].values).astype(np.float64)
|
|
||||||
n_dim = state.shape[1]
|
|
||||||
print(f" State dim: {n_dim} max|val|: {np.max(np.abs(state)):.1f}")
|
|
||||||
|
|
||||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
|
||||||
|
|
||||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
|
||||||
if ep_col is None:
|
|
||||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
|
||||||
|
|
||||||
episode_ids = df[ep_col].values
|
|
||||||
unique_eps = np.unique(episode_ids)
|
|
||||||
print(f" Episodes: {len(unique_eps):,}")
|
|
||||||
|
|
||||||
left_raw = state[:, left_idx]
|
|
||||||
right_raw = state[:, right_idx]
|
|
||||||
left_all = _detect_and_convert(left_raw)
|
|
||||||
right_all = _detect_and_convert(right_raw)
|
|
||||||
|
|
||||||
print(" Computing FK per episode …")
|
|
||||||
trajectories = []
|
|
||||||
for ep_id in unique_eps:
|
|
||||||
mask = episode_ids == ep_id
|
|
||||||
left_tcp = batch_fk(LEFT_CHAIN, left_all[mask])
|
|
||||||
right_tcp = batch_fk(RIGHT_CHAIN, right_all[mask])
|
|
||||||
if len(left_tcp) < 3:
|
|
||||||
continue
|
|
||||||
trajectories.append({"left_tcp": left_tcp, "right_tcp": right_tcp, "episode_index": int(ep_id)})
|
|
||||||
|
|
||||||
print(f" Valid trajectories: {len(trajectories):,}")
|
|
||||||
return trajectories
|
|
||||||
|
|
||||||
|
|
||||||
# ── Clustering ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def cluster_trajectories(
|
|
||||||
trajectories: list[dict], n_clusters: int, n_waypoints: int
|
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""
|
|
||||||
K-means on resampled trajectory features.
|
|
||||||
Combines left+right TCP into a single feature vector per episode.
|
|
||||||
Returns (labels, centroid_trajs (k, waypoints, 6), spread_per_cluster (k,) in metres).
|
|
||||||
Spread = mean per-waypoint Euclidean distance from each trajectory to its centroid.
|
|
||||||
"""
|
|
||||||
feat_vecs = []
|
|
||||||
for t in trajectories:
|
|
||||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
|
||||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
|
||||||
feat_vecs.append(np.concatenate([left_rs.ravel(), right_rs.ravel()]))
|
|
||||||
feat_matrix = np.array(feat_vecs)
|
|
||||||
|
|
||||||
k = min(n_clusters, len(feat_vecs))
|
|
||||||
km = KMeans(n_clusters=k, n_init=10, random_state=SEED)
|
|
||||||
labels = km.fit_predict(feat_matrix)
|
|
||||||
|
|
||||||
centroids_flat = km.cluster_centers_
|
|
||||||
centroid_trajs = np.zeros((k, n_waypoints, 6))
|
|
||||||
for ci in range(k):
|
|
||||||
left_flat = centroids_flat[ci, : n_waypoints * 3]
|
|
||||||
right_flat = centroids_flat[ci, n_waypoints * 3 :]
|
|
||||||
centroid_trajs[ci, :, :3] = left_flat.reshape(n_waypoints, 3)
|
|
||||||
centroid_trajs[ci, :, 3:] = right_flat.reshape(n_waypoints, 3)
|
|
||||||
|
|
||||||
# Mean per-waypoint distance to centroid (in metres) for each cluster
|
|
||||||
spread = np.zeros(k)
|
|
||||||
for ci in range(k):
|
|
||||||
members = np.where(labels == ci)[0]
|
|
||||||
if len(members) == 0:
|
|
||||||
continue
|
|
||||||
centroid_left = centroid_trajs[ci, :, :3]
|
|
||||||
centroid_right = centroid_trajs[ci, :, 3:]
|
|
||||||
dists = []
|
|
||||||
for mi in members:
|
|
||||||
t = trajectories[mi]
|
|
||||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
|
||||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
|
||||||
d_left = np.linalg.norm(left_rs - centroid_left, axis=1).mean()
|
|
||||||
d_right = np.linalg.norm(right_rs - centroid_right, axis=1).mean()
|
|
||||||
dists.append((d_left + d_right) / 2)
|
|
||||||
spread[ci] = np.mean(dists)
|
|
||||||
|
|
||||||
return labels, centroid_trajs, spread
|
|
||||||
|
|
||||||
|
|
||||||
# ── Visualization ───────────────────────────────────────
|
|
||||||
|
|
||||||
PROJ_VIEWS = [
|
|
||||||
("XZ (side)", 0, 2, "X (m)", "Z (m)"),
|
|
||||||
("XY (top)", 0, 1, "X (m)", "Y (m)"),
|
|
||||||
("YZ (front)", 1, 2, "Y (m)", "Z (m)"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def render(results: list[dict], out_path: Path) -> None:
|
|
||||||
"""
|
|
||||||
2-row × 3-col grid per dataset (3 projections × 2 datasets).
|
|
||||||
Trajectory lines colored by cluster, centroid trajectories drawn thick.
|
|
||||||
"""
|
|
||||||
n_ds = len(results)
|
|
||||||
n_proj = len(PROJ_VIEWS)
|
|
||||||
fig, axes = plt.subplots(n_ds, n_proj, figsize=(7 * n_proj, 7 * n_ds), facecolor="#0d1117")
|
|
||||||
if n_ds == 1:
|
|
||||||
axes = axes[np.newaxis, :]
|
|
||||||
|
|
||||||
for row, r in enumerate(results):
|
|
||||||
trajectories = r["trajectories"]
|
|
||||||
labels = r["labels"]
|
|
||||||
centroids = r["centroids"]
|
|
||||||
k = centroids.shape[0]
|
|
||||||
|
|
||||||
cluster_sizes = np.bincount(labels, minlength=k)
|
|
||||||
size_order = np.argsort(-cluster_sizes)
|
|
||||||
pcts = cluster_sizes / len(labels) * 100
|
|
||||||
spread = r["spread"]
|
|
||||||
|
|
||||||
for col, (view_name, dim_a, dim_b, xlabel, ylabel) in enumerate(PROJ_VIEWS):
|
|
||||||
ax = axes[row, col]
|
|
||||||
ax.set_facecolor("#0d1117")
|
|
||||||
|
|
||||||
for ti, traj in enumerate(trajectories):
|
|
||||||
color = CLUSTER_COLORS[labels[ti] % len(CLUSTER_COLORS)]
|
|
||||||
for tcp_key in ("left_tcp", "right_tcp"):
|
|
||||||
pts = traj[tcp_key]
|
|
||||||
ax.plot(pts[:, dim_a], pts[:, dim_b], color=color, alpha=0.12, linewidth=0.4)
|
|
||||||
|
|
||||||
for ci in range(k):
|
|
||||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
|
||||||
left_c = centroids[ci, :, :3]
|
|
||||||
right_c = centroids[ci, :, 3:]
|
|
||||||
lw = 1.5 + 2.0 * cluster_sizes[ci] / cluster_sizes.max()
|
|
||||||
for c_pts in (left_c, right_c):
|
|
||||||
ax.plot(
|
|
||||||
c_pts[:, dim_a],
|
|
||||||
c_pts[:, dim_b],
|
|
||||||
color=color,
|
|
||||||
linewidth=lw,
|
|
||||||
alpha=0.95,
|
|
||||||
zorder=10,
|
|
||||||
)
|
|
||||||
ax.plot(
|
|
||||||
c_pts[0, dim_a],
|
|
||||||
c_pts[0, dim_b],
|
|
||||||
"o",
|
|
||||||
color=color,
|
|
||||||
markersize=4,
|
|
||||||
zorder=11,
|
|
||||||
)
|
|
||||||
ax.plot(
|
|
||||||
c_pts[-1, dim_a],
|
|
||||||
c_pts[-1, dim_b],
|
|
||||||
"s",
|
|
||||||
color=color,
|
|
||||||
markersize=4,
|
|
||||||
zorder=11,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlabel(xlabel, color="#888", fontsize=9)
|
|
||||||
ax.set_ylabel(ylabel, color="#888", fontsize=9)
|
|
||||||
ax.tick_params(colors="#555", labelsize=7)
|
|
||||||
for spine in ax.spines.values():
|
|
||||||
spine.set_color("#333")
|
|
||||||
ax.set_aspect("equal")
|
|
||||||
|
|
||||||
mean_spread_cm = np.average(spread, weights=cluster_sizes) * 100
|
|
||||||
if col == 0:
|
|
||||||
ax.set_title(
|
|
||||||
f"{r['label']} ({r['n_episodes']:,} episodes, {k} clusters, "
|
|
||||||
f"avg spread {mean_spread_cm:.1f}cm)",
|
|
||||||
color="white",
|
|
||||||
fontsize=11,
|
|
||||||
pad=10,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ax.set_title(view_name, color="#aaa", fontsize=10, pad=8)
|
|
||||||
|
|
||||||
# Cluster size + spread legend on the rightmost panel
|
|
||||||
legend_ax = axes[row, -1]
|
|
||||||
for ci in size_order:
|
|
||||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
|
||||||
spread_cm = spread[ci] * 100
|
|
||||||
label = f"C{ci}: {cluster_sizes[ci]} eps ({pcts[ci]:.0f}%) ±{spread_cm:.1f}cm"
|
|
||||||
legend_ax.plot([], [], color=color, linewidth=3, label=label)
|
|
||||||
legend_ax.legend(
|
|
||||||
loc="upper right",
|
|
||||||
fontsize=7,
|
|
||||||
frameon=True,
|
|
||||||
facecolor="#1a1a2e",
|
|
||||||
edgecolor="#333",
|
|
||||||
labelcolor="white",
|
|
||||||
handlelength=1.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig.suptitle(
|
|
||||||
"End-Effector Trajectory Clusters (FK · K-means)",
|
|
||||||
color="white",
|
|
||||||
fontsize=16,
|
|
||||||
y=0.98,
|
|
||||||
)
|
|
||||||
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
|
||||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
|
||||||
plt.close()
|
|
||||||
print(f"\n✓ Saved: {out_path}")
|
|
||||||
|
|
||||||
|
|
||||||
# ── Main ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for ds in DATASETS:
|
|
||||||
repo_id, label = ds["repo_id"], ds["label"]
|
|
||||||
print(f"\n{'=' * 60}")
|
|
||||||
print(f" {label}: {repo_id}")
|
|
||||||
print(f"{'=' * 60}")
|
|
||||||
|
|
||||||
local = download_data(repo_id)
|
|
||||||
trajectories = load_episode_trajectories(local)
|
|
||||||
labels, centroids, spread = cluster_trajectories(trajectories, N_CLUSTERS, WAYPOINTS)
|
|
||||||
|
|
||||||
cluster_sizes = np.bincount(labels, minlength=centroids.shape[0])
|
|
||||||
print(f" Cluster sizes: {sorted(cluster_sizes, reverse=True)}")
|
|
||||||
for ci in np.argsort(-cluster_sizes):
|
|
||||||
print(
|
|
||||||
f" C{ci}: {cluster_sizes[ci]} eps ({cluster_sizes[ci] / len(labels) * 100:.0f}%) "
|
|
||||||
f"spread ±{spread[ci] * 100:.1f}cm"
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append(
|
|
||||||
{
|
|
||||||
"label": label,
|
|
||||||
"trajectories": trajectories,
|
|
||||||
"labels": labels,
|
|
||||||
"centroids": centroids,
|
|
||||||
"spread": spread,
|
|
||||||
"n_episodes": len(trajectories),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
out = OUTPUT_DIR / "workspace_trajectory_clusters.jpg"
|
|
||||||
render(results, out)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -131,6 +131,15 @@ class _NormalizationMixin:
|
|||||||
if self.dtype is None:
|
if self.dtype is None:
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||||
|
self._reshape_visual_stats()
|
||||||
|
|
||||||
|
def _reshape_visual_stats(self) -> None:
|
||||||
|
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||||
|
for key, feature in self.features.items():
|
||||||
|
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||||
|
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||||
|
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||||
|
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||||
|
|
||||||
def to(
|
def to(
|
||||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||||
@@ -149,6 +158,7 @@ class _NormalizationMixin:
|
|||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||||
|
self._reshape_visual_stats()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def state_dict(self) -> dict[str, Tensor]:
|
def state_dict(self) -> dict[str, Tensor]:
|
||||||
@@ -198,6 +208,7 @@ class _NormalizationMixin:
|
|||||||
# Don't load from state_dict, keep the explicitly provided stats
|
# Don't load from state_dict, keep the explicitly provided stats
|
||||||
# But ensure _tensor_stats is properly initialized
|
# But ensure _tensor_stats is properly initialized
|
||||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||||
|
self._reshape_visual_stats()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Normal behavior: load stats from state_dict
|
# Normal behavior: load stats from state_dict
|
||||||
@@ -209,6 +220,8 @@ class _NormalizationMixin:
|
|||||||
dtype=torch.float32, device=self.device
|
dtype=torch.float32, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._reshape_visual_stats()
|
||||||
|
|
||||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||||
# and other functions that rely on self.stats
|
# and other functions that rely on self.stats
|
||||||
self.stats = {}
|
self.stats = {}
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ from lerobot.configs import parser
|
|||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||||
|
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||||
from lerobot.rl.process import ProcessSignalHandler
|
from lerobot.rl.process import ProcessSignalHandler
|
||||||
from lerobot.rl.queue import get_last_item_from_queue
|
from lerobot.rl.queue import get_last_item_from_queue
|
||||||
from lerobot.robots import so_follower # noqa: F401
|
from lerobot.robots import so_follower # noqa: F401
|
||||||
@@ -258,6 +259,11 @@ def act_with_policy(
|
|||||||
policy = policy.eval()
|
policy = policy.eval()
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
|
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||||
|
config=cfg.policy,
|
||||||
|
dataset_stats=cfg.policy.dataset_stats,
|
||||||
|
)
|
||||||
|
|
||||||
obs, info = online_env.reset()
|
obs, info = online_env.reset()
|
||||||
env_processor.reset()
|
env_processor.reset()
|
||||||
action_processor.reset()
|
action_processor.reset()
|
||||||
@@ -289,7 +295,9 @@ def act_with_policy(
|
|||||||
# Time policy inference and check if it meets FPS requirement
|
# Time policy inference and check if it meets FPS requirement
|
||||||
with policy_timer:
|
with policy_timer:
|
||||||
# Extract observation from transition for policy
|
# Extract observation from transition for policy
|
||||||
action = policy.select_action(batch=observation)
|
normalized_observation = preprocessor.process_observation(observation)
|
||||||
|
action = policy.select_action(batch=normalized_observation)
|
||||||
|
# action = postprocessor.process_action(action)
|
||||||
policy_fps = policy_timer.fps_last
|
policy_fps = policy_timer.fps_last
|
||||||
|
|
||||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||||
|
|||||||
@@ -66,6 +66,7 @@ from lerobot.datasets.factory import make_dataset
|
|||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||||
|
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||||
from lerobot.rl.process import ProcessSignalHandler
|
from lerobot.rl.process import ProcessSignalHandler
|
||||||
from lerobot.rl.wandb_utils import WandBLogger
|
from lerobot.rl.wandb_utils import WandBLogger
|
||||||
@@ -313,6 +314,11 @@ def add_actor_information_and_train(
|
|||||||
|
|
||||||
assert isinstance(policy, nn.Module)
|
assert isinstance(policy, nn.Module)
|
||||||
|
|
||||||
|
preprocessor, _ = make_sac_pre_post_processors(
|
||||||
|
config=cfg.policy,
|
||||||
|
dataset_stats=cfg.policy.dataset_stats,
|
||||||
|
)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||||
@@ -408,6 +414,9 @@ def add_actor_information_and_train(
|
|||||||
done = batch["done"]
|
done = batch["done"]
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
|
observations = preprocessor.process_observation(observations)
|
||||||
|
next_observations = preprocessor.process_observation(next_observations)
|
||||||
|
|
||||||
observation_features, next_observation_features = get_observation_features(
|
observation_features, next_observation_features = get_observation_features(
|
||||||
policy=policy, observations=observations, next_observations=next_observations
|
policy=policy, observations=observations, next_observations=next_observations
|
||||||
)
|
)
|
||||||
@@ -467,6 +476,9 @@ def add_actor_information_and_train(
|
|||||||
|
|
||||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||||
|
|
||||||
|
observations = preprocessor.process_observation(observations)
|
||||||
|
next_observations = preprocessor.process_observation(next_observations)
|
||||||
|
|
||||||
observation_features, next_observation_features = get_observation_features(
|
observation_features, next_observation_features = get_observation_features(
|
||||||
policy=policy, observations=observations, next_observations=next_observations
|
policy=policy, observations=observations, next_observations=next_observations
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -23,65 +23,46 @@ class InputController:
|
|||||||
"""Base class for input controllers that generate motion deltas."""
|
"""Base class for input controllers that generate motion deltas."""
|
||||||
|
|
||||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||||
"""
|
|
||||||
Initialize the controller.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x_step_size: Base movement step size in meters
|
|
||||||
y_step_size: Base movement step size in meters
|
|
||||||
z_step_size: Base movement step size in meters
|
|
||||||
"""
|
|
||||||
self.x_step_size = x_step_size
|
self.x_step_size = x_step_size
|
||||||
self.y_step_size = y_step_size
|
self.y_step_size = y_step_size
|
||||||
self.z_step_size = z_step_size
|
self.z_step_size = z_step_size
|
||||||
self.running = True
|
self.running = True
|
||||||
self.episode_end_status = None # None, "success", or "failure"
|
self.episode_end_status = None
|
||||||
self.intervention_flag = False
|
self.intervention_flag = False
|
||||||
self.open_gripper_command = False
|
self.open_gripper_command = False
|
||||||
self.close_gripper_command = False
|
self.close_gripper_command = False
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the controller and initialize resources."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the controller and release resources."""
|
pass
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_deltas(self):
|
def get_deltas(self):
|
||||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
|
||||||
return 0.0, 0.0, 0.0
|
return 0.0, 0.0, 0.0
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""Update controller state - call this once per frame."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""Support for use in 'with' statements."""
|
|
||||||
self.start()
|
self.start()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Ensure resources are released when exiting 'with' block."""
|
|
||||||
self.stop()
|
self.stop()
|
||||||
|
|
||||||
def get_episode_end_status(self):
|
def get_episode_end_status(self):
|
||||||
"""
|
|
||||||
Get the current episode end status.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None if episode should continue, "success" or "failure" otherwise
|
|
||||||
"""
|
|
||||||
status = self.episode_end_status
|
status = self.episode_end_status
|
||||||
self.episode_end_status = None # Reset after reading
|
self.episode_end_status = None
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def should_intervene(self):
|
def should_intervene(self):
|
||||||
"""Return True if intervention flag was set."""
|
|
||||||
return self.intervention_flag
|
return self.intervention_flag
|
||||||
|
|
||||||
def gripper_command(self):
|
def gripper_command(self):
|
||||||
"""Return the current gripper command."""
|
|
||||||
if self.open_gripper_command == self.close_gripper_command:
|
if self.open_gripper_command == self.close_gripper_command:
|
||||||
return "stay"
|
return "stay"
|
||||||
elif self.open_gripper_command:
|
elif self.open_gripper_command:
|
||||||
@@ -102,14 +83,14 @@ class KeyboardController(InputController):
|
|||||||
"backward_y": False,
|
"backward_y": False,
|
||||||
"forward_z": False,
|
"forward_z": False,
|
||||||
"backward_z": False,
|
"backward_z": False,
|
||||||
"quit": False,
|
|
||||||
"success": False,
|
"success": False,
|
||||||
"failure": False,
|
"failure": False,
|
||||||
|
"intervention": False,
|
||||||
|
"rerecord": False,
|
||||||
}
|
}
|
||||||
self.listener = None
|
self.listener = None
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Start the keyboard listener."""
|
|
||||||
from pynput import keyboard
|
from pynput import keyboard
|
||||||
|
|
||||||
def on_press(key):
|
def on_press(key):
|
||||||
@@ -126,16 +107,21 @@ class KeyboardController(InputController):
|
|||||||
self.key_states["backward_z"] = True
|
self.key_states["backward_z"] = True
|
||||||
elif key == keyboard.Key.shift_r:
|
elif key == keyboard.Key.shift_r:
|
||||||
self.key_states["forward_z"] = True
|
self.key_states["forward_z"] = True
|
||||||
elif key == keyboard.Key.esc:
|
elif key == keyboard.Key.ctrl_r:
|
||||||
self.key_states["quit"] = True
|
self.open_gripper_command = True
|
||||||
self.running = False
|
elif key == keyboard.Key.ctrl_l:
|
||||||
return False
|
self.close_gripper_command = True
|
||||||
elif key == keyboard.Key.enter:
|
elif key == keyboard.Key.enter:
|
||||||
self.key_states["success"] = True
|
self.key_states["success"] = True
|
||||||
self.episode_end_status = TeleopEvents.SUCCESS
|
self.episode_end_status = TeleopEvents.SUCCESS
|
||||||
elif key == keyboard.Key.backspace:
|
elif key == keyboard.Key.esc:
|
||||||
self.key_states["failure"] = True
|
self.key_states["failure"] = True
|
||||||
self.episode_end_status = TeleopEvents.FAILURE
|
self.episode_end_status = TeleopEvents.FAILURE
|
||||||
|
elif key == keyboard.Key.space:
|
||||||
|
self.key_states["intervention"] = not self.key_states["intervention"]
|
||||||
|
elif hasattr(key, "char") and key.char == "r":
|
||||||
|
self.key_states["rerecord"] = True
|
||||||
|
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -153,10 +139,10 @@ class KeyboardController(InputController):
|
|||||||
self.key_states["backward_z"] = False
|
self.key_states["backward_z"] = False
|
||||||
elif key == keyboard.Key.shift_r:
|
elif key == keyboard.Key.shift_r:
|
||||||
self.key_states["forward_z"] = False
|
self.key_states["forward_z"] = False
|
||||||
elif key == keyboard.Key.enter:
|
elif key == keyboard.Key.ctrl_r:
|
||||||
self.key_states["success"] = False
|
self.open_gripper_command = False
|
||||||
elif key == keyboard.Key.backspace:
|
elif key == keyboard.Key.ctrl_l:
|
||||||
self.key_states["failure"] = False
|
self.close_gripper_command = False
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -165,18 +151,18 @@ class KeyboardController(InputController):
|
|||||||
|
|
||||||
print("Keyboard controls:")
|
print("Keyboard controls:")
|
||||||
print(" Arrow keys: Move in X-Y plane")
|
print(" Arrow keys: Move in X-Y plane")
|
||||||
print(" Shift and Shift_R: Move in Z axis")
|
print(" Shift / Shift_R: Move in Z axis")
|
||||||
|
print(" Ctrl_R / Ctrl_L: Open / Close gripper")
|
||||||
|
print(" Space: Toggle intervention")
|
||||||
print(" Enter: End episode with SUCCESS")
|
print(" Enter: End episode with SUCCESS")
|
||||||
print(" Backspace: End episode with FAILURE")
|
print(" Esc: End episode with FAILURE")
|
||||||
print(" ESC: Exit")
|
print(" R: Rerecord episode")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the keyboard listener."""
|
|
||||||
if self.listener and self.listener.is_alive():
|
if self.listener and self.listener.is_alive():
|
||||||
self.listener.stop()
|
self.listener.stop()
|
||||||
|
|
||||||
def get_deltas(self):
|
def get_deltas(self):
|
||||||
"""Get the current movement deltas from keyboard state."""
|
|
||||||
delta_x = delta_y = delta_z = 0.0
|
delta_x = delta_y = delta_z = 0.0
|
||||||
|
|
||||||
if self.key_states["forward_x"]:
|
if self.key_states["forward_x"]:
|
||||||
@@ -194,18 +180,58 @@ class KeyboardController(InputController):
|
|||||||
|
|
||||||
return delta_x, delta_y, delta_z
|
return delta_x, delta_y, delta_z
|
||||||
|
|
||||||
|
def should_intervene(self):
|
||||||
|
return self.key_states["intervention"]
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for key in self.key_states:
|
||||||
|
self.key_states[key] = False
|
||||||
|
|
||||||
|
|
||||||
class GamepadController(InputController):
|
class GamepadController(InputController):
|
||||||
"""Generate motion deltas from gamepad input."""
|
"""Generate motion deltas from gamepad input using pygame.
|
||||||
|
|
||||||
|
Matches gym-hil button/axis conventions for Linux gamepads, including
|
||||||
|
Xbox mappings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Face buttons (same across most controllers on Linux)
|
||||||
|
BUTTON_A = 0
|
||||||
|
BUTTON_B = 1
|
||||||
|
BUTTON_X = 2
|
||||||
|
BUTTON_Y = 3
|
||||||
|
BUTTON_LB = 4
|
||||||
|
BUTTON_RB = 5
|
||||||
|
# Stick axes
|
||||||
|
AXIS_LEFT_X = 0
|
||||||
|
AXIS_LEFT_Y = 1
|
||||||
|
AXIS_RIGHT_X = 2
|
||||||
|
AXIS_RIGHT_Y = 3
|
||||||
|
|
||||||
|
# Default trigger buttons
|
||||||
|
BUTTON_LT = 6
|
||||||
|
BUTTON_RT = 7
|
||||||
|
|
||||||
|
# Xbox (gym-hil mapping on Linux)
|
||||||
|
XBOX_BUTTON_LT = 9
|
||||||
|
XBOX_BUTTON_RT = 10
|
||||||
|
|
||||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
||||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||||
self.deadzone = deadzone
|
self.deadzone = deadzone
|
||||||
self.joystick = None
|
self.joystick = None
|
||||||
self.intervention_flag = False
|
self.intervention_flag = False
|
||||||
|
self.is_xbox = False
|
||||||
|
self._xbox360_profile = False
|
||||||
|
self._invert_left_x = False
|
||||||
|
self._invert_left_y = True
|
||||||
|
self._invert_right_y = True
|
||||||
|
|
||||||
|
def _detect_xbox(self, name):
|
||||||
|
name_lower = name.lower()
|
||||||
|
return any(tag in name_lower for tag in ["xbox", "microsoft", "x-box"])
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Initialize pygame and the gamepad."""
|
|
||||||
import pygame
|
import pygame
|
||||||
|
|
||||||
pygame.init()
|
pygame.init()
|
||||||
@@ -218,18 +244,35 @@ class GamepadController(InputController):
|
|||||||
|
|
||||||
self.joystick = pygame.joystick.Joystick(0)
|
self.joystick = pygame.joystick.Joystick(0)
|
||||||
self.joystick.init()
|
self.joystick.init()
|
||||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
joystick_name = self.joystick.get_name()
|
||||||
|
self.is_xbox = self._detect_xbox(joystick_name)
|
||||||
|
self._xbox360_profile = joystick_name == "Xbox 360 Controller"
|
||||||
|
if self._xbox360_profile:
|
||||||
|
# gym-hil "Xbox 360 Controller" profile
|
||||||
|
self.AXIS_RIGHT_X = 3
|
||||||
|
self.AXIS_RIGHT_Y = 4
|
||||||
|
self.BUTTON_LT = self.XBOX_BUTTON_LT
|
||||||
|
self.BUTTON_RT = self.XBOX_BUTTON_RT
|
||||||
|
self._invert_left_x = True
|
||||||
|
else:
|
||||||
|
# gym-hil default profile
|
||||||
|
self.AXIS_RIGHT_X = 2
|
||||||
|
self.AXIS_RIGHT_Y = 3
|
||||||
|
self.BUTTON_LT = 6
|
||||||
|
self.BUTTON_RT = 7
|
||||||
|
self._invert_left_x = False
|
||||||
|
logging.info(f"Initialized gamepad: {joystick_name} (xbox={self.is_xbox})")
|
||||||
|
|
||||||
print("Gamepad controls:")
|
print("Gamepad controls:")
|
||||||
print(" Left analog stick: Move in X-Y plane")
|
print(" Left analog stick: Move in X-Y plane")
|
||||||
print(" Right analog stick (vertical): Move in Z axis")
|
print(" Right analog stick (vertical): Move in Z axis")
|
||||||
print(" B/Circle button: Exit")
|
print(" RB: Intervention toggle")
|
||||||
print(" Y/Triangle button: End episode with SUCCESS")
|
print(" LT / RT: Close / Open gripper")
|
||||||
print(" A/Cross button: End episode with FAILURE")
|
print(" Y: End episode with SUCCESS")
|
||||||
print(" X/Square button: Rerecord episode")
|
print(" A: End episode with FAILURE")
|
||||||
|
print(" X: Rerecord episode")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Clean up pygame resources."""
|
|
||||||
import pygame
|
import pygame
|
||||||
|
|
||||||
if pygame.joystick.get_init():
|
if pygame.joystick.get_init():
|
||||||
@@ -239,67 +282,56 @@ class GamepadController(InputController):
|
|||||||
pygame.quit()
|
pygame.quit()
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""Process pygame events to get fresh gamepad readings."""
|
|
||||||
import pygame
|
import pygame
|
||||||
|
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
if event.type == pygame.JOYBUTTONDOWN:
|
if event.type == pygame.JOYBUTTONDOWN:
|
||||||
if event.button == 3:
|
if event.button == self.BUTTON_Y:
|
||||||
self.episode_end_status = TeleopEvents.SUCCESS
|
self.episode_end_status = TeleopEvents.SUCCESS
|
||||||
# A button (1) for failure
|
elif event.button == self.BUTTON_A:
|
||||||
elif event.button == 1:
|
|
||||||
self.episode_end_status = TeleopEvents.FAILURE
|
self.episode_end_status = TeleopEvents.FAILURE
|
||||||
# X button (0) for rerecord
|
elif event.button == self.BUTTON_X:
|
||||||
elif event.button == 0:
|
|
||||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||||
|
elif event.button == self.BUTTON_LT:
|
||||||
# RB button (6) for closing gripper
|
|
||||||
elif event.button == 6:
|
|
||||||
self.close_gripper_command = True
|
self.close_gripper_command = True
|
||||||
|
elif event.button == self.BUTTON_RT:
|
||||||
# LT button (7) for opening gripper
|
|
||||||
elif event.button == 7:
|
|
||||||
self.open_gripper_command = True
|
self.open_gripper_command = True
|
||||||
|
|
||||||
# Reset episode status on button release
|
|
||||||
elif event.type == pygame.JOYBUTTONUP:
|
elif event.type == pygame.JOYBUTTONUP:
|
||||||
if event.button in [0, 2, 3]:
|
if event.button in [self.BUTTON_Y, self.BUTTON_A, self.BUTTON_X]:
|
||||||
self.episode_end_status = None
|
self.episode_end_status = None
|
||||||
|
elif event.button == self.BUTTON_LT:
|
||||||
elif event.button == 6:
|
|
||||||
self.close_gripper_command = False
|
self.close_gripper_command = False
|
||||||
|
elif event.button == self.BUTTON_RT:
|
||||||
elif event.button == 7:
|
|
||||||
self.open_gripper_command = False
|
self.open_gripper_command = False
|
||||||
|
|
||||||
# Check for RB button (typically button 5) for intervention flag
|
if self.joystick.get_button(self.BUTTON_RB):
|
||||||
if self.joystick.get_button(5):
|
|
||||||
self.intervention_flag = True
|
self.intervention_flag = True
|
||||||
else:
|
else:
|
||||||
self.intervention_flag = False
|
self.intervention_flag = False
|
||||||
|
|
||||||
def get_deltas(self):
|
def get_deltas(self):
|
||||||
"""Get the current movement deltas from gamepad state."""
|
|
||||||
import pygame
|
import pygame
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Read joystick axes
|
x_input = self.joystick.get_axis(self.AXIS_LEFT_X)
|
||||||
# Left stick X and Y (typically axes 0 and 1)
|
y_input = self.joystick.get_axis(self.AXIS_LEFT_Y)
|
||||||
y_input = self.joystick.get_axis(0) # Up/Down (often inverted)
|
z_input = self.joystick.get_axis(self.AXIS_RIGHT_Y)
|
||||||
x_input = self.joystick.get_axis(1) # Left/Right
|
|
||||||
|
|
||||||
# Right stick Y (typically axis 3 or 4)
|
|
||||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
|
||||||
|
|
||||||
# Apply deadzone to avoid drift
|
|
||||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||||
|
|
||||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
if self._invert_left_x:
|
||||||
delta_x = -x_input * self.x_step_size # Forward/backward
|
x_input = -x_input
|
||||||
delta_y = -y_input * self.y_step_size # Left/right
|
if self._invert_left_y:
|
||||||
delta_z = -z_input * self.z_step_size # Up/down
|
y_input = -y_input
|
||||||
|
if self._invert_right_y:
|
||||||
|
z_input = -z_input
|
||||||
|
|
||||||
|
delta_x = y_input * self.y_step_size
|
||||||
|
delta_y = x_input * self.x_step_size
|
||||||
|
delta_z = z_input * self.z_step_size
|
||||||
|
|
||||||
return delta_x, delta_y, delta_z
|
return delta_x, delta_y, delta_z
|
||||||
|
|
||||||
@@ -309,7 +341,15 @@ class GamepadController(InputController):
|
|||||||
|
|
||||||
|
|
||||||
class GamepadControllerHID(InputController):
|
class GamepadControllerHID(InputController):
|
||||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
"""Generate motion deltas from gamepad input using HIDAPI.
|
||||||
|
|
||||||
|
Supports auto-detection of controller type for correct HID report parsing.
|
||||||
|
Currently supported: Logitech RumblePad 2, 8BitDo Ultimate 2C Wireless.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CONTROLLER_LOGITECH = "logitech"
|
||||||
|
CONTROLLER_8BITDO = "8bitdo"
|
||||||
|
CONTROLLER_UNKNOWN = "unknown"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -318,36 +358,26 @@ class GamepadControllerHID(InputController):
|
|||||||
z_step_size=1.0,
|
z_step_size=1.0,
|
||||||
deadzone=0.1,
|
deadzone=0.1,
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Initialize the HID gamepad controller.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
step_size: Base movement step size in meters
|
|
||||||
z_scale: Scaling factor for Z-axis movement
|
|
||||||
deadzone: Joystick deadzone to prevent drift
|
|
||||||
"""
|
|
||||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||||
self.deadzone = deadzone
|
self.deadzone = deadzone
|
||||||
self.device = None
|
self.device = None
|
||||||
self.device_info = None
|
self.device_info = None
|
||||||
|
self.controller_type = self.CONTROLLER_UNKNOWN
|
||||||
|
|
||||||
# Movement values (normalized from -1.0 to 1.0)
|
|
||||||
self.left_x = 0.0
|
self.left_x = 0.0
|
||||||
self.left_y = 0.0
|
self.left_y = 0.0
|
||||||
self.right_x = 0.0
|
self.right_x = 0.0
|
||||||
self.right_y = 0.0
|
self.right_y = 0.0
|
||||||
|
|
||||||
# Button states
|
|
||||||
self.buttons = {}
|
self.buttons = {}
|
||||||
|
|
||||||
def find_device(self):
|
def find_device(self):
|
||||||
"""Look for the gamepad device by vendor and product ID."""
|
|
||||||
import hid
|
import hid
|
||||||
|
|
||||||
devices = hid.enumerate()
|
devices = hid.enumerate()
|
||||||
for device in devices:
|
for device in devices:
|
||||||
device_name = device["product_string"]
|
device_name = device["product_string"]
|
||||||
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]):
|
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5", "8BitDo"]):
|
||||||
return device
|
return device
|
||||||
|
|
||||||
logging.error(
|
logging.error(
|
||||||
@@ -355,8 +385,15 @@ class GamepadControllerHID(InputController):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _detect_controller_type(self, product_string):
|
||||||
|
product = product_string.lower() if product_string else ""
|
||||||
|
if "8bitdo" in product:
|
||||||
|
return self.CONTROLLER_8BITDO
|
||||||
|
elif "logitech" in product:
|
||||||
|
return self.CONTROLLER_LOGITECH
|
||||||
|
return self.CONTROLLER_UNKNOWN
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""Connect to the gamepad using HIDAPI."""
|
|
||||||
import hid
|
import hid
|
||||||
|
|
||||||
self.device_info = self.find_device()
|
self.device_info = self.find_device()
|
||||||
@@ -374,12 +411,22 @@ class GamepadControllerHID(InputController):
|
|||||||
product = self.device.get_product_string()
|
product = self.device.get_product_string()
|
||||||
logging.info(f"Connected to {manufacturer} {product}")
|
logging.info(f"Connected to {manufacturer} {product}")
|
||||||
|
|
||||||
logging.info("Gamepad controls (HID mode):")
|
self.controller_type = self._detect_controller_type(product)
|
||||||
logging.info(" Left analog stick: Move in X-Y plane")
|
logging.info(f"Detected controller type: {self.controller_type}")
|
||||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
|
||||||
logging.info(" Button 1/B/Circle: Exit")
|
print("Gamepad controls (HID mode):")
|
||||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
print(" Left analog stick: Move in X-Y plane")
|
||||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
print(" Right analog stick: Move in Z axis (vertical)")
|
||||||
|
print(" RB: Intervention toggle")
|
||||||
|
if self.controller_type == self.CONTROLLER_8BITDO:
|
||||||
|
print(" L3 (left stick click): Close gripper")
|
||||||
|
print(" R3 (right stick click): Open gripper")
|
||||||
|
else:
|
||||||
|
print(" LT: Close gripper")
|
||||||
|
print(" RT: Open gripper")
|
||||||
|
print(" Y: End episode with SUCCESS")
|
||||||
|
print(" X: End episode with FAILURE")
|
||||||
|
print(" A: Rerecord episode")
|
||||||
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logging.error(f"Error opening gamepad: {e}")
|
logging.error(f"Error opening gamepad: {e}")
|
||||||
@@ -387,74 +434,124 @@ class GamepadControllerHID(InputController):
|
|||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Close the HID device connection."""
|
|
||||||
if self.device:
|
if self.device:
|
||||||
self.device.close()
|
self.device.close()
|
||||||
self.device = None
|
self.device = None
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
"""
|
"""Read the device several times to drain the HID buffer and get a stable reading."""
|
||||||
Read and process the latest gamepad data.
|
|
||||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
|
||||||
"""
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
self._update()
|
self._update()
|
||||||
|
|
||||||
def _update(self):
|
def _update(self):
|
||||||
"""Read and process the latest gamepad data."""
|
|
||||||
if not self.device or not self.running:
|
if not self.device or not self.running:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Read data from the gamepad
|
|
||||||
data = self.device.read(64)
|
data = self.device.read(64)
|
||||||
# Interpret gamepad data - this will vary by controller model
|
if not data:
|
||||||
# These offsets are for the Logitech RumblePad 2
|
return
|
||||||
if data and len(data) >= 8:
|
|
||||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
|
||||||
self.left_y = (data[1] - 128) / 128.0
|
|
||||||
self.left_x = (data[2] - 128) / 128.0
|
|
||||||
self.right_x = (data[3] - 128) / 128.0
|
|
||||||
self.right_y = (data[4] - 128) / 128.0
|
|
||||||
|
|
||||||
# Apply deadzone
|
if self.controller_type == self.CONTROLLER_8BITDO:
|
||||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
self._parse_8bitdo(data)
|
||||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
else:
|
||||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
self._parse_logitech(data)
|
||||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
|
||||||
|
|
||||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
|
||||||
buttons = data[5]
|
|
||||||
|
|
||||||
# Check if RB is pressed then the intervention flag should be set
|
|
||||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
|
||||||
|
|
||||||
# Check if RT is pressed
|
|
||||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
|
||||||
|
|
||||||
# Check if LT is pressed
|
|
||||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
|
||||||
|
|
||||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
|
||||||
# Check if X/Square button (bit 5) is pressed for failure
|
|
||||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
|
||||||
if buttons & 1 << 7:
|
|
||||||
self.episode_end_status = TeleopEvents.SUCCESS
|
|
||||||
elif buttons & 1 << 5:
|
|
||||||
self.episode_end_status = TeleopEvents.FAILURE
|
|
||||||
elif buttons & 1 << 4:
|
|
||||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
|
||||||
else:
|
|
||||||
self.episode_end_status = None
|
|
||||||
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logging.error(f"Error reading from gamepad: {e}")
|
logging.error(f"Error reading from gamepad: {e}")
|
||||||
|
|
||||||
|
def _apply_deadzone(self):
|
||||||
|
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||||
|
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||||
|
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||||
|
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||||
|
|
||||||
|
def _parse_8bitdo(self, data):
|
||||||
|
"""Parse HID report from 8BitDo Ultimate 2C Wireless (Bluetooth on macOS).
|
||||||
|
|
||||||
|
11-byte report layout:
|
||||||
|
byte[0]: Report ID (0x01)
|
||||||
|
byte[1]: D-pad hat switch (0=N, 2=E, 5=S, 6=W, 15=neutral)
|
||||||
|
byte[2]: Left Stick X (0=left, 127=center, 255=right)
|
||||||
|
byte[3]: Left Stick Y (0=up, 127=center, 255=down)
|
||||||
|
byte[4]: Right Stick X (inverted: 255=left, 0=right)
|
||||||
|
byte[5]: Right Stick Y (0=up, 127=center, 255=down)
|
||||||
|
byte[6]: RT analog trigger (0-255)
|
||||||
|
byte[7]: LT analog trigger (0-255)
|
||||||
|
byte[8]: Buttons -- bit0=A, bit1=B, bit3=X, bit4=Y, bit6=LB, bit7=RB
|
||||||
|
byte[9]: System -- bit0=LT(digital), bit1=RT(digital), bit3=Select,
|
||||||
|
bit4=Start, bit5=L3, bit6=R3
|
||||||
|
byte[10]: Unused
|
||||||
|
"""
|
||||||
|
if len(data) < 11:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.left_x = (data[2] - 127) / 128.0
|
||||||
|
self.left_y = (data[3] - 127) / 128.0
|
||||||
|
self.right_x = -(data[4] - 127) / 128.0
|
||||||
|
self.right_y = (data[5] - 127) / 128.0
|
||||||
|
|
||||||
|
self._apply_deadzone()
|
||||||
|
|
||||||
|
buttons = data[8]
|
||||||
|
|
||||||
|
# RB (bit 7) = intervention
|
||||||
|
self.intervention_flag = bool(buttons & 0x80)
|
||||||
|
|
||||||
|
# Stick clicks for gripper: R3 (byte[9] bit6) = open, L3 (byte[9] bit5) = close
|
||||||
|
system = data[9]
|
||||||
|
self.open_gripper_command = bool(system & 0x40) # R3
|
||||||
|
self.close_gripper_command = bool(system & 0x20) # L3
|
||||||
|
|
||||||
|
# Y (bit 4) = success, X (bit 3) = failure, A (bit 0) = rerecord
|
||||||
|
if buttons & 0x10:
|
||||||
|
self.episode_end_status = TeleopEvents.SUCCESS
|
||||||
|
elif buttons & 0x08:
|
||||||
|
self.episode_end_status = TeleopEvents.FAILURE
|
||||||
|
elif buttons & 0x01:
|
||||||
|
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||||
|
else:
|
||||||
|
self.episode_end_status = None
|
||||||
|
|
||||||
|
def _parse_logitech(self, data):
|
||||||
|
"""Parse HID report from Logitech RumblePad 2 (and similar Logitech gamepads).
|
||||||
|
|
||||||
|
Report layout (8+ bytes):
|
||||||
|
byte[1]: Left Stick X (0-255, center=128)
|
||||||
|
byte[2]: Left Stick Y (0-255, center=128)
|
||||||
|
byte[3]: Right Stick X (0-255, center=128)
|
||||||
|
byte[4]: Right Stick Y (0-255, center=128)
|
||||||
|
byte[5]: Face buttons bitmask
|
||||||
|
byte[6]: Shoulder/trigger buttons bitmask
|
||||||
|
"""
|
||||||
|
if len(data) < 8:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.left_x = (data[1] - 128) / 128.0
|
||||||
|
self.left_y = (data[2] - 128) / 128.0
|
||||||
|
self.right_x = (data[3] - 128) / 128.0
|
||||||
|
self.right_y = (data[4] - 128) / 128.0
|
||||||
|
|
||||||
|
self._apply_deadzone()
|
||||||
|
|
||||||
|
buttons = data[5]
|
||||||
|
|
||||||
|
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||||
|
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||||
|
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||||
|
|
||||||
|
if buttons & 1 << 7:
|
||||||
|
self.episode_end_status = TeleopEvents.SUCCESS
|
||||||
|
elif buttons & 1 << 5:
|
||||||
|
self.episode_end_status = TeleopEvents.FAILURE
|
||||||
|
elif buttons & 1 << 4:
|
||||||
|
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||||
|
else:
|
||||||
|
self.episode_end_status = None
|
||||||
|
|
||||||
def get_deltas(self):
|
def get_deltas(self):
|
||||||
"""Get the current movement deltas from gamepad state."""
|
delta_x = -self.left_y * self.x_step_size
|
||||||
# Calculate deltas - invert as needed based on controller orientation
|
delta_y = -self.left_x * self.y_step_size
|
||||||
delta_x = -self.left_x * self.x_step_size # Forward/backward
|
delta_z = -self.right_y * self.z_step_size
|
||||||
delta_y = -self.left_y * self.y_step_size # Left/right
|
|
||||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
|
||||||
|
|
||||||
return delta_x, delta_y, delta_z
|
return delta_x, delta_y, delta_z
|
||||||
|
|||||||
Reference in New Issue
Block a user