mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Compare commits
11 Commits
feat/umi
...
feat/datas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
026e4c937d | ||
|
|
efe8c09fca | ||
|
|
58eecad8a4 | ||
|
|
c7fd1f47d1 | ||
|
|
6370949e5c | ||
|
|
46b97da168 | ||
|
|
e69be57a66 | ||
|
|
c3a6ddb668 | ||
|
|
dad661012d | ||
|
|
219c08ccb8 | ||
|
|
06385902df |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -173,5 +173,7 @@ outputs/
|
||||
|
||||
# Dev folders
|
||||
.cache/*
|
||||
*.stl
|
||||
*.urdf
|
||||
*.xml
|
||||
*.part
|
||||
|
||||
10
README.md
10
README.md
@@ -100,11 +100,11 @@ lerobot-train \
|
||||
--dataset.repo_id=lerobot/aloha_mobile_cabinet
|
||||
```
|
||||
|
||||
| Category | Models |
|
||||
| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
| Category | Models |
|
||||
| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) |
|
||||
| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) |
|
||||
| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) |
|
||||
|
||||
Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub
|
||||
|
||||
|
||||
@@ -19,10 +19,6 @@
|
||||
title: Multi GPU training
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
- local: rename_map
|
||||
title: Using Rename Map and Empty Cameras
|
||||
- local: umi_pi0_relative_ee
|
||||
title: UMI Data with pi0 Relative EE Actions
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
@@ -51,8 +47,6 @@
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
- local: multi_task_dit
|
||||
title: Multitask DiT Policy
|
||||
- local: walloss
|
||||
title: WALL-OSS
|
||||
title: "Policies"
|
||||
@@ -89,8 +83,6 @@
|
||||
title: Processors for Robots and Teleoperators
|
||||
- local: env_processor
|
||||
title: Environment Processors
|
||||
- local: action_representations
|
||||
title: Action Representations
|
||||
title: "Robot Processors"
|
||||
- sections:
|
||||
- local: so101
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
# Action Representations
|
||||
|
||||
This guide explains the different ways robot actions can be represented in LeRobot, how they relate to each other, and when to use each one.
|
||||
|
||||
## Joint Space vs End-Effector Space
|
||||
|
||||
Before discussing action representations, it helps to understand the two coordinate spaces actions can live in.
|
||||
|
||||
### Joint Space
|
||||
|
||||
Joint-space actions directly specify target positions for each motor. For a 6-DOF arm with a gripper, a joint-space action might look like:
|
||||
|
||||
```
|
||||
action = [shoulder_pan: 45.0, shoulder_lift: -20.0, elbow: -30.0, wrist_pitch: 10.0, wrist_roll: 0.0, wrist_yaw: 5.0, gripper: 0.8]
|
||||
```
|
||||
|
||||
Joint space is the default in LeRobot. It is simple, requires no kinematics model, and maps directly to motor commands. Most beginner setups (SO-100, Koch) use joint-space actions.
|
||||
|
||||
### End-Effector (EE) Space
|
||||
|
||||
End-effector-space actions specify the desired position and orientation of the robot's tool tip (gripper) in Cartesian coordinates:
|
||||
|
||||
```
|
||||
action = [x: 0.25, y: -0.10, z: 0.15, wx: 0.0, wy: 0.0, wz: 0.1, gripper: 0.8]
|
||||
```
|
||||
|
||||
EE space is more intuitive for tasks like pick-and-place because it directly describes where the gripper should go, but it requires a kinematics model (URDF) to convert between EE poses and joint angles.
|
||||
|
||||
### Converting Between Spaces
|
||||
|
||||
LeRobot provides processor steps for converting between joint and EE spaces using forward and inverse kinematics. These are built on top of `RobotKinematics`, which loads a URDF model of your robot.
|
||||
|
||||
```python
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.robots.so_follower.robot_kinematic_processor import (
|
||||
ForwardKinematicsJointsToEE,
|
||||
InverseKinematicsEEToJoints,
|
||||
)
|
||||
|
||||
kinematics = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=["shoulder", "elbow", "wrist_pitch", "wrist_roll", "wrist_yaw"],
|
||||
)
|
||||
|
||||
# Joints → EE (for observations: "where is my gripper?")
|
||||
fk_step = ForwardKinematicsJointsToEE(kinematics=kinematics, motor_names=[...])
|
||||
|
||||
# EE → Joints (for actions: "move my gripper here")
|
||||
ik_step = InverseKinematicsEEToJoints(kinematics=kinematics, motor_names=[...])
|
||||
```
|
||||
|
||||
See [`examples/so100_to_so100_EE/`](https://github.com/huggingface/lerobot/tree/main/examples/so100_to_so100_EE) for a complete working example of recording, replaying, and evaluating with EE-space actions on an SO-100 arm.
|
||||
|
||||
## Absolute, Relative, and Delta Actions
|
||||
|
||||
Regardless of whether you work in joint space or EE space, the action values can be expressed in three different ways. The terminology follows [UMI (Chi et al., 2024)](https://arxiv.org/abs/2402.10329).
|
||||
|
||||
### Absolute Actions (LeRobot default)
|
||||
|
||||
Each action specifies the target position directly.
|
||||
|
||||
**Example** (joint space, chunk of 4):
|
||||
|
||||
```
|
||||
current_state = [45.0, -30.0, 10.0]
|
||||
|
||||
action_chunk = [
|
||||
[46.0, -29.0, 11.0], # go to 46, -29, 11
|
||||
[47.5, -27.0, 12.0], # go to 47.5, -27, 12
|
||||
[49.0, -25.0, 13.5], # go to 49, -25, 13.5
|
||||
[50.0, -24.0, 15.0], # go to 50, -24, 15
|
||||
]
|
||||
```
|
||||
|
||||
Each value is a target position in the robot's coordinate frame. Simple and direct, but requires a consistent global coordinate frame. This is the default in LeRobot.
|
||||
|
||||
### Relative Actions (used by OpenPI / pi0)
|
||||
|
||||
Each action in the chunk is an offset from the **current state at the moment of prediction**. All actions in the chunk share the same reference point:
|
||||
|
||||
```
|
||||
current_state = [45.0, -30.0, 10.0]
|
||||
|
||||
relative_chunk = [
|
||||
[1.0, 1.0, 1.0], # +1 from current → target 46, -29, 11
|
||||
[2.5, 3.0, 2.0], # +2.5 from current → target 47.5, -27, 12
|
||||
[4.0, 5.0, 3.5], # +4 from current → target 49, -25, 13.5
|
||||
[5.0, 6.0, 5.0], # +5 from current → target 50, -24, 15
|
||||
]
|
||||
```
|
||||
|
||||
The conversion is straightforward: `relative = absolute - current_state`. To recover absolute: `absolute = relative + current_state`.
|
||||
|
||||
**Why use relative actions?** The model learns to predict offsets centered around zero, which is easier to normalize and leads to more stable training. Because every chunk references the same current state, there is no error accumulation across chunks.
|
||||
|
||||
### Delta Actions (sequential differences)
|
||||
|
||||
Each action is an offset from the **previous action** (or from the current state for the first step):
|
||||
|
||||
```
|
||||
current_state = [45.0, -30.0, 10.0]
|
||||
|
||||
delta_chunk = [
|
||||
[1.0, 1.0, 1.0], # current → 46, -29, 11
|
||||
[1.5, 2.0, 1.0], # previous action → 47.5, -27, 12
|
||||
[1.5, 2.0, 1.5], # previous action → 49, -25, 13.5
|
||||
[1.0, 1.0, 1.5], # previous action → 50, -24, 15
|
||||
]
|
||||
```
|
||||
|
||||
Here each step is relative to the one before it. To recover absolute positions you must sum all previous deltas, which means errors accumulate over time. UMI explicitly argues against this representation for this reason.
|
||||
|
||||
### Visual Comparison
|
||||
|
||||
The figure below (based on a figure from [UMI, Chi et al., 2024](https://arxiv.org/abs/2402.10329)) illustrates the key difference. With **relative trajectory**, every action in the chunk points back to the same origin (current state), so a new inference step cleanly resets the reference. With **delta**, each action depends on the previous one, so errors accumulate. **Absolute** actions require a consistent global coordinate frame.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/action_representations_umi.png"
|
||||
alt="Relative Trajectory as Action Representation (UMI, Chi et al., 2024)"
|
||||
width="85%"
|
||||
/>
|
||||
|
||||
## Using Relative Actions in LeRobot
|
||||
|
||||
LeRobot provides `RelativeActionsProcessorStep` to convert between absolute and relative actions inside the processor pipeline. This is how pi0, pi0.5, and pi0_fast support relative actions.
|
||||
|
||||
> **Note:** All pi models (pi0, pi0.5, pi0*fast) apply relative conversion \_before* normalization (`relative → normalize`), so the normalizer always sees delta (relative) values. This means **relative action stats are required** for all of them when training with `use_relative_actions=true`. In pi0_fast the `RelativeActionsProcessorStep` only modifies the action — the state observation is unchanged — so `NormalizerProcessorStep` still runs before the state tokenizer and the tokenizer continues to receive normalized state as expected.
|
||||
|
||||
### How it works
|
||||
|
||||
During **training** (preprocessing), actions are converted from absolute to relative before the model sees them:
|
||||
|
||||
```
|
||||
raw absolute action → RelativeActionsProcessorStep → normalize → model
|
||||
```
|
||||
|
||||
During **inference** (postprocessing), model predictions are converted back to absolute before being sent to the robot:
|
||||
|
||||
```
|
||||
model output → unnormalize → AbsoluteActionsProcessorStep → robot
|
||||
```
|
||||
|
||||
The `AbsoluteActionsProcessorStep` reads the cached current state from its paired `RelativeActionsProcessorStep`, so the two must be wired together (handled automatically by the policy factory).
|
||||
|
||||
### Enabling relative actions for the pi family (pi0, pi0.5, pi0_fast)
|
||||
|
||||
**Step 1**: Precompute relative action statistics for your dataset:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
--repo_id your_dataset \
|
||||
--operation.type recompute_stats \
|
||||
--operation.relative_action true \
|
||||
--operation.chunk_size 50 \
|
||||
--operation.relative_exclude_joints "['gripper']"
|
||||
```
|
||||
|
||||
**Step 2**: Train with relative actions enabled:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]'
|
||||
```
|
||||
|
||||
The `relative_exclude_joints` parameter specifies joints that should remain in absolute space. For example, gripper commands are typically binary (open/close) and don't benefit from relative encoding.
|
||||
|
||||
### Combining relative actions with RTC
|
||||
|
||||
[RTC](https://arxiv.org/abs/2506.07339) runs policy inference at high frequency and sends actions to the robot as they are predicted rather than waiting for a full chunk. Relative actions and RTC are fully compatible: because every chunk in relative mode references the **same** current state (captured at the start of inference), each predicted action in the chunk remains a valid offset even if the robot has already moved. No special handling is needed — `RelativeActionsProcessorStep` caches the state once per inference call and `AbsoluteActionsProcessorStep` applies it to every action in the streamed output.
|
||||
|
||||
### Combining relative actions with EE space
|
||||
|
||||
Relative actions work in both joint space and EE space. For example, if your dataset stores EE actions, relative encoding converts them to offsets from the current EE pose:
|
||||
|
||||
```
|
||||
current_ee_state = [x: 0.25, y: -0.10, z: 0.15, gripper: 0.8]
|
||||
|
||||
absolute_ee_chunk = [
|
||||
[0.26, -0.09, 0.16, 0.8],
|
||||
[0.28, -0.07, 0.18, 0.8],
|
||||
]
|
||||
|
||||
relative_ee_chunk = [
|
||||
[0.01, 0.01, 0.01, 0.0], # offset from current EE pose
|
||||
[0.03, 0.03, 0.03, 0.0], # offset from current EE pose
|
||||
]
|
||||
```
|
||||
|
||||
## Processing Pipeline Summary
|
||||
|
||||
Here is how the different processors compose. Each arrow is a processor step, and they can be chained in a `RobotProcessorPipeline` or `PolicyProcessorPipeline`:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
Action Space │ Joint Space ←──IK──→ EE Space │
|
||||
│ ForwardKinematicsJointsToEE │
|
||||
│ InverseKinematicsEEToJoints │
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────┐
|
||||
State Derivation │ Action column ────→ State + Action │
|
||||
│ DeriveStateFromActionStep (pre only) │
|
||||
│ (UMI-style: state from action chunk) │
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────┐
|
||||
Action Repr. │ Absolute ←────→ Relative │
|
||||
│ RelativeActionsProcessorStep (pre) │
|
||||
│ AbsoluteActionsProcessorStep (post) │
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────┐
|
||||
State Repr. │ Absolute ────→ Relative │
|
||||
│ RelativeStateProcessorStep (pre only) │
|
||||
└─────────────────────────────────────────┘
|
||||
|
||||
┌─────────────────────────────────────────┐
|
||||
Normalization │ Raw ←────→ Normalized │
|
||||
│ NormalizerProcessorStep (pre) │
|
||||
│ UnnormalizerProcessorStep (post) │
|
||||
└─────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
A typical training preprocessor might chain: `raw absolute joint actions → relative → normalize`. A typical inference postprocessor: `unnormalize → absolute → (optionally IK to joints)`.
|
||||
|
||||
With UMI-style relative proprioception (`use_relative_state=True`), the preprocessor also converts observation.state to offsets from the current timestep via `RelativeStateProcessorStep` before normalization. This is a pre-processing-only step (state is an input, not an output).
|
||||
|
||||
With `derive_state_from_action=True`, the preprocessor first runs `DeriveStateFromActionStep` to extract a 2-step state from the extended action chunk. This enables full UMI-style training without a separate `observation.state` column. See the [UMI pi0 guide](umi_pi0_relative_ee) for details.
|
||||
|
||||
## References
|
||||
|
||||
- [Universal Manipulation Interface (UMI)](https://arxiv.org/abs/2402.10329) - Chi et al., 2024. Defines the relative trajectory action representation and compares it with absolute and delta actions.
|
||||
- [Introduction to Processors](./introduction_processors) - How processor pipelines work in LeRobot.
|
||||
- [`examples/so100_to_so100_EE/`](https://github.com/huggingface/lerobot/tree/main/examples/so100_to_so100_EE) - Complete example of recording and evaluating with EE-space actions.
|
||||
@@ -310,4 +310,4 @@ Asynchronous inference represents a significant advancement in real-time robotic
|
||||
- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA
|
||||
|
||||
Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case.
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/huggingface/lerobot/issues).
|
||||
If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues).
|
||||
|
||||
@@ -41,15 +41,13 @@ requires = # your-build-system
|
||||
|
||||
## Step 2: Define the Policy Configuration
|
||||
|
||||
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
|
||||
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
|
||||
|
||||
```python
|
||||
# configuration_my_custom_policy.py
|
||||
from dataclasses import dataclass, field
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||
@dataclass
|
||||
@@ -63,56 +61,22 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
hidden_dim: Hidden dimension for the policy network
|
||||
# Add your policy-specific parameters here
|
||||
"""
|
||||
|
||||
horizon: int = 50
|
||||
n_action_steps: int = 50
|
||||
hidden_dim: int = 256
|
||||
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
# ...PreTrainedConfig fields...
|
||||
pass
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("n_action_steps cannot exceed horizon")
|
||||
# Add any validation logic here
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input/output feature compatibility."""
|
||||
if not self.image_features:
|
||||
raise ValueError("MyCustomPolicy requires at least one image feature.")
|
||||
if self.action_feature is None:
|
||||
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
"""Relative timestep offsets the dataset loader provides per observation.
|
||||
|
||||
Return `None` for single-frame policies. For temporal policies that consume
|
||||
multiple past or future frames, return a list of offsets, e.g. `[-20, -10, 0, 10]` for
|
||||
3 past frames at stride 10 and 1 future frame at stride 10.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Relative timestep offsets for the action chunk the dataset loader returns.
|
||||
"""
|
||||
return list(range(self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
# Implement validation logic for your policy's requirements
|
||||
pass
|
||||
```
|
||||
|
||||
## Step 3: Implement the Policy Class
|
||||
|
||||
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
|
||||
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
|
||||
|
||||
```python
|
||||
# modeling_my_custom_policy.py
|
||||
@@ -121,73 +85,37 @@ import torch.nn as nn
|
||||
from typing import Any
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
|
||||
class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
|
||||
config_class = MyCustomPolicyConfig
|
||||
name = "my_custom_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features() # not called automatically by the base class
|
||||
self.config = config
|
||||
self.model = ... # your nn.Module here
|
||||
|
||||
def reset(self):
|
||||
"""Reset episode state."""
|
||||
...
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
"""Return parameters to pass to the optimizer (e.g. with per-group lr/wd)."""
|
||||
return {"params": self.parameters()}
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||
"""Return the full action chunk (B, chunk_size, action_dim) for the current observation."""
|
||||
...
|
||||
|
||||
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||
"""Return a single action for the current timestep (called at inference)."""
|
||||
...
|
||||
|
||||
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Compute the training loss.
|
||||
|
||||
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
||||
timesteps padded because the episode ended before `horizon` steps, you
|
||||
can exclude those from your loss.
|
||||
"""
|
||||
actions = batch[ACTION]
|
||||
action_is_pad = batch.get("action_is_pad")
|
||||
...
|
||||
return {"loss": ...}
|
||||
```
|
||||
|
||||
## Step 4: Add Data Processors
|
||||
|
||||
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||
Create processor functions:
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
from typing import Any
|
||||
import torch
|
||||
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
|
||||
def make_my_custom_policy_pre_post_processors(
|
||||
config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
preprocessor = ... # build your PolicyProcessorPipeline for inputs
|
||||
postprocessor = ... # build your PolicyProcessorPipeline for outputs
|
||||
return preprocessor, postprocessor
|
||||
```
|
||||
"""Create preprocessing and postprocessing functions for your policy."""
|
||||
pass # Define your preprocessing and postprocessing logic here
|
||||
|
||||
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||
```
|
||||
|
||||
## Step 5: Package Initialization
|
||||
|
||||
|
||||
@@ -424,7 +424,7 @@ robot = SO100Follower(robot_config)
|
||||
robot.connect()
|
||||
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[episode_idx])
|
||||
actions = dataset.select_columns("action")
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
log_say(f"Replaying episode {episode_idx}")
|
||||
for idx in range(dataset.num_frames):
|
||||
|
||||
@@ -1,340 +0,0 @@
|
||||
# Multitask DiT Policy
|
||||
|
||||
Multitask Diffusion Transformer (DiT) Policy is an evolution of the original Diffusion Policy architecture, which leverages a large DiT with text and vision conditioning for multitask robot learning. This implementation supports both diffusion and flow matching objectives for action generation, enabling robots to perform diverse manipulation tasks conditioned on language instructions.
|
||||
|
||||
## Model Overview
|
||||
|
||||
The model uses:
|
||||
|
||||
- **CLIP Vision Encoder**: Processes RGB images from multiple camera views
|
||||
- **CLIP Text Encoder**: Encodes language task instructions (frozen weights with learnable projection)
|
||||
- **Diffusion Transformer**: Predicts action sequences conditioned on observations and language
|
||||
- **Two Objectives**: Supports both diffusion (DDPM/DDIM) and flow matching for action generation
|
||||
|
||||
This model is exciting because you can achieve extremely high dexterity, competitive with multi-billion parameter
|
||||
VLAs, with only ~450M parameters and significantly less training.
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
Multitask DiT Policy has additional dependencies. Install it with:
|
||||
|
||||
```bash
|
||||
pip install lerobot[multi_task_dit]
|
||||
```
|
||||
|
||||
This will install all necessary dependencies including the HuggingFace Transformers library for CLIP models.
|
||||
|
||||
## Usage
|
||||
|
||||
To use Multitask DiT in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```python
|
||||
policy.type=multi_task_dit
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Basic Training Command
|
||||
|
||||
Here's a complete training command for training Multitask DiT on your dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/multitask_dit_training \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--save_freq=500 \
|
||||
--log_freq=100 \
|
||||
--policy.type=multi_task_dit \
|
||||
--policy.device=cuda \
|
||||
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
### Recommended Hyperparameters and Dataset Details (30Hz Control Frequency)
|
||||
|
||||
For reliable performance, start with these suggested default hyperparameters:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/mutitask_dit_training \
|
||||
--batch_size=320 \
|
||||
--steps=30000 \
|
||||
--policy.type=multi_task_dit \
|
||||
--policy.device=cuda \
|
||||
--policy.horizon=32 \
|
||||
--policy.n_action_steps=24 \
|
||||
--policy.objective=diffusion \
|
||||
--policy.noise_scheduler_type=DDPM \
|
||||
--policy.num_train_timesteps=100 \
|
||||
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
|
||||
- **Batch Size**: 192-320 - If you have access to a GPU that can support this, you will get the best training dynamics
|
||||
- **Horizon**: 32 - number of action steps to predict, ~1.0 sec at 30Hz
|
||||
- **n_action_steps**: 24 - ~0.8 seconds at 30Hz
|
||||
- **Objective**: `diffusion` - start with diffusion and experiment with flow matching if generation quality is poor
|
||||
- **Training Steps**: >30k steps recommended for a single task
|
||||
|
||||
### Training Configuration Parameters
|
||||
|
||||
#### Objective Selection
|
||||
|
||||
Choose between diffusion and flow matching:
|
||||
|
||||
```bash
|
||||
# Diffusion objective (default)
|
||||
--policy.objective=diffusion \
|
||||
--policy.noise_scheduler_type=DDPM \ # or "DDIM"
|
||||
--policy.num_train_timesteps=100 \
|
||||
--policy.num_inference_steps=10 \ # For faster inference
|
||||
--policy.beta_schedule=squaredcos_cap_v2 \ # Noise schedule type
|
||||
--policy.prediction_type=epsilon \ # "epsilon" (predict noise) or "sample" (predict clean)
|
||||
--policy.clip_sample=true \ # Clip samples during denoising
|
||||
--policy.clip_sample_range=1.0 # Clipping range [-x, x]
|
||||
|
||||
# Flow matching objective
|
||||
--policy.objective=flow_matching \
|
||||
--policy.timestep_sampling_strategy=beta \ # or "uniform" | the beta sampling strategy performance appears much better in practice
|
||||
--policy.num_integration_steps=100 \
|
||||
--policy.integration_method=euler \ # or "rk4"
|
||||
--policy.sigma_min=0.0 # Minimum noise in flow interpolation path
|
||||
```
|
||||
|
||||
#### Transformer Architecture
|
||||
|
||||
Adjust model capacity based on dataset size:
|
||||
|
||||
```bash
|
||||
# Small datasets (< 100 examples)
|
||||
--policy.num_layers=4 \
|
||||
--policy.hidden_dim=512 \
|
||||
--policy.num_heads=8 # should ideally be hidden_dim // 64
|
||||
|
||||
# Medium datasets (100-5k examples) - default
|
||||
--policy.num_layers=6 \
|
||||
--policy.hidden_dim=512 \
|
||||
--policy.num_heads=8 # should ideally be hidden_dim // 64
|
||||
|
||||
# Large datasets (> 5k examples)
|
||||
--policy.num_layers=8 \
|
||||
--policy.hidden_dim=512 \
|
||||
--policy.num_heads=8 # should ideally be hidden_dim // 64
|
||||
```
|
||||
|
||||
**Positional Encoding Options:**
|
||||
|
||||
The model supports two positional encoding methods for action sequences:
|
||||
|
||||
```bash
|
||||
# Rotary Position Embedding (RoPE) - default, recommended
|
||||
--policy.use_rope=true \
|
||||
--policy.rope_base=10000.0 # Base frequency for RoPE
|
||||
|
||||
# Absolute positional encoding
|
||||
--policy.use_positional_encoding=true # Disables RoPE when true
|
||||
```
|
||||
|
||||
**Other Transformer Parameters:**
|
||||
|
||||
```bash
|
||||
--policy.dropout=0.1 # Dropout rate for DiT blocks (0.0-1.0)
|
||||
--policy.timestep_embed_dim=256 # Timestep embedding dimension
|
||||
```
|
||||
|
||||
#### Vision Encoder Configuration
|
||||
|
||||
```bash
|
||||
# Use different CLIP model for more expressivity at the cost of inference time
|
||||
# experiment with larger or smaller models depending on the complexity of your tasks and size of dataset
|
||||
--policy.vision_encoder_name=openai/clip-vit-large-patch14
|
||||
|
||||
# Use separate vision encoder per camera
|
||||
# This may be useful when cameras have significantly different characteristics, but
|
||||
# be wary of increased VRAM footprint.
|
||||
--policy.use_separate_rgb_encoder_per_camera=true
|
||||
|
||||
# Image preprocessing
|
||||
--policy.image_resize_shape=[XXX,YYY] \ # you may need to resize your images for inference speed ups
|
||||
--policy.image_crop_shape=[224,224] \
|
||||
--policy.image_crop_is_random=true # Random during training, center at inference
|
||||
```
|
||||
|
||||
#### Text Encoder Configuration
|
||||
|
||||
```bash
|
||||
# Use different CLIP text encoder model
|
||||
# same as vision: experiment with larger or smaller models depending on the
|
||||
# complexity of your tasks and size of dataset
|
||||
--policy.text_encoder_name=openai/clip-vit-large-patch14
|
||||
```
|
||||
|
||||
#### Learning Rate Configuration
|
||||
|
||||
The vision encoder uses a separate learning rate multiplier, where 1/10th is suggested to be the ideal staritng point:
|
||||
|
||||
```bash
|
||||
--policy.optimizer_lr=2e-5 \
|
||||
--policy.vision_encoder_lr_multiplier=0.1 # Vision encoder LR = 0.1 * optimizer_lr
|
||||
```
|
||||
|
||||
### Training Tuning Guidelines
|
||||
|
||||
#### 1. Flow Matching with Beta Sampling
|
||||
|
||||
The original diffusion implementation here is based on the work described in [TRI's LBM paper](https://arxiv.org/abs/2507.05331)
|
||||
|
||||
Additionally, we have implemented a flow-matching objective, which is described at a high-level in [Boston Dynamics blog post](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/).
|
||||
|
||||
Consider testing the flow-matching objective and evaluating performance differences for your task:
|
||||
|
||||
```bash
|
||||
--policy.objective=flow_matching \
|
||||
--policy.timestep_sampling_strategy=beta \
|
||||
--policy.timestep_sampling_alpha=1.5 \
|
||||
--policy.timestep_sampling_beta=1.0 \
|
||||
--policy.timestep_sampling_s=0.999
|
||||
```
|
||||
|
||||
This hasn't been shown to be a silver bullet across every user case, but it occasionally results in smoother and more consistent actions.
|
||||
|
||||
#### 2. Number of Transformer Layers
|
||||
|
||||
Match model capacity to your dataset size:
|
||||
|
||||
- **Small datasets** (< 100 examples): Reduce to 4 layers
|
||||
- **Large datasets** (> 5k examples): Increase to 8 layers
|
||||
|
||||
#### 3. `horizon` Tuning
|
||||
|
||||
The model can be sensitive to the horizon you choose. Start with around a 1 second horizon based on your control frequency:
|
||||
|
||||
- **30 Hz frequency**: `horizon=30`
|
||||
- **10 Hz frequency**: `horizon=10`
|
||||
|
||||
Then experiment with increasing from there. The horizon determines how far into the future the model predicts actions.
|
||||
|
||||
#### 4. `n_action_steps` Sensitivity
|
||||
|
||||
The model can also be very sensitive to `n_action_steps`. Start with it being around 0.8 seconds based on your control frequency and tune from there:
|
||||
|
||||
- **Lower values**: More reactive but potentially less stable for long-horizon tasks
|
||||
- **Higher values**: Better for long-horizon execution but open-loop failures are limited in their recovery
|
||||
|
||||
### Inference Tuning
|
||||
|
||||
For faster inference, use DDIM with fewer sampling steps:
|
||||
|
||||
```bash
|
||||
--policy.noise_scheduler_type=DDIM \
|
||||
--policy.num_inference_steps=10
|
||||
```
|
||||
|
||||
### Resuming Training
|
||||
|
||||
To resume training from a checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--config_path=./outputs/mutitask_dit_training/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
```
|
||||
|
||||
The checkpoint directory should contain `model.safetensors` and `config.json` files (saved automatically during training). When resuming, the configuration is loaded from the checkpoint, so you don't need to specify other parameters.
|
||||
|
||||
## Common Failure Modes and Debugging
|
||||
|
||||
Training these models can be finicky. Here are common failure modes and debugging approaches:
|
||||
|
||||
### Idling / No Motion
|
||||
|
||||
The model may "collapse" during inference, resulting in static or no motion. This can occur when:
|
||||
|
||||
1. **Insufficient training data**: If you only have 20-50 examples, try to roughly double your dataset size. Once you have above 300 examples, if you're still seeing this, the task may be too complex.
|
||||
|
||||
2. **Multiple similar tasks**: When your dataset contains multiple similar tasks (e.g., picking up 2 different objects), the model may rely too heavily on language conditioning which might not be rich enough.
|
||||
|
||||
**Debugging tips:**
|
||||
|
||||
- Increase dataset size (double until you get to over 300 examples)
|
||||
- Train for longer, up to 100k steps, even when the loss flatlines
|
||||
- Check if the model is receiving proper language instructions or increase diversity of instruction
|
||||
|
||||
### Executing the Wrong Task
|
||||
|
||||
Sometimes the robot will completely ignore your instruction and perform some other task. This generally only happens if you have trained on multiple tasks.
|
||||
|
||||
**Potential causes:**
|
||||
|
||||
- Language instruction ambiguity
|
||||
- Insufficient task-specific training data
|
||||
- Model confusion between similar tasks in the multitask dataset
|
||||
|
||||
**Debugging tips:**
|
||||
|
||||
- Verify language instruction specificity, especially if descriptions are similar between multiple tasks
|
||||
- Check task distribution in your training dataset and add weighting to the failing/ignored task
|
||||
- Consider task-specific fine-tuning
|
||||
|
||||
### Training Instability
|
||||
|
||||
If training loss is unstable or diverging:
|
||||
|
||||
- Try adjusting learning rate between `1e-5` and `3e-4`
|
||||
- Increase batch size if possible
|
||||
- Check that your dataset normalization is correct
|
||||
- Verify image preprocessing is working correctly
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### GPU Requirements
|
||||
|
||||
- **Inference**: At least an RTX 5070 Ti (or equivalent GPU) is recommended for reasonable speed performance
|
||||
- **Training**: A GPU with enough VRAM to load batch sizes of >64 is ideal, which will vary depending on the number of image observations, etc
|
||||
|
||||
### Batch Size Recommendations
|
||||
|
||||
- **Minimum**: 64 (less than this may result in unstable training)
|
||||
- **Recommended**: 256-320 (best performance, requires larger GPU)
|
||||
|
||||
## Example: Training on Custom Dataset
|
||||
|
||||
Here's a complete example training on a custom dataset:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/mutitask_dit_training \
|
||||
--batch_size=320 \
|
||||
--steps=30000 \
|
||||
--save_freq=1000 \
|
||||
--log_freq=100 \
|
||||
--eval_freq=1000 \
|
||||
--policy.type=multi_task_dit \
|
||||
--policy.device=cuda \
|
||||
--policy.horizon=32 \
|
||||
--policy.n_action_steps=24 \
|
||||
--policy.objective=diffusion \
|
||||
--policy.noise_scheduler_type=DDPM \
|
||||
--policy.num_layers=6 \
|
||||
--policy.hidden_dim=512 \
|
||||
--policy.vision_encoder_name=openai/clip-vit-base-patch16 \
|
||||
--policy.image_resize_shape=[320,240] \
|
||||
--policy.image_crop_shape=[224,224] \
|
||||
--policy.repo_id="HF_USER/multitask-dit-your-robot" \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=multitask_dit
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
For more details on the technical implementation and architecture, see:
|
||||
|
||||
- [A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation](https://arxiv.org/abs/2507.05331)
|
||||
- [Large Behavior Models and Atlas Find New Footing](https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/)
|
||||
- [Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy](https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy)
|
||||
@@ -91,46 +91,6 @@ lerobot-train \
|
||||
|
||||
**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage.
|
||||
|
||||
## Relative Actions
|
||||
|
||||
By default, π₀ predicts absolute actions. You can enable **relative actions** so the model predicts offsets relative to the current robot state. This can improve training stability for certain setups.
|
||||
|
||||
To use relative actions, first recompute your dataset stats in relative space via the CLI:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
--repo_id your_dataset \
|
||||
--operation.type recompute_stats \
|
||||
--operation.relative_action true \
|
||||
--operation.chunk_size 50 \
|
||||
--operation.relative_exclude_joints "['gripper']" \
|
||||
--push_to_hub true
|
||||
```
|
||||
|
||||
Or equivalently in Python:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.dataset_tools import recompute_stats
|
||||
|
||||
dataset = LeRobotDataset("your_dataset")
|
||||
recompute_stats(dataset, relative_action=True, chunk_size=50, relative_exclude_joints=["gripper"])
|
||||
dataset.push_to_hub()
|
||||
```
|
||||
|
||||
The `chunk_size` should match your policy's `chunk_size` (default 50 for π₀). `relative_exclude_joints` lists joint names that should remain in absolute space (e.g. gripper commands). Use `--push_to_hub true` to upload the updated stats to the Hub.
|
||||
|
||||
Then train with relative actions enabled:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi0 \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]' \
|
||||
...
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
|
||||
@@ -97,46 +97,6 @@ python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
|
||||
|
||||
Or train pi05 with this normalization mapping: `--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'`
|
||||
|
||||
## Relative Actions
|
||||
|
||||
By default, π₀.₅ predicts absolute actions. You can enable **relative actions** so the model predicts offsets relative to the current robot state. This can improve training stability for certain setups.
|
||||
|
||||
To use relative actions, first recompute your dataset stats in relative space via the CLI:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
--repo_id your_dataset \
|
||||
--operation.type recompute_stats \
|
||||
--operation.relative_action true \
|
||||
--operation.chunk_size 50 \
|
||||
--operation.relative_exclude_joints "['gripper']" \
|
||||
--push_to_hub true
|
||||
```
|
||||
|
||||
Or equivalently in Python:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.dataset_tools import recompute_stats
|
||||
|
||||
dataset = LeRobotDataset("your_dataset")
|
||||
recompute_stats(dataset, relative_action=True, chunk_size=50, relative_exclude_joints=["gripper"])
|
||||
dataset.push_to_hub()
|
||||
```
|
||||
|
||||
The `chunk_size` should match your policy's `chunk_size` (default 50 for π₀.₅). `relative_exclude_joints` lists joint names that should remain in absolute space (e.g. gripper commands). Use `--push_to_hub true` to upload the updated stats to the Hub.
|
||||
|
||||
Then train with relative actions enabled:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=your_dataset \
|
||||
--policy.type=pi05 \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]' \
|
||||
...
|
||||
```
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Libero Benchmark Results
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
# Multitask DiT Policy
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite the following works:
|
||||
|
||||
```bibtex
|
||||
@misc{jones2025multitaskditpolicy,
|
||||
author = {Bryson Jones},
|
||||
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
|
||||
year = {2025},
|
||||
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
|
||||
author = {TRI LBM Team},
|
||||
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
|
||||
year = {2025},
|
||||
eprint = {arXiv:2507.05331},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2507.05331}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{bostondynamics2025largebehaviormodelsatlas,
|
||||
author = {Boston Dynamics and TRI Research Team},
|
||||
title = {Large Behavior Models and Atlas Find New Footing},
|
||||
year = {2025},
|
||||
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
@@ -1,114 +0,0 @@
|
||||
# Rename Map and Empty Cameras
|
||||
|
||||
When you train, evaluate, or record with a robot policy, your **dataset** or **environment** provides observations under one set of keys (e.g. `observation.images.front`, `observation.images.eagle`), while your **policy** expects another (e.g. `observation.images.image`, `observation.images.image2`). The **rename map** bridges that gap without changing the policy or data source.
|
||||
|
||||
> **Scope:** The rename map only renames **observation** keys (images and state). Action keys are not affected.
|
||||
|
||||
## Why observation keys don't always match
|
||||
|
||||
Policies have a fixed set of **input feature names** baked into their pretrained config. For example:
|
||||
|
||||
- [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb`.
|
||||
- [xvla-base](https://huggingface.co/lerobot/xvla-base) expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`.
|
||||
|
||||
Your dataset might use different names entirely (e.g. `observation.images.front`, `observation.images.eagle`, `observation.images.glove`), and your eval environment might use yet another set. Rather than editing the policy config or renaming columns in the dataset, you pass a **rename map**: a JSON dictionary that maps source keys to the keys the policy expects. Renaming happens inside the preprocessor pipeline, so the policy always sees its expected keys.
|
||||
|
||||
## Using the rename map
|
||||
|
||||
Pass the mapping as a JSON string on the command line. The convention is always:
|
||||
|
||||
```
|
||||
--rename_map='{"source_key": "policy_key", ...}'
|
||||
```
|
||||
|
||||
where **source_key** is what the dataset or environment provides, and **policy_key** is what the policy expects.
|
||||
|
||||
Only listed keys are renamed; everything else passes through unchanged. Order of entries doesn't matter.
|
||||
|
||||
Supported policies: **PI0**, **PI05**, **PI0Fast**, **SmolVLA**, and **XVLA**.
|
||||
|
||||
### Training
|
||||
|
||||
Suppose you fine-tune [lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base) on a dataset with images under `observation.images.front`, `observation.images.eagle`, and `observation.images.glove`. XVLA expects `observation.images.image`, `observation.images.image2`, and `observation.images.image3`:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/xvla_training \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.action_mode=auto \
|
||||
--steps=20000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
--rename_map='{"observation.images.front": "observation.images.image", "observation.images.eagle": "observation.images.image2", "observation.images.glove": "observation.images.image3"}'
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
A policy that expects `observation.images.base_0_rgb` and `observation.images.left_wrist_0_rgb` (e.g. [pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero)), but the LIBERO environment returns `observation.images.image` and `observation.images.image2`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/pi0fast-libero \
|
||||
--env.type=libero \
|
||||
... \
|
||||
--rename_map='{"observation.images.image": "observation.images.base_0_rgb", "observation.images.image2": "observation.images.left_wrist_0_rgb"}'
|
||||
```
|
||||
|
||||
### Recording
|
||||
|
||||
`lerobot-record` also supports rename maps, nested under the dataset config:
|
||||
|
||||
```bash
|
||||
lerobot-record \ # When running inference
|
||||
--policy.path="<user>/smolVLA_finetuned" \
|
||||
... \
|
||||
--dataset.rename_map='{"observation.images.glove2": "observation.images.image"}'
|
||||
```
|
||||
|
||||
## Alternative: edit the policy config directly
|
||||
|
||||
If you always use the same dataset or environment, you can **edit the policy's `config.json`** so its observation keys match your data source. Then no rename map is needed.
|
||||
|
||||
The tradeoff: modifying the policy config ties it to one data source. A rename map keeps one policy usable across many datasets and environments.
|
||||
|
||||
## Empty cameras: fewer views than the policy expects
|
||||
|
||||
Some policies are built for a fixed number of image inputs. If your dataset has fewer cameras, you can set **`empty_cameras`** in the policy config instead of modifying the model architecture.
|
||||
|
||||
### How it works
|
||||
|
||||
Setting `empty_cameras=N` adds N placeholder image features to the policy config, named:
|
||||
|
||||
```
|
||||
observation.images.empty_camera_0
|
||||
observation.images.empty_camera_1
|
||||
...
|
||||
```
|
||||
|
||||
At runtime, these keys have no corresponding data in the batch. The policy fills them with masked dummy tensors (padded with `-1` for SigLIP-based vision encoders, with a zero attention mask), so the extra image slots are effectively ignored during training and inference.
|
||||
|
||||
### Example
|
||||
|
||||
XVLA-base has three visual inputs and `empty_cameras=0` by default. Your dataset only has two cameras:
|
||||
|
||||
1. Set `--policy.empty_cameras=1`.
|
||||
2. The config adds a third key: `observation.images.empty_camera_0`.
|
||||
3. Use the rename map for your two real cameras as usual.
|
||||
4. The third slot is masked out — no fake images needed in your dataset.
|
||||
|
||||
## Quick reference
|
||||
|
||||
| Goal | What to do |
|
||||
| ----------------------------------------- | --------------------------------------------------------------------------- |
|
||||
| Dataset keys ≠ policy keys | `--rename_map='{"dataset_key": "policy_key", ...}'` |
|
||||
| Env keys ≠ policy keys (eval) | `--rename_map='{"env_key": "policy_key", ...}'` |
|
||||
| Recording with different keys (inference) | `--dataset.rename_map='{"source_key": "policy_key", ...}'`. |
|
||||
| Fewer cameras than policy expects | `--policy.empty_cameras=N` (supported by PI0, PI05, PI0Fast, SmolVLA, XVLA) |
|
||||
| Avoid passing a rename map | Edit the policy's `config.json` so its keys match your data source |
|
||||
@@ -236,10 +236,10 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
|
||||
|
||||
### Joint 1
|
||||
|
||||
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
|
||||
- Place the first motor into the base.
|
||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
|
||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||
- Attach the shoulder part.
|
||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||
- Add the shoulder motor holder.
|
||||
@@ -255,9 +255,9 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
|
||||
|
||||
### Joint 2
|
||||
|
||||
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
|
||||
- Slide the second motor in from the top.
|
||||
- Fasten the second motor with 4 M2x6mm screws.
|
||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
@@ -271,8 +271,8 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
|
||||
|
||||
### Joint 3
|
||||
|
||||
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws.
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
@@ -286,10 +286,9 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
|
||||
|
||||
### Joint 4
|
||||
|
||||
- Install both motor horns. Secure the top horn with a M3x6mm screw. No screws are required for the bottom horn.
|
||||
- Slide over motor holder 4.
|
||||
- Slide in motor 4.
|
||||
- Fasten motor 4 with 4 M2x6mm screws.
|
||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
@@ -322,7 +321,7 @@ It is advisable to install one 3-pin cable in the motor after placing them befor
|
||||
|
||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||
- Install both motor horns on the gripper motor. Secure the top horn with a M3x6mm screw; no screws are required for the bottom horn.
|
||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
# UMI Data with pi0 Relative EE Actions
|
||||
|
||||
This guide explains how to train a pi0 policy with UMI-style relative end-effector (EE) actions and deploy it on a real OpenArm robot.
|
||||
|
||||
**What we will do:**
|
||||
|
||||
1. Prepare the dataset (EE pose + gripper in the action column).
|
||||
2. Recompute statistics for relative actions.
|
||||
3. Train pi0 with `derive_state_from_action=true`.
|
||||
4. Evaluate the trained policy on a real robot.
|
||||
|
||||
## Background
|
||||
|
||||
[UMI (Universal Manipulation Interface)](https://umi-gripper.github.io) collects manipulation data with hand-held grippers, recovering 6-DoF EE poses via SLAM. The key insight from UMI (Chi et al., 2024) is that the action space must include **both EE trajectory and gripper width**, and actions should be expressed as **relative trajectories** (offsets from the current pose).
|
||||
|
||||
### Dataset layout
|
||||
|
||||
The dataset should have this structure:
|
||||
|
||||
| Feature | Shape | Content |
|
||||
| ------------------------- | --------- | -------------------------------------------------------- |
|
||||
| `observation.images.cam0` | `[3,H,W]` | Wrist camera image |
|
||||
| `action` | `[8]` | `[x, y, z, ax, ay, az, proximal, distal]` (EE + gripper) |
|
||||
|
||||
No separate `observation.pose` or `observation.joints` columns are needed — the model derives its proprioception state directly from the action column (`derive_state_from_action=true`).
|
||||
|
||||
### Why relative actions?
|
||||
|
||||
With relative actions, each action in a chunk is an **offset from the current state** rather than an absolute target:
|
||||
|
||||
```
|
||||
relative_action[i] = absolute_action[t + i] − state[t]
|
||||
```
|
||||
|
||||
UMI ablations show this is critical: absolute actions achieve only 25% success vs 100% for relative trajectory on the cup arrangement task. Compared to delta actions (each step relative to the previous), relative trajectory avoids error accumulation. See the [Action Representations](action_representations) guide for details.
|
||||
|
||||
### `derive_state_from_action`
|
||||
|
||||
When `derive_state_from_action=true`, pi0 derives `observation.state` from the action column during training — no separate state column needed. Under the hood:
|
||||
|
||||
- `action_delta_indices` extends to `[-1, 0, 1, ..., chunk_size-1]` (one extra leading timestep).
|
||||
- `DeriveStateFromActionStep` extracts `[action[t-1], action[t]]` as a 2-step state and strips the extra timestep from the action chunk.
|
||||
- `RelativeActionsProcessorStep` converts actions to offsets from `state[t]`.
|
||||
- `RelativeStateProcessorStep` converts the 2-step state to relative proprioception (velocity + zeros) and flattens.
|
||||
|
||||
This implies `use_relative_state=true` and `state_obs_steps=2`.
|
||||
|
||||
During **inference**, `DeriveStateFromActionStep` is a no-op — state comes from the robot via forward kinematics. `RelativeStateProcessorStep` buffers the previous state and applies the same conversion automatically.
|
||||
|
||||
## Step 1: Recompute Stats
|
||||
|
||||
After preparing the dataset with EE pose in the action column, recompute statistics with `derive_state_from_action=true`. This computes relative action and state stats so the normalizer sees offset distributions:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
--repo-id=glannuzel/grabette-dataset \
|
||||
--operation=recompute_stats \
|
||||
--operation.relative_action=true \
|
||||
--operation.relative_exclude_joints='["proximal", "distal"]' \
|
||||
--operation.derive_state_from_action=true \
|
||||
--operation.chunk_size=30 \
|
||||
--push_to_hub=true
|
||||
```
|
||||
|
||||
| Flag | Purpose |
|
||||
| ------------------------------- | ------------------------------------------------------------------------------- |
|
||||
| `relative_action=true` | Compute stats on `action − state` (relative actions) |
|
||||
| `relative_exclude_joints` | Keep gripper dims absolute (they don't benefit from relative encoding) |
|
||||
| `derive_state_from_action=true` | Derive state from action column (implies `relative_state`, `state_obs_steps=2`) |
|
||||
| `chunk_size=30` | Must match training chunk size |
|
||||
|
||||
## Step 2: Train
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:${LD_LIBRARY_PATH:-}
|
||||
|
||||
DATASET="glannuzel/grabette-dataset"
|
||||
NUM_PROCESSES=8
|
||||
|
||||
echo "=== Training pi0 on $DATASET (UMI relative EE, ${NUM_PROCESSES} GPUs) ==="
|
||||
accelerate launch --multi_gpu --num_processes=$NUM_PROCESSES \
|
||||
-m lerobot.scripts.lerobot_train \
|
||||
--dataset.repo_id="$DATASET" \
|
||||
--dataset.video_backend=pyav \
|
||||
--policy.type=pi0 \
|
||||
--policy.pretrained_path=lerobot/pi0_base \
|
||||
--policy.repo_id=pepijn/grabette-umi-pi0 \
|
||||
--policy.chunk_size=30 \
|
||||
--policy.n_action_steps=30 \
|
||||
--policy.derive_state_from_action=true \
|
||||
--use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["proximal", "distal"]' \
|
||||
--batch_size=32 \
|
||||
--steps=5000 \
|
||||
--policy.scheduler_decay_steps=5000 \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.compile_model=false \
|
||||
--policy.gradient_checkpointing=true \
|
||||
--policy.device=cuda \
|
||||
--output_dir=/fsx/pepijn/outputs/grabette-umi \
|
||||
--job_name=grabette-umi-v2 \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true \
|
||||
--wandb.project=grabette-umi \
|
||||
--log_freq=100 \
|
||||
--save_freq=5000
|
||||
```
|
||||
|
||||
Key flags:
|
||||
|
||||
| Flag | Purpose |
|
||||
| ------------------------------- | ---------------------------------------------------------------------- |
|
||||
| `derive_state_from_action=true` | Derive proprioception from action column (full UMI mode) |
|
||||
| `use_relative_actions=true` | Actions are offsets from current state |
|
||||
| `relative_exclude_joints` | `["proximal", "distal"]` — gripper stays absolute, EE pose is relative |
|
||||
| `chunk_size=30` | Action horizon: 30 steps (~0.65s at 46 FPS) |
|
||||
| `n_action_steps=30` | Execute full chunk before replanning |
|
||||
|
||||
Note: `derive_state_from_action=true` automatically implies `use_relative_state=true` and `state_obs_steps=2`. No `rename_map` is needed since there are no separate observation columns to rename.
|
||||
|
||||
## Step 3: Evaluate
|
||||
|
||||
The evaluation script in `examples/umi_pi0_relative_ee/evaluate.py` runs inference on a real OpenArm robot:
|
||||
|
||||
```bash
|
||||
python examples/umi_pi0_relative_ee/evaluate.py
|
||||
```
|
||||
|
||||
Edit `HF_MODEL_ID`, camera index, and robot configuration at the top of the file.
|
||||
|
||||
### How inference works
|
||||
|
||||
At inference, the training dataset has no `observation.state` — it was derived from actions. The evaluate script provides `observation.state` from the robot via forward kinematics:
|
||||
|
||||
1. **Robot → FK** — Arm joint positions → EE pose `[x,y,z,ax,ay,az]`, gripper → `[proximal, distal]`. Combined into `observation.state` (8D).
|
||||
2. **Preprocessor** (loaded from checkpoint) — `DeriveStateFromActionStep` is a no-op. `RelativeStateProcessorStep` buffers previous state, stacks `[prev, current]`, subtracts current → velocity info. `RelativeActionsProcessorStep` caches state. `NormalizerProcessorStep` normalizes.
|
||||
3. **pi0 inference** — Predicts normalized relative action chunk (30 steps).
|
||||
4. **Postprocessor** — `UnnormalizerProcessorStep` unnormalizes, `AbsoluteActionsProcessorStep` adds cached state → absolute EE targets.
|
||||
5. **IK → Robot** — Absolute `[x,y,z,ax,ay,az]` → arm joint targets with full 6-DOF IK (orientation weight = 1.0). `[proximal, distal]` → direct gripper position commands.
|
||||
|
||||
### Latency compensation
|
||||
|
||||
Set `LATENCY_SKIP_STEPS` to skip the first few predicted action steps, compensating for system latency:
|
||||
|
||||
```python
|
||||
LATENCY_SKIP_STEPS = 7 # ceil(total_latency_ms / (1000 / FPS))
|
||||
```
|
||||
|
||||
At 46 FPS (~22ms/step) with ~150ms total latency: `ceil(150/22) ≈ 7`. Start with 0 for a safe first test.
|
||||
|
||||
## Replay Viewer
|
||||
|
||||
Visualize any dataset episode in a browser-based 3D viewer before running on hardware. The viewer shows the EE trajectory overlaid on the OpenArm URDF model.
|
||||
|
||||
### Quick start
|
||||
|
||||
```bash
|
||||
python examples/umi_pi0_relative_ee/replay.py
|
||||
```
|
||||
|
||||
### Options
|
||||
|
||||
| Flag | Default | Description |
|
||||
| ----------- | ---------------------------- | ------------------------------------ |
|
||||
| `--repo-id` | `glannuzel/grabette-dataset` | HuggingFace dataset repo to load |
|
||||
| `--episode` | `0` | Episode index to replay |
|
||||
| `--port` | `8765` | HTTP server port |
|
||||
| `--force` | off | Re-extract trajectory even if cached |
|
||||
|
||||
### Viewer controls
|
||||
|
||||
The panel in the top-left corner shows live EE coordinates and gripper state. Transport controls:
|
||||
|
||||
- **Play / Pause** — toggle automatic playback.
|
||||
- **Step buttons** (◀ ▶) — advance or rewind one frame.
|
||||
- **Reset** (⟳) — jump to frame 0.
|
||||
- **Scrubber** — drag to seek.
|
||||
- **Speed selector** — 0.25× to 4× playback speed.
|
||||
|
||||
### Color legend
|
||||
|
||||
| Color | Meaning |
|
||||
| ------------------ | --------------------------------------------- |
|
||||
| Red sphere | Current EE position |
|
||||
| Yellow trail | Past trajectory |
|
||||
| Dark trail | Future trajectory |
|
||||
| Orange ring + axes | URDF `ee_target` frame (zero-joint reference) |
|
||||
|
||||
## How the Pieces Fit Together
|
||||
|
||||
```
|
||||
Training (derive_state_from_action=true):
|
||||
DataLoader loads action: [B, 31, 8] (chunk_size=30 + 1 leading)
|
||||
→ DeriveStateFromActionStep
|
||||
state = action[:, :2, :] → [B, 2, 8]
|
||||
action = action[:, 1:, :] → [B, 30, 8]
|
||||
→ RelativeActionsProcessorStep (action -= state[:, -1, :])
|
||||
→ RelativeStateProcessorStep (state offsets from current, flatten → [B, 16])
|
||||
→ NormalizerProcessorStep → pi0 model
|
||||
|
||||
Inference:
|
||||
arm joints → FK → observation.state [8D: x,y,z,ax,ay,az,prox,dist]
|
||||
↓
|
||||
DeriveStateFromActionStep (no-op)
|
||||
↓
|
||||
RelativeActionsProcessorStep (caches state)
|
||||
↓
|
||||
RelativeStateProcessorStep (buffers prev, stacks, subtracts, flattens)
|
||||
↓
|
||||
NormalizerProcessorStep → pi0 model → relative action chunk [30, 8]
|
||||
↓
|
||||
UnnormalizerProcessorStep
|
||||
↓
|
||||
AbsoluteActionsProcessorStep (+ cached state → absolute EE)
|
||||
↓
|
||||
IK → joint targets → robot
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- [UMI: Universal Manipulation Interface](https://umi-gripper.github.io) — Chi et al., 2024. Defines relative trajectory actions.
|
||||
- [Action Representations](action_representations) — LeRobot guide comparing absolute, relative, and delta actions.
|
||||
- [pi0 documentation](pi0) — Full pi0 configuration including `use_relative_actions`.
|
||||
- [`examples/so100_to_so100_EE/`](https://github.com/huggingface/lerobot/tree/main/examples/so100_to_so100_EE) — EE-space evaluation example this builds on.
|
||||
@@ -78,7 +78,7 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
actions = dataset.select_columns(ACTION)
|
||||
actions = dataset.hf_dataset.select_columns(ACTION)
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
|
||||
@@ -88,8 +88,9 @@ def main():
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
|
||||
# You can inspect the dataset using its repr:
|
||||
print(dataset)
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
|
||||
717
examples/dataset/visualization_tools/action_consistency.py
Normal file
717
examples/dataset/visualization_tools/action_consistency.py
Normal file
@@ -0,0 +1,717 @@
|
||||
"""
|
||||
Action consistency analysis for imitation learning datasets.
|
||||
|
||||
Two parallel analyses per dataset:
|
||||
1. State-based: KNN in joint-state space → action chunk variance
|
||||
2. Image-based: KNN in SigLIP embedding space → action chunk variance
|
||||
|
||||
Comparing them reveals whether visual similarity and proprioceptive similarity
|
||||
agree on where the data is inconsistent — and images are what the policy
|
||||
primarily sees.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from matplotlib.colors import LinearSegmentedColormap
|
||||
from PIL import Image
|
||||
from scipy.spatial import cKDTree
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
MAX_FRAMES = 100_000
|
||||
K_NEIGHBORS = 50
|
||||
ACTION_CHUNK_SIZE = 30
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
ENCODER_MODEL = "google/siglip-base-patch16-224"
|
||||
ENCODE_BATCH_SIZE = 512
|
||||
SEED = 42
|
||||
DPI = 150
|
||||
|
||||
CONSISTENCY_CMAP = LinearSegmentedColormap.from_list(
|
||||
"consistency", ["#0a2e0a", "#1a8e1a", "#88cc22", "#ffaa22", "#ff2222"]
|
||||
)
|
||||
|
||||
# FK chains from OpenArm bimanual URDF (same as workspace_density.py).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data helpers ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str, camera_key: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet + {camera_key} videos) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=[
|
||||
"meta/**",
|
||||
"data/**",
|
||||
f"videos/{camera_key}/**",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_action_chunks(
|
||||
actions: np.ndarray, episode_ids: np.ndarray, chunk_size: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
For each frame, concatenate the next chunk_size actions from the same episode.
|
||||
Returns (action_chunks, valid_mask).
|
||||
"""
|
||||
n = len(actions)
|
||||
act_dim = actions.shape[1]
|
||||
chunks = np.zeros((n, chunk_size * act_dim), dtype=np.float64)
|
||||
valid = np.zeros(n, dtype=bool)
|
||||
|
||||
for i in range(n):
|
||||
end = i + chunk_size
|
||||
if end > n:
|
||||
continue
|
||||
if episode_ids[i] != episode_ids[end - 1]:
|
||||
continue
|
||||
chunks[i] = actions[i:end].ravel()
|
||||
valid[i] = True
|
||||
|
||||
return chunks, valid
|
||||
|
||||
|
||||
def load_state_action_data(local: Path, max_frames: int, chunk_size: int, rng: np.random.Generator) -> dict:
|
||||
"""
|
||||
Load observation.state and action, build action chunks, subsample, normalize.
|
||||
Also returns the original row indices (`chosen_idx`) for video frame mapping.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
n_total = len(df)
|
||||
print(f" Total frames: {n_total:,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
action_col = next((c for c in df.columns if c == "action"), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
if action_col is None:
|
||||
raise RuntimeError(f"No action column. Available: {list(df.columns)}")
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
state_all = np.stack(df[state_col].values).astype(np.float64)
|
||||
action_all = np.stack(df[action_col].values).astype(np.float64)
|
||||
episode_all = df[ep_col].values.astype(np.int64)
|
||||
|
||||
n_dim = state_all.shape[1]
|
||||
act_dim = action_all.shape[1]
|
||||
print(f" State dim: {n_dim} Action dim: {act_dim} Chunk size: {chunk_size}")
|
||||
print(f" Action chunk dim: {chunk_size * act_dim}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
print(" Building action chunks …")
|
||||
action_chunks, valid = _build_action_chunks(action_all, episode_all, chunk_size)
|
||||
valid_idx = np.where(valid)[0]
|
||||
print(f" Valid frames (with full action chunk): {len(valid_idx):,} / {n_total:,}")
|
||||
|
||||
if len(valid_idx) > max_frames:
|
||||
chosen = np.sort(rng.choice(valid_idx, max_frames, replace=False))
|
||||
else:
|
||||
chosen = valid_idx
|
||||
print(f" Using {len(chosen):,} frames")
|
||||
|
||||
state_raw = state_all[chosen]
|
||||
action_raw = action_chunks[chosen]
|
||||
episode_ids = episode_all[chosen]
|
||||
|
||||
state_mean = state_raw.mean(axis=0)
|
||||
state_std = state_raw.std(axis=0)
|
||||
state_std[state_std < 1e-8] = 1.0
|
||||
state_norm = (state_raw - state_mean) / state_std
|
||||
|
||||
action_mean = action_raw.mean(axis=0)
|
||||
action_std = action_raw.std(axis=0)
|
||||
action_std[action_std < 1e-8] = 1.0
|
||||
action_norm = (action_raw - action_mean) / action_std
|
||||
|
||||
return {
|
||||
"state_raw": state_raw,
|
||||
"state_norm": state_norm,
|
||||
"action_raw": action_raw,
|
||||
"action_norm": action_norm,
|
||||
"episode_ids": episode_ids,
|
||||
"episode_all": episode_all,
|
||||
"left_joint_idx": left_idx,
|
||||
"right_joint_idx": right_idx,
|
||||
"n_total": n_total,
|
||||
"chosen_idx": chosen,
|
||||
"df": df,
|
||||
}
|
||||
|
||||
|
||||
# ── Video → frame extraction ──────────────────────────────
|
||||
|
||||
|
||||
def build_video_lookup(local: Path, camera_key: str) -> dict:
|
||||
"""
|
||||
Build a mapping from episode_index → {video_path, fps, from_ts}.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
chunk_col = f"videos/{camera_key}/chunk_index"
|
||||
file_col = f"videos/{camera_key}/file_index"
|
||||
ts_from = f"videos/{camera_key}/from_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{camera_key}/chunk_index"
|
||||
file_col = f"{camera_key}/file_index"
|
||||
ts_from = f"{camera_key}/from_timestamp"
|
||||
|
||||
lookup: dict[int, dict] = {}
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
video_rel = video_template.format(video_key=camera_key, chunk_index=ci, file_index=fi)
|
||||
lookup[int(row["episode_index"])] = {
|
||||
"video_path": local / video_rel,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"fps": fps,
|
||||
}
|
||||
return lookup
|
||||
|
||||
|
||||
def _decode_video_frames(video_path: str) -> list[np.ndarray]:
|
||||
"""Decode all frames from a video file using PyAV. Returns list of RGB arrays."""
|
||||
container = av.open(video_path)
|
||||
stream = container.streams.video[0]
|
||||
stream.thread_type = "AUTO"
|
||||
decoded = []
|
||||
for frame in container.decode(stream):
|
||||
decoded.append(frame.to_ndarray(format="rgb24"))
|
||||
container.close()
|
||||
return decoded
|
||||
|
||||
|
||||
def extract_frames(
|
||||
chosen_idx: np.ndarray,
|
||||
episode_all: np.ndarray,
|
||||
video_lookup: dict,
|
||||
) -> list[np.ndarray | None]:
|
||||
"""
|
||||
Extract RGB frames for each chosen global index using PyAV.
|
||||
Returns list of (H, W, 3) RGB arrays (or None on failure).
|
||||
"""
|
||||
unique_eps = np.unique(episode_all)
|
||||
ep_start: dict[int, int] = {}
|
||||
for ep in unique_eps:
|
||||
ep_start[int(ep)] = int(np.where(episode_all == ep)[0][0])
|
||||
|
||||
# Build jobs: (output_index, video_path, local_frame_number)
|
||||
jobs: list[tuple[int, str, int]] = []
|
||||
for out_i, global_i in enumerate(chosen_idx):
|
||||
ep = int(episode_all[global_i])
|
||||
info = video_lookup.get(ep)
|
||||
if info is None:
|
||||
continue
|
||||
local_frame = global_i - ep_start[ep]
|
||||
jobs.append((out_i, str(info["video_path"]), local_frame))
|
||||
|
||||
# Group by video file, decode each video once
|
||||
from collections import defaultdict
|
||||
|
||||
video_jobs: dict[str, list[tuple[int, int]]] = defaultdict(list)
|
||||
for out_i, vpath, local_frame in jobs:
|
||||
video_jobs[vpath].append((out_i, local_frame))
|
||||
|
||||
frames: list[np.ndarray | None] = [None] * len(chosen_idx)
|
||||
extracted = 0
|
||||
n_videos = len(video_jobs)
|
||||
for vi, (vpath, frame_requests) in enumerate(video_jobs.items()):
|
||||
if not Path(vpath).exists():
|
||||
continue
|
||||
try:
|
||||
decoded = _decode_video_frames(vpath)
|
||||
except Exception as exc:
|
||||
print(f" Warning: failed to decode {Path(vpath).name}: {exc}")
|
||||
continue
|
||||
for out_i, local_frame in frame_requests:
|
||||
if 0 <= local_frame < len(decoded):
|
||||
frames[out_i] = decoded[local_frame]
|
||||
extracted += 1
|
||||
if (vi + 1) % 50 == 0 or (vi + 1) == n_videos:
|
||||
print(f" Decoded {vi + 1}/{n_videos} videos ({extracted:,} frames so far)")
|
||||
del decoded
|
||||
|
||||
print(f" Extracted {extracted:,} / {len(chosen_idx):,} frames from video")
|
||||
return frames
|
||||
|
||||
|
||||
# ── SigLIP encoding ─────────────────────────────────────
|
||||
|
||||
|
||||
def encode_frames_siglip(
|
||||
frames: list[np.ndarray | None],
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Encode RGB frames through SigLIP vision encoder.
|
||||
Returns (N, embed_dim) float32 array. Frames that are None get a zero vector.
|
||||
"""
|
||||
print(f" Loading SigLIP model: {model_name} …")
|
||||
processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
embed_dim = model.config.vision_config.hidden_size
|
||||
|
||||
n = len(frames)
|
||||
embeddings = np.zeros((n, embed_dim), dtype=np.float32)
|
||||
|
||||
valid_indices = [i for i, f in enumerate(frames) if f is not None]
|
||||
print(f" Encoding {len(valid_indices):,} valid frames in batches of {batch_size} …")
|
||||
|
||||
for batch_start in range(0, len(valid_indices), batch_size):
|
||||
batch_idx = valid_indices[batch_start : batch_start + batch_size]
|
||||
pil_images = [Image.fromarray(frames[i]) for i in batch_idx]
|
||||
|
||||
inputs = processor(images=pil_images, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
image_features = model.get_image_features(**inputs)
|
||||
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
||||
embeddings[batch_idx] = image_features.cpu().numpy()
|
||||
|
||||
done = min(batch_start + batch_size, len(valid_indices))
|
||||
if done % (batch_size * 10) == 0 or done == len(valid_indices):
|
||||
print(f" {done:,} / {len(valid_indices):,} encoded")
|
||||
|
||||
del model, processor
|
||||
torch.cuda.empty_cache()
|
||||
return embeddings
|
||||
|
||||
|
||||
# ── KNN consistency ─────────────────────────────────────
|
||||
|
||||
|
||||
def compute_consistency(
|
||||
features: np.ndarray,
|
||||
action_norm: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
k: int,
|
||||
label: str = "",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
For each frame, find K nearest neighbors in feature space from other episodes.
|
||||
Return per-frame action variance (mean across action dims).
|
||||
"""
|
||||
n = len(features)
|
||||
print(f" Building KD-tree on {n:,} vectors ({label}) …")
|
||||
tree = cKDTree(features)
|
||||
|
||||
k_query = min(k * 3, n - 1)
|
||||
print(f" Querying {k_query} neighbors per frame …")
|
||||
_dists, indices = tree.query(features, k=k_query + 1)
|
||||
indices = indices[:, 1:]
|
||||
|
||||
print(f" Computing cross-episode action variance ({label}) …")
|
||||
variance = np.zeros(n)
|
||||
for i in range(n):
|
||||
ep_i = episode_ids[i]
|
||||
neighbors = indices[i]
|
||||
cross_ep = neighbors[episode_ids[neighbors] != ep_i][:k]
|
||||
if len(cross_ep) < 2:
|
||||
variance[i] = 0.0
|
||||
continue
|
||||
neighbor_actions = action_norm[cross_ep]
|
||||
variance[i] = np.mean(np.var(neighbor_actions, axis=0))
|
||||
|
||||
return variance
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
|
||||
def _style_ax(ax: plt.Axes) -> None:
|
||||
ax.set_facecolor("#0d1117")
|
||||
ax.tick_params(colors="#555", labelsize=8)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
|
||||
|
||||
def _plot_histogram(ax: plt.Axes, variance: np.ndarray, title: str, color: str) -> None:
|
||||
_style_ax(ax)
|
||||
median_var = np.median(variance)
|
||||
mean_var = np.mean(variance)
|
||||
nonzero = variance[variance > 0]
|
||||
if len(nonzero) > 0:
|
||||
bins = np.logspace(np.log10(nonzero.min().clip(1e-6)), np.log10(nonzero.max()), 60)
|
||||
ax.hist(nonzero, bins=bins, color=color, alpha=0.8, edgecolor="#222")
|
||||
ax.set_xscale("log")
|
||||
ax.axvline(median_var, color="#ff6600", linewidth=2, label=f"median={median_var:.3f}")
|
||||
ax.axvline(mean_var, color="#ff2222", linewidth=2, linestyle="--", label=f"mean={mean_var:.3f}")
|
||||
ax.set_xlabel("Action variance (log scale)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Frame count", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_episode_curves(
|
||||
ax: plt.Axes,
|
||||
var_state: np.ndarray,
|
||||
var_image: np.ndarray,
|
||||
episode_ids: np.ndarray,
|
||||
title: str,
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
unique_eps = np.unique(episode_ids)
|
||||
|
||||
ep_means_s = np.array([var_state[episode_ids == ep].mean() for ep in unique_eps])
|
||||
ep_means_i = np.array([var_image[episode_ids == ep].mean() for ep in unique_eps])
|
||||
|
||||
sorted_s = np.sort(ep_means_s)[::-1]
|
||||
sorted_i = np.sort(ep_means_i)[::-1]
|
||||
ep_x = np.arange(len(unique_eps))
|
||||
|
||||
ax.fill_between(ep_x, sorted_s, alpha=0.2, color="#4363d8")
|
||||
ax.plot(ep_x, sorted_s, color="#4363d8", linewidth=1.2, label=f"State (med={np.median(ep_means_s):.3f})")
|
||||
ax.fill_between(ep_x, sorted_i, alpha=0.2, color="#e6194b")
|
||||
ax.plot(ep_x, sorted_i, color="#e6194b", linewidth=1.2, label=f"Image (med={np.median(ep_means_i):.3f})")
|
||||
|
||||
ax.set_xlabel("Episode rank (worst → best)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Mean action variance", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.legend(fontsize=8, facecolor="#1a1a2e", edgecolor="#333", labelcolor="white")
|
||||
|
||||
|
||||
def _plot_heatmap(
|
||||
ax: plt.Axes, fig: plt.Figure, tcp_xz: np.ndarray, variance: np.ndarray, title: str
|
||||
) -> None:
|
||||
_style_ax(ax)
|
||||
order = np.argsort(variance)
|
||||
pts = tcp_xz[order]
|
||||
var_sorted = variance[order]
|
||||
vmin = np.percentile(variance[variance > 0], 5) if np.any(variance > 0) else 0
|
||||
vmax = np.percentile(variance[variance > 0], 95) if np.any(variance > 0) else 1
|
||||
sc = ax.scatter(
|
||||
pts[:, 0],
|
||||
pts[:, 1],
|
||||
c=var_sorted,
|
||||
cmap=CONSISTENCY_CMAP,
|
||||
s=0.5,
|
||||
alpha=0.6,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
rasterized=True,
|
||||
)
|
||||
ax.set_xlabel("X (m)", color="#888", fontsize=10)
|
||||
ax.set_ylabel("Z (m)", color="#888", fontsize=10)
|
||||
ax.set_title(title, color="white", fontsize=11, pad=10)
|
||||
ax.set_aspect("equal")
|
||||
cbar = fig.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||||
cbar.set_label("Action variance", color="white", fontsize=9)
|
||||
cbar.ax.tick_params(colors="#aaa", labelsize=7)
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
4-row x N-column figure:
|
||||
Row 0: State-based variance histogram
|
||||
Row 1: Image-based variance histogram
|
||||
Row 2: Per-episode curves (both overlaid)
|
||||
Row 3: Spatial heatmap (image-based variance)
|
||||
"""
|
||||
n_ds = len(results)
|
||||
fig, axes = plt.subplots(4, n_ds, figsize=(9 * n_ds, 24), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[:, np.newaxis]
|
||||
|
||||
headline_parts = []
|
||||
for col, r in enumerate(results):
|
||||
label = r["label"]
|
||||
var_s = r["var_state"]
|
||||
var_i = r["var_image"]
|
||||
tcp_xz = r["tcp_xz"]
|
||||
episode_ids = r["episode_ids"]
|
||||
|
||||
med_s = np.median(var_s)
|
||||
med_i = np.median(var_i)
|
||||
headline_parts.append(f"{label}: state={med_s:.3f}, image={med_i:.3f}")
|
||||
|
||||
_plot_histogram(axes[0, col], var_s, f"{label}\nState-based variance (K={K_NEIGHBORS})", "#4363d8")
|
||||
_plot_histogram(
|
||||
axes[1, col], var_i, f"{label}\nImage-based variance (SigLIP, K={K_NEIGHBORS})", "#e6194b"
|
||||
)
|
||||
_plot_episode_curves(
|
||||
axes[2, col],
|
||||
var_s,
|
||||
var_i,
|
||||
episode_ids,
|
||||
f"{label}\nPer-episode inconsistency ({len(np.unique(episode_ids)):,} episodes)",
|
||||
)
|
||||
_plot_heatmap(
|
||||
axes[3, col],
|
||||
fig,
|
||||
tcp_xz,
|
||||
var_i,
|
||||
f"{label}\nImage-based variance by TCP position (XZ)",
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
f"Action Consistency: State vs Image (chunk={ACTION_CHUNK_SIZE}, K={K_NEIGHBORS})\n"
|
||||
+ " | ".join(headline_parts),
|
||||
color="white",
|
||||
fontsize=15,
|
||||
y=0.99,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Device: {device}")
|
||||
rng = np.random.default_rng(SEED)
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id, CAMERA_KEY)
|
||||
data = load_state_action_data(local, MAX_FRAMES, ACTION_CHUNK_SIZE, rng)
|
||||
|
||||
# --- State-based KNN ---
|
||||
var_state = compute_consistency(
|
||||
data["state_norm"], data["action_norm"], data["episode_ids"], K_NEIGHBORS, "state"
|
||||
)
|
||||
print(
|
||||
f" State variance: median={np.median(var_state):.4f} "
|
||||
f"mean={np.mean(var_state):.4f} p90={np.percentile(var_state, 90):.4f}"
|
||||
)
|
||||
|
||||
# --- Image-based KNN ---
|
||||
print("\n Preparing image embeddings …")
|
||||
video_lookup = build_video_lookup(local, CAMERA_KEY)
|
||||
frames = extract_frames(data["chosen_idx"], data["episode_all"], video_lookup)
|
||||
embeddings = encode_frames_siglip(frames, ENCODER_MODEL, ENCODE_BATCH_SIZE, device)
|
||||
del frames # free memory
|
||||
|
||||
var_image = compute_consistency(
|
||||
embeddings, data["action_norm"], data["episode_ids"], K_NEIGHBORS, "image"
|
||||
)
|
||||
print(
|
||||
f" Image variance: median={np.median(var_image):.4f} "
|
||||
f"mean={np.mean(var_image):.4f} p90={np.percentile(var_image, 90):.4f}"
|
||||
)
|
||||
|
||||
# FK for spatial heatmap
|
||||
print(" Computing FK for spatial heatmap …")
|
||||
left_raw = data["state_raw"][:, data["left_joint_idx"]]
|
||||
left_rad = _detect_and_convert(left_raw)
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_rad)
|
||||
tcp_xz = left_tcp[:, [0, 2]]
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"var_state": var_state,
|
||||
"var_image": var_image,
|
||||
"episode_ids": data["episode_ids"],
|
||||
"tcp_xz": tcp_xz,
|
||||
"n_total": data["n_total"],
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "action_consistency_comparison.jpg"
|
||||
render(results, out)
|
||||
|
||||
# Save worst-episodes summary (image-based, since that's the stronger signal)
|
||||
worst_summary = {}
|
||||
for r in results:
|
||||
unique_eps = np.unique(r["episode_ids"])
|
||||
ep_means = {int(ep): float(r["var_image"][r["episode_ids"] == ep].mean()) for ep in unique_eps}
|
||||
ranked = sorted(ep_means.items(), key=lambda x: x[1], reverse=True)[:50]
|
||||
worst_summary[r["label"]] = [{"episode": ep, "mean_variance": v} for ep, v in ranked]
|
||||
worst_path = OUTPUT_DIR / "action_consistency_worst_episodes.json"
|
||||
worst_path.write_text(json.dumps(worst_summary, indent=2))
|
||||
print(f"✓ Saved worst episodes: {worst_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
178
examples/dataset/visualization_tools/create_frame_grid.py
Normal file
178
examples/dataset/visualization_tools/create_frame_grid.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
Create a JPG grid of random frames sampled from a LeRobot video dataset.
|
||||
Downloads metadata + video chunks from HuggingFace, picks random frames,
|
||||
decodes them, and tiles into a single image.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
REPO_ID = "lerobot-data-collection/level2_final_quality3"
|
||||
CAMERA_KEY = "observation.images.base"
|
||||
GRID_COLS = 15
|
||||
GRID_ROWS = 10
|
||||
THUMB_WIDTH = 160
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
SEED = 1
|
||||
|
||||
|
||||
def download_metadata(repo_id: str) -> Path:
|
||||
"""Download only metadata (no videos yet)."""
|
||||
print(f"[1/3] Downloading metadata for {repo_id} …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_video_info(local: Path) -> tuple[str, list[dict], int]:
|
||||
"""Parse info.json and episode parquets. Returns (camera_key, episode_rows, fps)."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
cam = CAMERA_KEY
|
||||
else:
|
||||
cam = video_keys[0]
|
||||
print(f" camera='{cam}' all_cams={video_keys} fps={fps}")
|
||||
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
ep_rows.append(pd.read_parquet(pq))
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
|
||||
video_template = info.get(
|
||||
"video_path",
|
||||
"videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4",
|
||||
)
|
||||
|
||||
chunk_col = f"videos/{cam}/chunk_index"
|
||||
file_col = f"videos/{cam}/file_index"
|
||||
ts_from = f"videos/{cam}/from_timestamp"
|
||||
ts_to = f"videos/{cam}/to_timestamp"
|
||||
if chunk_col not in ep_df.columns:
|
||||
chunk_col = f"{cam}/chunk_index"
|
||||
file_col = f"{cam}/file_index"
|
||||
ts_from = f"{cam}/from_timestamp"
|
||||
ts_to = f"{cam}/to_timestamp"
|
||||
|
||||
episodes = []
|
||||
for _, row in ep_df.iterrows():
|
||||
ci = int(row[chunk_col])
|
||||
fi = int(row[file_col])
|
||||
episodes.append(
|
||||
{
|
||||
"episode_index": int(row["episode_index"]),
|
||||
"chunk_index": ci,
|
||||
"file_index": fi,
|
||||
"from_ts": float(row[ts_from]),
|
||||
"to_ts": float(row[ts_to]),
|
||||
"video_rel": video_template.format(video_key=cam, chunk_index=ci, file_index=fi),
|
||||
}
|
||||
)
|
||||
return cam, episodes, fps
|
||||
|
||||
|
||||
def pick_random_frames(episodes: list[dict], fps: int, n: int, rng: random.Random) -> list[dict]:
|
||||
"""Pick n random (episode, timestamp) pairs, return sorted by video file for efficient access."""
|
||||
picks = []
|
||||
for _ in range(n):
|
||||
ep = rng.choice(episodes)
|
||||
duration = ep["to_ts"] - ep["from_ts"]
|
||||
if duration <= 0:
|
||||
continue
|
||||
t = ep["from_ts"] + rng.random() * duration
|
||||
picks.append({**ep, "seek_ts": t})
|
||||
picks.sort(key=lambda p: (p["video_rel"], p["seek_ts"]))
|
||||
return picks
|
||||
|
||||
|
||||
def download_video_files(repo_id: str, local: Path, picks: list[dict]) -> None:
|
||||
"""Download only the video files we need."""
|
||||
needed = sorted({p["video_rel"] for p in picks})
|
||||
print(f"[2/3] Downloading {len(needed)} video file(s) …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=needed,
|
||||
)
|
||||
|
||||
|
||||
def extract_frame(video_path: Path, seek_ts: float) -> np.ndarray | None:
|
||||
"""Decode a single frame at the given timestamp."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
cap.set(cv2.CAP_PROP_POS_MSEC, seek_ts * 1000.0)
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
return frame if ret else None
|
||||
|
||||
|
||||
def build_grid(frames: list[np.ndarray], cols: int, thumb_w: int) -> np.ndarray:
|
||||
"""Resize frames to uniform thumbnails and tile into a grid."""
|
||||
if not frames:
|
||||
raise RuntimeError("No frames decoded")
|
||||
|
||||
h0, w0 = frames[0].shape[:2]
|
||||
thumb_h = int(thumb_w * h0 / w0)
|
||||
|
||||
thumbs = [cv2.resize(f, (thumb_w, thumb_h), interpolation=cv2.INTER_AREA) for f in frames]
|
||||
|
||||
rows = []
|
||||
for i in range(0, len(thumbs), cols):
|
||||
row_thumbs = thumbs[i : i + cols]
|
||||
while len(row_thumbs) < cols:
|
||||
row_thumbs.append(np.zeros_like(row_thumbs[0]))
|
||||
rows.append(np.hstack(row_thumbs))
|
||||
return np.vstack(rows)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
rng = random.Random(SEED)
|
||||
n_frames = GRID_COLS * GRID_ROWS
|
||||
|
||||
local = download_metadata(REPO_ID)
|
||||
cam, episodes, fps = load_video_info(local)
|
||||
picks = pick_random_frames(episodes, fps, n_frames, rng)
|
||||
download_video_files(REPO_ID, local, picks)
|
||||
|
||||
print(f"[3/3] Decoding {n_frames} frames …")
|
||||
frames: list[np.ndarray] = []
|
||||
for p in picks:
|
||||
vp = local / p["video_rel"]
|
||||
if not vp.exists():
|
||||
print(f" SKIP: {p['video_rel']} not found")
|
||||
continue
|
||||
frame = extract_frame(vp, p["seek_ts"])
|
||||
if frame is not None:
|
||||
frames.append(frame)
|
||||
|
||||
print(f" Decoded {len(frames)}/{n_frames} frames")
|
||||
grid = build_grid(frames, GRID_COLS, THUMB_WIDTH)
|
||||
|
||||
safe_name = REPO_ID.replace("/", "_")
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_grid_{GRID_COLS}x{GRID_ROWS}.jpg"
|
||||
cv2.imwrite(str(out_path), grid, [cv2.IMWRITE_JPEG_QUALITY, 92])
|
||||
print(f"\n✓ Saved: {out_path} ({grid.shape[1]}×{grid.shape[0]})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
526
examples/dataset/visualization_tools/create_progress_videos.py
Normal file
526
examples/dataset/visualization_tools/create_progress_videos.py
Normal file
@@ -0,0 +1,526 @@
|
||||
"""
|
||||
Create MP4 videos with sarm_progress overlay for specified episodes.
|
||||
Downloads datasets from HuggingFace, extracts episode video + progress data,
|
||||
and draws the progress line directly on each frame (no panel, no axes).
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "episode": 250},
|
||||
]
|
||||
CAMERA_KEY = (
|
||||
"observation.images.base" # None = auto-select first camera, or set e.g. "observation.images.top"
|
||||
)
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
# Progress line spans the full video height
|
||||
GRAPH_Y_TOP_FRAC = 0.01
|
||||
GRAPH_Y_BOT_FRAC = 0.99
|
||||
LINE_THICKNESS = 3
|
||||
SHADOW_THICKNESS = 6 # white edge thickness
|
||||
REF_ALPHA = 0.45 # opacity of the 1.0 reference line
|
||||
FILL_ALPHA = 0.55 # opacity of the grey fill under the line
|
||||
SCORE_FONT_SCALE = 0.8
|
||||
TASK_FONT_SCALE = 0.55
|
||||
|
||||
|
||||
def download_episode(repo_id: str, episode: int) -> Path:
|
||||
"""Download only the files needed for this episode."""
|
||||
# We need: meta/, sarm_progress.parquet, and the relevant video/data chunks.
|
||||
# We'll download meta + sarm first, then figure out chunks.
|
||||
print(f"\n[1/5] Downloading metadata for {repo_id} …")
|
||||
local = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||
ignore_patterns=["*.mp4"],
|
||||
)
|
||||
)
|
||||
return local
|
||||
|
||||
|
||||
def load_episode_meta(local: Path, episode: int) -> dict:
|
||||
"""Read info.json + episode-level parquet to get fps, video paths, timestamps."""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
fps = info["fps"]
|
||||
features = info["features"]
|
||||
|
||||
# Find video keys (keys whose dtype=="video")
|
||||
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||
if not video_keys:
|
||||
raise RuntimeError("No video keys found in dataset features")
|
||||
if CAMERA_KEY is not None:
|
||||
if CAMERA_KEY not in video_keys:
|
||||
raise RuntimeError(f"CAMERA_KEY='{CAMERA_KEY}' not found. Available: {video_keys}")
|
||||
first_cam = CAMERA_KEY
|
||||
else:
|
||||
first_cam = video_keys[0]
|
||||
print(f" fps={fps} camera='{first_cam}' all_cams={video_keys}")
|
||||
|
||||
# Load all episode-meta parquet files and find our episode
|
||||
ep_rows = []
|
||||
for pq in sorted((local / "meta" / "episodes").glob("**/*.parquet")):
|
||||
df = pd.read_parquet(pq)
|
||||
ep_rows.append(df)
|
||||
ep_df = pd.concat(ep_rows, ignore_index=True)
|
||||
row = ep_df[ep_df["episode_index"] == episode]
|
||||
if row.empty:
|
||||
raise RuntimeError(f"Episode {episode} not found in episode metadata")
|
||||
row = row.iloc[0]
|
||||
|
||||
# Extract video chunk/file index for first camera
|
||||
# Try both dot and slash variants of the key
|
||||
chunk_col = f"videos/{first_cam}/chunk_index"
|
||||
file_col = f"videos/{first_cam}/file_index"
|
||||
ts_col = f"videos/{first_cam}/from_timestamp"
|
||||
to_col = f"videos/{first_cam}/to_timestamp"
|
||||
|
||||
# Some datasets use different column naming
|
||||
if chunk_col not in row.index:
|
||||
# Try without the 'videos/' prefix
|
||||
chunk_col = f"{first_cam}/chunk_index"
|
||||
file_col = f"{first_cam}/file_index"
|
||||
ts_col = f"{first_cam}/from_timestamp"
|
||||
to_col = f"{first_cam}/to_timestamp"
|
||||
if chunk_col not in row.index:
|
||||
raise RuntimeError(
|
||||
f"Cannot find video metadata columns for {first_cam}.\nAvailable: {list(row.index)}"
|
||||
)
|
||||
|
||||
chunk_idx = int(row[chunk_col])
|
||||
file_idx = int(row[file_col])
|
||||
from_ts = float(row[ts_col])
|
||||
to_ts = float(row[to_col])
|
||||
|
||||
video_template = info.get(
|
||||
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
|
||||
)
|
||||
video_rel = video_template.format(
|
||||
video_key=first_cam,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
# Load task name for this episode
|
||||
# tasks.parquet uses the task string as the row index; task_index column holds the int id
|
||||
task_name = ""
|
||||
try:
|
||||
# Prefer the 'tasks' list directly on the episode row
|
||||
if "tasks" in row.index and row["tasks"] is not None:
|
||||
tasks_val = row["tasks"]
|
||||
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
|
||||
task_name = str(tasks_val[0])
|
||||
else:
|
||||
task_name = str(tasks_val).strip("[]'")
|
||||
else:
|
||||
tasks_pq = local / "meta" / "tasks.parquet"
|
||||
if tasks_pq.exists():
|
||||
tasks_df = pd.read_parquet(tasks_pq)
|
||||
# Row index is the task string; task_index column is the int
|
||||
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
|
||||
match = tasks_df[tasks_df["task_index"] == task_idx]
|
||||
if not match.empty:
|
||||
task_name = str(match.index[0])
|
||||
print(f" Task name: '{task_name}'")
|
||||
except Exception as e:
|
||||
print(f" WARNING: could not load task name: {e}")
|
||||
|
||||
return {
|
||||
"fps": fps,
|
||||
"first_cam": first_cam,
|
||||
"video_rel": video_rel,
|
||||
"chunk_index": chunk_idx,
|
||||
"file_index": file_idx,
|
||||
"from_ts": from_ts,
|
||||
"to_ts": to_ts,
|
||||
"task_name": task_name,
|
||||
}
|
||||
|
||||
|
||||
def download_video(repo_id: str, local: Path, video_rel: str) -> Path:
|
||||
"""Download the specific video file if not already present."""
|
||||
video_path = local / video_rel
|
||||
if video_path.exists():
|
||||
print(f" Video already cached: {video_path}")
|
||||
return video_path
|
||||
print(f"[2/5] Downloading video file {video_rel} …")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=str(local),
|
||||
allow_patterns=[video_rel],
|
||||
)
|
||||
if not video_path.exists():
|
||||
raise RuntimeError(f"Video not found after download: {video_path}")
|
||||
return video_path
|
||||
|
||||
|
||||
def load_progress(local: Path, episode: int) -> np.ndarray | None:
|
||||
"""Load sarm_progress values for this episode. Returns sorted array of (frame_index, progress)."""
|
||||
pq_path = local / "sarm_progress.parquet"
|
||||
if not pq_path.exists():
|
||||
print(" WARNING: sarm_progress.parquet not found, trying data parquet …")
|
||||
return None
|
||||
df = pd.read_parquet(pq_path)
|
||||
print(f" sarm_progress.parquet columns: {list(df.columns)}")
|
||||
ep_df = df[df["episode_index"] == episode].copy()
|
||||
if ep_df.empty:
|
||||
print(f" WARNING: No sarm_progress rows for episode {episode}")
|
||||
return None
|
||||
ep_df = ep_df.sort_values("frame_index")
|
||||
|
||||
# Prefer dense, fall back to sparse
|
||||
if "progress_dense" in ep_df.columns and ep_df["progress_dense"].notna().any():
|
||||
prog_col = "progress_dense"
|
||||
elif "progress_sparse" in ep_df.columns:
|
||||
prog_col = "progress_sparse"
|
||||
else:
|
||||
# Last resort: any column with 'progress' in the name
|
||||
prog_cols = [c for c in ep_df.columns if "progress" in c.lower()]
|
||||
if not prog_cols:
|
||||
return None
|
||||
prog_col = prog_cols[0]
|
||||
|
||||
print(f" Using progress column: '{prog_col}'")
|
||||
return ep_df[["frame_index", prog_col]].rename(columns={prog_col: "progress"}).values
|
||||
|
||||
|
||||
def extract_episode_clip(video_path: Path, from_ts: float, to_ts: float, out_path: Path) -> Path:
|
||||
"""Use ffmpeg to cut the episode segment from the combined video file."""
|
||||
duration = to_ts - from_ts
|
||||
print(f"[3/5] Extracting clip [{from_ts:.3f}s → {to_ts:.3f}s] ({duration:.2f}s) …")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-ss",
|
||||
str(from_ts),
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-t",
|
||||
str(duration),
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-preset",
|
||||
"fast",
|
||||
"-crf",
|
||||
"18",
|
||||
"-an",
|
||||
str(out_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffmpeg clip extraction failed:\n{result.stderr}")
|
||||
return out_path
|
||||
|
||||
|
||||
def precompute_pixels(
|
||||
progress_data: np.ndarray,
|
||||
n_frames: int,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Map each progress sample to pixel coordinates.
|
||||
Returns array of shape (N, 2) with (x, y) in pixel space.
|
||||
x spans full video width; y maps progress [0,1] to graph band.
|
||||
"""
|
||||
frame_indices = progress_data[:, 0].astype(float)
|
||||
progress_vals = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
|
||||
|
||||
y_top = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
graph_h = y_bot - y_top
|
||||
|
||||
xs = (frame_indices / (n_frames - 1) * (frame_w - 1)).astype(int)
|
||||
# progress=1 → y_top, progress=0 → y_bot
|
||||
ys = (y_bot - progress_vals * graph_h).astype(int)
|
||||
|
||||
return np.stack([xs, ys], axis=1) # (N, 2)
|
||||
|
||||
|
||||
def progress_color(t: float) -> tuple[int, int, int]:
|
||||
"""Interpolate BGR color red→green based on normalised position t in [0,1]."""
|
||||
r = int(255 * (1.0 - t))
|
||||
g = int(255 * t)
|
||||
return (0, g, r) # BGR
|
||||
|
||||
|
||||
def prerender_fill(
|
||||
pixels: np.ndarray,
|
||||
frame_w: int,
|
||||
frame_h: int,
|
||||
) -> np.ndarray:
|
||||
"""Pre-render the full grey fill polygon under the curve as a BGRA image."""
|
||||
y_bot = int(frame_h * GRAPH_Y_BOT_FRAC)
|
||||
fill_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
poly = np.concatenate(
|
||||
[
|
||||
pixels,
|
||||
[[pixels[-1][0], y_bot], [pixels[0][0], y_bot]],
|
||||
],
|
||||
axis=0,
|
||||
).astype(np.int32)
|
||||
cv2.fillPoly(fill_img, [poly], color=(128, 128, 128, int(255 * FILL_ALPHA)))
|
||||
return fill_img
|
||||
|
||||
|
||||
def alpha_composite(base: np.ndarray, overlay_bgra: np.ndarray, x_max: int) -> None:
|
||||
"""Blend overlay onto base in-place, but only for x < x_max."""
|
||||
if x_max <= 0:
|
||||
return
|
||||
roi_b = base[:, :x_max]
|
||||
roi_o = overlay_bgra[:, :x_max]
|
||||
alpha = roi_o[:, :, 3:4].astype(np.float32) / 255.0
|
||||
roi_b[:] = np.clip(
|
||||
roi_o[:, :, :3].astype(np.float32) * alpha + roi_b.astype(np.float32) * (1.0 - alpha),
|
||||
0,
|
||||
255,
|
||||
).astype(np.uint8)
|
||||
|
||||
|
||||
def draw_text_outlined(
|
||||
frame: np.ndarray,
|
||||
text: str,
|
||||
pos: tuple[int, int],
|
||||
font_scale: float,
|
||||
thickness: int = 1,
|
||||
) -> None:
|
||||
"""Draw text with a dark outline for readability on any background."""
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
cv2.putText(frame, text, pos, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
|
||||
cv2.putText(frame, text, pos, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||
|
||||
|
||||
def composite_video(
|
||||
clip_path: Path,
|
||||
progress_data: np.ndarray,
|
||||
out_path: Path,
|
||||
fps: float,
|
||||
frame_h: int,
|
||||
frame_w: int,
|
||||
task_name: str = "",
|
||||
) -> Path:
|
||||
"""Read clip frames, draw gradient progress line with fill + labels, export as GIF."""
|
||||
n_total = int(cv2.VideoCapture(str(clip_path)).get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
pixels = precompute_pixels(progress_data, n_total, frame_w, frame_h)
|
||||
|
||||
y_ref = int(frame_h * GRAPH_Y_TOP_FRAC)
|
||||
|
||||
# Pre-render fill polygon (line is drawn per-frame with live color)
|
||||
fill_img = prerender_fill(pixels, frame_w, frame_h)
|
||||
|
||||
# 1.0 reference line overlay (full width, drawn once)
|
||||
ref_img = np.zeros((frame_h, frame_w, 4), dtype=np.uint8)
|
||||
cv2.line(ref_img, (0, y_ref), (frame_w - 1, y_ref), (200, 200, 200, int(255 * REF_ALPHA)), 1, cv2.LINE_AA)
|
||||
|
||||
frame_indices = progress_data[:, 0].astype(int)
|
||||
progress_vals = progress_data[:, 1].astype(float)
|
||||
|
||||
print(f"[4/4] Compositing {n_total} frames …")
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
tmp_path = out_path.parent / (out_path.stem + "_tmp.mp4")
|
||||
writer = cv2.VideoWriter(str(tmp_path), fourcc, fps, (frame_w, frame_h))
|
||||
|
||||
fi = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
n_drawn = int(np.searchsorted(frame_indices, fi, side="right"))
|
||||
x_cur = int(pixels[min(n_drawn, len(pixels)) - 1][0]) + 1 if n_drawn > 0 else 0
|
||||
|
||||
# 1. reference line (full width, always)
|
||||
alpha_composite(frame, ref_img, frame_w)
|
||||
|
||||
# 2. grey fill under curve up to current x
|
||||
alpha_composite(frame, fill_img, x_cur)
|
||||
|
||||
# 3. progress line — single color that transitions red→green over time
|
||||
if n_drawn >= 2:
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
line_col = progress_color(t_cur)
|
||||
pts = pixels[:n_drawn].reshape(-1, 1, 2).astype(np.int32)
|
||||
cv2.polylines(
|
||||
frame,
|
||||
[pts],
|
||||
isClosed=False,
|
||||
color=(255, 255, 255),
|
||||
thickness=SHADOW_THICKNESS,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
cv2.polylines(
|
||||
frame, [pts], isClosed=False, color=line_col, thickness=LINE_THICKNESS, lineType=cv2.LINE_AA
|
||||
)
|
||||
|
||||
# 4. score — bottom right
|
||||
if n_drawn > 0:
|
||||
score = float(progress_vals[min(n_drawn, len(progress_vals)) - 1])
|
||||
score_text = f"{score:.2f}"
|
||||
(tw, th), _ = cv2.getTextSize(score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2)
|
||||
sx = frame_w - tw - 12
|
||||
sy = frame_h - 12
|
||||
# coloured score matching current gradient position
|
||||
t_cur = (n_drawn - 1) / max(len(progress_vals) - 1, 1)
|
||||
score_col = progress_color(t_cur)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
(0, 0, 0),
|
||||
4,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
cv2.putText(
|
||||
frame,
|
||||
score_text,
|
||||
(sx, sy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
SCORE_FONT_SCALE,
|
||||
score_col,
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
# 5. task name — top centre
|
||||
if task_name:
|
||||
(tw, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
|
||||
tx = max((frame_w - tw) // 2, 4)
|
||||
draw_text_outlined(frame, task_name, (tx, 22), TASK_FONT_SCALE)
|
||||
|
||||
writer.write(frame)
|
||||
fi += 1
|
||||
if fi % 100 == 0:
|
||||
print(f" Frame {fi}/{n_total} …", end="\r")
|
||||
|
||||
cap.release()
|
||||
writer.release()
|
||||
print()
|
||||
|
||||
# Convert to GIF: full resolution, 12fps, 128-color diff palette (<40MB)
|
||||
gif_path = out_path.with_suffix(".gif")
|
||||
palette = out_path.parent / "_palette.png"
|
||||
r1 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-vf",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
|
||||
"-update",
|
||||
"1",
|
||||
str(palette),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r1.returncode != 0:
|
||||
print(f" WARNING: palettegen failed:\n{r1.stderr[-500:]}")
|
||||
r2 = subprocess.run( # nosec B607
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(tmp_path),
|
||||
"-i",
|
||||
str(palette),
|
||||
"-filter_complex",
|
||||
f"fps=10,scale={frame_w}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
|
||||
str(gif_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if r2.returncode != 0:
|
||||
print(f" WARNING: gif encode failed:\n{r2.stderr[-500:]}")
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
palette.unlink(missing_ok=True)
|
||||
return gif_path
|
||||
|
||||
|
||||
def process_dataset(repo_id: str, episode: int):
|
||||
safe_name = repo_id.replace("/", "_")
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Processing: {repo_id} | episode {episode}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 1. Download metadata
|
||||
local = download_episode(repo_id, episode)
|
||||
print(f" Local cache: {local}")
|
||||
|
||||
# 2. Read episode metadata
|
||||
ep_meta = load_episode_meta(local, episode)
|
||||
print(f" Episode meta: {ep_meta}")
|
||||
|
||||
# 3. Download video file
|
||||
video_path = download_video(repo_id, local, ep_meta["video_rel"])
|
||||
|
||||
# 4. Extract clip
|
||||
clip_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_clip.mp4"
|
||||
extract_episode_clip(video_path, ep_meta["from_ts"], ep_meta["to_ts"], clip_path)
|
||||
|
||||
# 5. Load progress data
|
||||
progress_data = load_progress(local, episode)
|
||||
if progress_data is None:
|
||||
print(" ERROR: Could not load sarm_progress data. Skipping overlay.")
|
||||
return
|
||||
|
||||
n_progress = len(progress_data)
|
||||
print(f" Progress frames: {n_progress}")
|
||||
|
||||
# 6. Get clip dimensions
|
||||
cap = cv2.VideoCapture(str(clip_path))
|
||||
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
actual_fps = cap.get(cv2.CAP_PROP_FPS) or ep_meta["fps"]
|
||||
cap.release()
|
||||
print(f" Clip: {frame_w}×{frame_h} {n_frames} frames @ {actual_fps:.1f}fps")
|
||||
|
||||
# 7. Composite (draw line directly on frames)
|
||||
out_path = OUTPUT_DIR / f"{safe_name}_ep{episode}_progress.mp4"
|
||||
final = composite_video(
|
||||
clip_path,
|
||||
progress_data,
|
||||
out_path,
|
||||
actual_fps,
|
||||
frame_h,
|
||||
frame_w,
|
||||
task_name=ep_meta.get("task_name", ""),
|
||||
)
|
||||
clip_path.unlink(missing_ok=True)
|
||||
print(f"\n✓ Done: {final}")
|
||||
return final
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = []
|
||||
for cfg in DATASETS:
|
||||
try:
|
||||
out = process_dataset(cfg["repo_id"], cfg["episode"])
|
||||
if out:
|
||||
results.append(out)
|
||||
except Exception as e:
|
||||
print(f"\nERROR processing {cfg['repo_id']}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Output files:")
|
||||
for r in results:
|
||||
print(f" {r}")
|
||||
496
examples/dataset/visualization_tools/workspace_density.py
Normal file
496
examples/dataset/visualization_tools/workspace_density.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""
|
||||
Visualize end-effector workspace density and trajectory clusters for OpenArm datasets.
|
||||
Downloads joint position data (no videos) from HuggingFace, computes forward
|
||||
kinematics per episode, clusters trajectories with K-means, and renders
|
||||
2D projections comparing dataset coverage and multimodality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
DATASETS = [
|
||||
{"repo_id": "lerobot-data-collection/level2_final_quality3", "label": "HQ curated"},
|
||||
{"repo_id": "lerobot-data-collection/level12_rac_2_2026-02-08_1", "label": "Full collection"},
|
||||
]
|
||||
OUTPUT_DIR = Path(__file__).resolve().parent / "outputs"
|
||||
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||
|
||||
N_CLUSTERS = 10
|
||||
WAYPOINTS = 50
|
||||
SEED = 42
|
||||
DPI = 180
|
||||
|
||||
CLUSTER_COLORS = [
|
||||
"#e6194b",
|
||||
"#3cb44b",
|
||||
"#4363d8",
|
||||
"#f58231",
|
||||
"#911eb4",
|
||||
"#42d4f4",
|
||||
"#f032e6",
|
||||
"#bfef45",
|
||||
"#fabed4",
|
||||
"#dcbeff",
|
||||
"#9a6324",
|
||||
"#fffac8",
|
||||
"#800000",
|
||||
"#aaffc3",
|
||||
"#808000",
|
||||
"#ffd8b1",
|
||||
"#000075",
|
||||
"#a9a9a9",
|
||||
]
|
||||
|
||||
# FK chains extracted from OpenArm bimanual URDF.
|
||||
# Each entry: (rpy, xyz, revolute_axis_or_None).
|
||||
LEFT_CHAIN = [
|
||||
((-np.pi / 2, 0, 0), (0, 0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((-np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, -1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
RIGHT_CHAIN = [
|
||||
((np.pi / 2, 0, 0), (0, -0.031, 0.698), None),
|
||||
((0, 0, 0), (0, 0, 0.0625), (0, 0, 1)),
|
||||
((np.pi / 2, 0, 0), (-0.0301, 0, 0.06), (-1, 0, 0)),
|
||||
((0, 0, 0), (0.0301, 0, 0.06625), (0, 0, 1)),
|
||||
((0, 0, 0), (0, 0.0315, 0.15375), (0, 1, 0)),
|
||||
((0, 0, 0), (0, -0.0315, 0.0955), (0, 0, 1)),
|
||||
((0, 0, 0), (0.0375, 0, 0.1205), (1, 0, 0)),
|
||||
((0, 0, 0), (-0.0375, 0, 0), (0, 1, 0)),
|
||||
((0, 0, 0), (0, 0, 0.1001), None),
|
||||
((0, 0, 0), (0, 0, 0.08), None),
|
||||
]
|
||||
|
||||
|
||||
# ── FK math ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def _rot_x(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[1, 0, 0], [0, c, -s], [0, s, c]])
|
||||
|
||||
|
||||
def _rot_y(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
|
||||
|
||||
|
||||
def _rot_z(a: float) -> np.ndarray:
|
||||
c, s = np.cos(a), np.sin(a)
|
||||
return np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]])
|
||||
|
||||
|
||||
def _tf(rpy: tuple, xyz: tuple) -> np.ndarray:
|
||||
"""Build a 4x4 homogeneous transform from URDF rpy + xyz."""
|
||||
r, p, y = rpy
|
||||
mat = np.eye(4)
|
||||
mat[:3, :3] = _rot_z(y) @ _rot_y(p) @ _rot_x(r)
|
||||
mat[:3, 3] = xyz
|
||||
return mat
|
||||
|
||||
|
||||
def _batch_axis_rot(axis: tuple, angles: np.ndarray) -> np.ndarray:
|
||||
"""Batched Rodrigues rotation: (n,) angles around a fixed axis → (n, 4, 4)."""
|
||||
n = len(angles)
|
||||
ax = np.asarray(axis, dtype=np.float64)
|
||||
ax = ax / np.linalg.norm(ax)
|
||||
x, y, z = ax
|
||||
c = np.cos(angles)
|
||||
s = np.sin(angles)
|
||||
t = 1 - c
|
||||
rot = np.zeros((n, 4, 4))
|
||||
rot[:, 0, 0] = t * x * x + c
|
||||
rot[:, 0, 1] = t * x * y - s * z
|
||||
rot[:, 0, 2] = t * x * z + s * y
|
||||
rot[:, 1, 0] = t * x * y + s * z
|
||||
rot[:, 1, 1] = t * y * y + c
|
||||
rot[:, 1, 2] = t * y * z - s * x
|
||||
rot[:, 2, 0] = t * x * z - s * y
|
||||
rot[:, 2, 1] = t * y * z + s * x
|
||||
rot[:, 2, 2] = t * z * z + c
|
||||
rot[:, 3, 3] = 1.0
|
||||
return rot
|
||||
|
||||
|
||||
def batch_fk(chain: list, joint_angles: np.ndarray) -> np.ndarray:
|
||||
"""Vectorized FK: (n, 7) radians → (n, 3) TCP positions in world frame."""
|
||||
n = joint_angles.shape[0]
|
||||
tf_batch = np.tile(np.eye(4), (n, 1, 1))
|
||||
qi = 0
|
||||
for rpy, xyz, axis in chain:
|
||||
tf_batch = tf_batch @ _tf(rpy, xyz)
|
||||
if axis is not None:
|
||||
rot = _batch_axis_rot(axis, joint_angles[:, qi])
|
||||
tf_batch = np.einsum("nij,njk->nik", tf_batch, rot)
|
||||
qi += 1
|
||||
return tf_batch[:, :3, 3]
|
||||
|
||||
|
||||
# ── Data loading ────────────────────────────────────────
|
||||
|
||||
|
||||
def _flatten_names(obj: object) -> list[str]:
|
||||
"""Recursively flatten a names structure (list, dict, or nested) into a flat string list."""
|
||||
if isinstance(obj, dict):
|
||||
out: list[str] = []
|
||||
for v in obj.values():
|
||||
out.extend(_flatten_names(v))
|
||||
return out
|
||||
if isinstance(obj, (list, tuple)):
|
||||
out = []
|
||||
for item in obj:
|
||||
if isinstance(item, (list, tuple, dict)):
|
||||
out.extend(_flatten_names(item))
|
||||
else:
|
||||
out.append(str(item))
|
||||
return out
|
||||
return [str(obj)]
|
||||
|
||||
|
||||
def _detect_and_convert(vals: np.ndarray) -> np.ndarray:
|
||||
"""Auto-detect servo ticks / degrees / radians and convert to radians."""
|
||||
mx = np.max(np.abs(vals))
|
||||
if mx > 360:
|
||||
print(f" Unit detection: servo ticks (max={mx:.0f})")
|
||||
return (vals - 2048) / 2048 * np.pi
|
||||
if mx > 6.3:
|
||||
print(f" Unit detection: degrees (max={mx:.1f})")
|
||||
return np.deg2rad(vals)
|
||||
print(f" Unit detection: radians (max={mx:.3f})")
|
||||
return vals.astype(np.float64)
|
||||
|
||||
|
||||
def _find_joint_indices(features: dict, state_col: str, n_dim: int) -> tuple[list[int], list[int]]:
|
||||
"""Try to find left/right joint indices from info.json feature names."""
|
||||
feat = features.get("observation.state", features.get(state_col, {}))
|
||||
names = _flatten_names(feat.get("names", []))
|
||||
|
||||
left_idx: list[int] = []
|
||||
right_idx: list[int] = []
|
||||
if names and len(names) == n_dim:
|
||||
names_l = [n.lower() for n in names]
|
||||
print(f" Feature names: {names[:4]}…{names[-4:]}")
|
||||
for j in range(1, 8):
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"left_joint_{j}" in nm and i not in left_idx:
|
||||
left_idx.append(i)
|
||||
break
|
||||
for i, nm in enumerate(names_l):
|
||||
if f"right_joint_{j}" in nm and i not in right_idx:
|
||||
right_idx.append(i)
|
||||
break
|
||||
|
||||
if len(left_idx) == 7 and len(right_idx) == 7:
|
||||
print(f" Matched by name: left={left_idx} right={right_idx}")
|
||||
return left_idx, right_idx
|
||||
if n_dim >= 16:
|
||||
print(" Falling back to positional: [0:7]=left, [8:15]=right")
|
||||
return list(range(7)), list(range(8, 15))
|
||||
if n_dim >= 14:
|
||||
print(" Falling back to positional: [0:7]=left, [7:14]=right")
|
||||
return list(range(7)), list(range(7, 14))
|
||||
raise RuntimeError(f"State dim {n_dim} too small for bimanual 7-DOF robot")
|
||||
|
||||
|
||||
def download_data(repo_id: str) -> Path:
|
||||
print(f" Downloading {repo_id} (parquet only) …")
|
||||
return Path(
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
allow_patterns=["meta/**", "data/**"],
|
||||
ignore_patterns=["*.mp4", "videos/**"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def resample_trajectory(traj: np.ndarray, n_waypoints: int) -> np.ndarray:
|
||||
"""Resample a (F, 3) trajectory to exactly n_waypoints via linear interpolation."""
|
||||
f = traj.shape[0]
|
||||
if f == n_waypoints:
|
||||
return traj
|
||||
old_t = np.linspace(0, 1, f)
|
||||
new_t = np.linspace(0, 1, n_waypoints)
|
||||
return np.column_stack([np.interp(new_t, old_t, traj[:, d]) for d in range(3)])
|
||||
|
||||
|
||||
def load_episode_trajectories(local: Path) -> list[dict]:
|
||||
"""
|
||||
Load per-episode joint data, compute FK, return list of trajectory dicts.
|
||||
Each dict: {"left_tcp": (F,3), "right_tcp": (F,3), "episode_index": int}.
|
||||
Uses all episodes in the dataset for a fair comparison.
|
||||
"""
|
||||
info = json.loads((local / "meta" / "info.json").read_text())
|
||||
features = info.get("features", {})
|
||||
|
||||
dfs = [pd.read_parquet(pq) for pq in sorted((local / "data").glob("**/*.parquet"))]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
print(f" Total frames: {len(df):,}")
|
||||
|
||||
state_col = next((c for c in df.columns if "observation.state" in c), None)
|
||||
if state_col is None:
|
||||
raise RuntimeError(f"No observation.state column. Available: {list(df.columns)}")
|
||||
|
||||
first = df[state_col].iloc[0]
|
||||
if not hasattr(first, "__len__"):
|
||||
raise RuntimeError(f"observation.state is scalar ({type(first)}), expected array")
|
||||
|
||||
state = np.stack(df[state_col].values).astype(np.float64)
|
||||
n_dim = state.shape[1]
|
||||
print(f" State dim: {n_dim} max|val|: {np.max(np.abs(state)):.1f}")
|
||||
|
||||
left_idx, right_idx = _find_joint_indices(features, state_col, n_dim)
|
||||
|
||||
ep_col = next((c for c in df.columns if c == "episode_index"), None)
|
||||
if ep_col is None:
|
||||
raise RuntimeError(f"No episode_index column. Available: {list(df.columns)}")
|
||||
|
||||
episode_ids = df[ep_col].values
|
||||
unique_eps = np.unique(episode_ids)
|
||||
print(f" Episodes: {len(unique_eps):,}")
|
||||
|
||||
left_raw = state[:, left_idx]
|
||||
right_raw = state[:, right_idx]
|
||||
left_all = _detect_and_convert(left_raw)
|
||||
right_all = _detect_and_convert(right_raw)
|
||||
|
||||
print(" Computing FK per episode …")
|
||||
trajectories = []
|
||||
for ep_id in unique_eps:
|
||||
mask = episode_ids == ep_id
|
||||
left_tcp = batch_fk(LEFT_CHAIN, left_all[mask])
|
||||
right_tcp = batch_fk(RIGHT_CHAIN, right_all[mask])
|
||||
if len(left_tcp) < 3:
|
||||
continue
|
||||
trajectories.append({"left_tcp": left_tcp, "right_tcp": right_tcp, "episode_index": int(ep_id)})
|
||||
|
||||
print(f" Valid trajectories: {len(trajectories):,}")
|
||||
return trajectories
|
||||
|
||||
|
||||
# ── Clustering ──────────────────────────────────────────
|
||||
|
||||
|
||||
def cluster_trajectories(
|
||||
trajectories: list[dict], n_clusters: int, n_waypoints: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
K-means on resampled trajectory features.
|
||||
Combines left+right TCP into a single feature vector per episode.
|
||||
Returns (labels, centroid_trajs (k, waypoints, 6), spread_per_cluster (k,) in metres).
|
||||
Spread = mean per-waypoint Euclidean distance from each trajectory to its centroid.
|
||||
"""
|
||||
feat_vecs = []
|
||||
for t in trajectories:
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
feat_vecs.append(np.concatenate([left_rs.ravel(), right_rs.ravel()]))
|
||||
feat_matrix = np.array(feat_vecs)
|
||||
|
||||
k = min(n_clusters, len(feat_vecs))
|
||||
km = KMeans(n_clusters=k, n_init=10, random_state=SEED)
|
||||
labels = km.fit_predict(feat_matrix)
|
||||
|
||||
centroids_flat = km.cluster_centers_
|
||||
centroid_trajs = np.zeros((k, n_waypoints, 6))
|
||||
for ci in range(k):
|
||||
left_flat = centroids_flat[ci, : n_waypoints * 3]
|
||||
right_flat = centroids_flat[ci, n_waypoints * 3 :]
|
||||
centroid_trajs[ci, :, :3] = left_flat.reshape(n_waypoints, 3)
|
||||
centroid_trajs[ci, :, 3:] = right_flat.reshape(n_waypoints, 3)
|
||||
|
||||
# Mean per-waypoint distance to centroid (in metres) for each cluster
|
||||
spread = np.zeros(k)
|
||||
for ci in range(k):
|
||||
members = np.where(labels == ci)[0]
|
||||
if len(members) == 0:
|
||||
continue
|
||||
centroid_left = centroid_trajs[ci, :, :3]
|
||||
centroid_right = centroid_trajs[ci, :, 3:]
|
||||
dists = []
|
||||
for mi in members:
|
||||
t = trajectories[mi]
|
||||
left_rs = resample_trajectory(t["left_tcp"], n_waypoints)
|
||||
right_rs = resample_trajectory(t["right_tcp"], n_waypoints)
|
||||
d_left = np.linalg.norm(left_rs - centroid_left, axis=1).mean()
|
||||
d_right = np.linalg.norm(right_rs - centroid_right, axis=1).mean()
|
||||
dists.append((d_left + d_right) / 2)
|
||||
spread[ci] = np.mean(dists)
|
||||
|
||||
return labels, centroid_trajs, spread
|
||||
|
||||
|
||||
# ── Visualization ───────────────────────────────────────
|
||||
|
||||
PROJ_VIEWS = [
|
||||
("XZ (side)", 0, 2, "X (m)", "Z (m)"),
|
||||
("XY (top)", 0, 1, "X (m)", "Y (m)"),
|
||||
("YZ (front)", 1, 2, "Y (m)", "Z (m)"),
|
||||
]
|
||||
|
||||
|
||||
def render(results: list[dict], out_path: Path) -> None:
|
||||
"""
|
||||
2-row × 3-col grid per dataset (3 projections × 2 datasets).
|
||||
Trajectory lines colored by cluster, centroid trajectories drawn thick.
|
||||
"""
|
||||
n_ds = len(results)
|
||||
n_proj = len(PROJ_VIEWS)
|
||||
fig, axes = plt.subplots(n_ds, n_proj, figsize=(7 * n_proj, 7 * n_ds), facecolor="#0d1117")
|
||||
if n_ds == 1:
|
||||
axes = axes[np.newaxis, :]
|
||||
|
||||
for row, r in enumerate(results):
|
||||
trajectories = r["trajectories"]
|
||||
labels = r["labels"]
|
||||
centroids = r["centroids"]
|
||||
k = centroids.shape[0]
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=k)
|
||||
size_order = np.argsort(-cluster_sizes)
|
||||
pcts = cluster_sizes / len(labels) * 100
|
||||
spread = r["spread"]
|
||||
|
||||
for col, (view_name, dim_a, dim_b, xlabel, ylabel) in enumerate(PROJ_VIEWS):
|
||||
ax = axes[row, col]
|
||||
ax.set_facecolor("#0d1117")
|
||||
|
||||
for ti, traj in enumerate(trajectories):
|
||||
color = CLUSTER_COLORS[labels[ti] % len(CLUSTER_COLORS)]
|
||||
for tcp_key in ("left_tcp", "right_tcp"):
|
||||
pts = traj[tcp_key]
|
||||
ax.plot(pts[:, dim_a], pts[:, dim_b], color=color, alpha=0.12, linewidth=0.4)
|
||||
|
||||
for ci in range(k):
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
left_c = centroids[ci, :, :3]
|
||||
right_c = centroids[ci, :, 3:]
|
||||
lw = 1.5 + 2.0 * cluster_sizes[ci] / cluster_sizes.max()
|
||||
for c_pts in (left_c, right_c):
|
||||
ax.plot(
|
||||
c_pts[:, dim_a],
|
||||
c_pts[:, dim_b],
|
||||
color=color,
|
||||
linewidth=lw,
|
||||
alpha=0.95,
|
||||
zorder=10,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[0, dim_a],
|
||||
c_pts[0, dim_b],
|
||||
"o",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
ax.plot(
|
||||
c_pts[-1, dim_a],
|
||||
c_pts[-1, dim_b],
|
||||
"s",
|
||||
color=color,
|
||||
markersize=4,
|
||||
zorder=11,
|
||||
)
|
||||
|
||||
ax.set_xlabel(xlabel, color="#888", fontsize=9)
|
||||
ax.set_ylabel(ylabel, color="#888", fontsize=9)
|
||||
ax.tick_params(colors="#555", labelsize=7)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#333")
|
||||
ax.set_aspect("equal")
|
||||
|
||||
mean_spread_cm = np.average(spread, weights=cluster_sizes) * 100
|
||||
if col == 0:
|
||||
ax.set_title(
|
||||
f"{r['label']} ({r['n_episodes']:,} episodes, {k} clusters, "
|
||||
f"avg spread {mean_spread_cm:.1f}cm)",
|
||||
color="white",
|
||||
fontsize=11,
|
||||
pad=10,
|
||||
)
|
||||
else:
|
||||
ax.set_title(view_name, color="#aaa", fontsize=10, pad=8)
|
||||
|
||||
# Cluster size + spread legend on the rightmost panel
|
||||
legend_ax = axes[row, -1]
|
||||
for ci in size_order:
|
||||
color = CLUSTER_COLORS[ci % len(CLUSTER_COLORS)]
|
||||
spread_cm = spread[ci] * 100
|
||||
label = f"C{ci}: {cluster_sizes[ci]} eps ({pcts[ci]:.0f}%) ±{spread_cm:.1f}cm"
|
||||
legend_ax.plot([], [], color=color, linewidth=3, label=label)
|
||||
legend_ax.legend(
|
||||
loc="upper right",
|
||||
fontsize=7,
|
||||
frameon=True,
|
||||
facecolor="#1a1a2e",
|
||||
edgecolor="#333",
|
||||
labelcolor="white",
|
||||
handlelength=1.5,
|
||||
)
|
||||
|
||||
fig.suptitle(
|
||||
"End-Effector Trajectory Clusters (FK · K-means)",
|
||||
color="white",
|
||||
fontsize=16,
|
||||
y=0.98,
|
||||
)
|
||||
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
||||
plt.savefig(out_path, dpi=DPI, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close()
|
||||
print(f"\n✓ Saved: {out_path}")
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main() -> None:
|
||||
results = []
|
||||
|
||||
for ds in DATASETS:
|
||||
repo_id, label = ds["repo_id"], ds["label"]
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f" {label}: {repo_id}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
local = download_data(repo_id)
|
||||
trajectories = load_episode_trajectories(local)
|
||||
labels, centroids, spread = cluster_trajectories(trajectories, N_CLUSTERS, WAYPOINTS)
|
||||
|
||||
cluster_sizes = np.bincount(labels, minlength=centroids.shape[0])
|
||||
print(f" Cluster sizes: {sorted(cluster_sizes, reverse=True)}")
|
||||
for ci in np.argsort(-cluster_sizes):
|
||||
print(
|
||||
f" C{ci}: {cluster_sizes[ci]} eps ({cluster_sizes[ci] / len(labels) * 100:.0f}%) "
|
||||
f"spread ±{spread[ci] * 100:.1f}cm"
|
||||
)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"label": label,
|
||||
"trajectories": trajectories,
|
||||
"labels": labels,
|
||||
"centroids": centroids,
|
||||
"spread": spread,
|
||||
"n_episodes": len(trajectories),
|
||||
}
|
||||
)
|
||||
|
||||
out = OUTPUT_DIR / "workspace_trajectory_clusters.jpg"
|
||||
render(results, out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -35,7 +35,9 @@ def main():
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
actions = dataset.select_columns(ACTION)
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -46,7 +48,7 @@ def main():
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
|
||||
@@ -67,7 +67,9 @@ def main():
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
actions = dataset.select_columns(ACTION)
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -78,7 +80,7 @@ def main():
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
|
||||
@@ -63,26 +63,6 @@ Usage:
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with bi_openarm_follower (dual-arm OpenArms) and pi0.5 policy
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=lerobot-data-collection/folding_final \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}}' \
|
||||
--robot.left_arm_config.port=can1 \
|
||||
--robot.left_arm_config.side=left \
|
||||
--robot.left_arm_config.can_interface=socketcan \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.right_arm_config.can_interface=socketcan \
|
||||
--task="Fold the T-shirt properly" \
|
||||
--fps=30 \
|
||||
--duration=2000 \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--rtc.max_guidance_weight=5.0 \
|
||||
--rtc.prefix_attention_schedule=LINEAR \
|
||||
--device=cuda
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -107,29 +87,21 @@ from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
TransitionKey,
|
||||
create_transition,
|
||||
)
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.processor.relative_action_processor import to_relative_actions
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_openarm_follower,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
so_follower,
|
||||
unitree_g1,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -240,35 +212,6 @@ def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute: Tensor,
|
||||
current_state: Tensor,
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
normalizer_step: NormalizerProcessorStep | None,
|
||||
policy_device: torch.device | str,
|
||||
) -> Tensor:
|
||||
"""Convert absolute leftovers into model-space for relative-action RTC policies.
|
||||
|
||||
When a policy uses relative actions, the RTC prefix (leftover actions from
|
||||
the previous chunk) is stored in absolute space. Before feeding it back to
|
||||
the policy we need to re-express it relative to the *current* robot state
|
||||
and then re-normalize.
|
||||
"""
|
||||
state = current_state.detach().cpu()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
action_cpu = prev_actions_absolute.detach().cpu()
|
||||
mask = relative_step._build_mask(action_cpu.shape[-1])
|
||||
relative_actions = to_relative_actions(action_cpu, state, mask)
|
||||
|
||||
transition = create_transition(action=relative_actions)
|
||||
if normalizer_step is not None:
|
||||
transition = normalizer_step(transition)
|
||||
|
||||
return transition[TransitionKey.ACTION].to(policy_device)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
@@ -294,15 +237,7 @@ def get_actions(
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
# Only keep .pos joints + camera streams if the policy was trained on positions,
|
||||
# not the full pos/vel/torque state the robot exposes.
|
||||
observation_features_hw = {
|
||||
key: value
|
||||
for key, value in robot.observation_features().items()
|
||||
if key.endswith(".pos") or isinstance(value, tuple)
|
||||
}
|
||||
|
||||
dataset_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
@@ -320,25 +255,6 @@ def get_actions(
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
normalizer_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, NormalizerProcessorStep)),
|
||||
None,
|
||||
)
|
||||
if relative_step is not None:
|
||||
if relative_step.action_names is None:
|
||||
cfg_names = getattr(cfg.policy, "action_feature_names", None)
|
||||
if cfg_names:
|
||||
relative_step.action_names = list(cfg_names)
|
||||
else:
|
||||
relative_step.action_names = [
|
||||
k for k in robot.robot.action_features if k.endswith(".pos")
|
||||
]
|
||||
logger.info("[GET_ACTIONS] Relative actions enabled: will re-anchor RTC prefix")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
@@ -381,28 +297,6 @@ def get_actions(
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Re-anchor leftover actions for relative-action policies.
|
||||
# We need the *postprocessed* (absolute) leftover, not the original
|
||||
# (normalized/relative) one that get_left_over() returns.
|
||||
if (
|
||||
prev_actions is not None
|
||||
and relative_step is not None
|
||||
and OBS_STATE in obs_with_policy_features
|
||||
):
|
||||
with action_queue.lock:
|
||||
if action_queue.queue is not None:
|
||||
prev_actions_abs = action_queue.queue[action_queue.last_index :].clone()
|
||||
else:
|
||||
prev_actions_abs = None
|
||||
if prev_actions_abs is not None and prev_actions_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_abs,
|
||||
current_state=obs_with_policy_features[OBS_STATE],
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
@@ -458,8 +352,6 @@ def actor_control(
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_keys = [k for k in robot.action_features() if k.endswith(".pos")]
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
@@ -471,7 +363,7 @@ def actor_control(
|
||||
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(action_keys)}
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
|
||||
|
||||
@@ -68,7 +68,9 @@ def main():
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
actions = dataset.select_columns(ACTION)
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
@@ -79,7 +81,7 @@ def main():
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(dataset.num_frames):
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get recorded action from dataset
|
||||
|
||||
@@ -1,297 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Inference script for a pi0 model trained with UMI-style relative EE actions
|
||||
on an OpenArm robot (single right arm, one wrist camera).
|
||||
|
||||
Training dataset layout:
|
||||
observation.images.cam0 [3, 720, 960]
|
||||
action [x, y, z, ax, ay, az, proximal, distal] (shape 8)
|
||||
|
||||
The model uses ``derive_state_from_action=true``, so observation.state is
|
||||
derived from the action column during training. At inference the state must
|
||||
be provided by the robot — this script uses FK to compute the current EE
|
||||
pose and gripper position, which it exposes as ``observation.state``.
|
||||
|
||||
Pipeline:
|
||||
1. Read arm joints from robot → FK → observation.state [x,y,z,ax,ay,az,prox,dist]
|
||||
2. Read camera image → observation.images.cam0
|
||||
3. pi0 preprocessor (loaded from checkpoint):
|
||||
- DeriveStateFromActionStep: no-op at inference (state from robot)
|
||||
- RelativeActionsProcessorStep: caches current state
|
||||
- RelativeStateProcessorStep: buffers prev state, stacks [prev,cur],
|
||||
subtracts current → velocity info, flattens
|
||||
- NormalizerProcessorStep: normalizes
|
||||
4. pi0 predicts relative action chunk (30 steps)
|
||||
5. pi0 postprocessor: unnormalize, add cached state → absolute EE
|
||||
6. IK: absolute EE [x,y,z,ax,ay,az] → arm joint targets
|
||||
7. Gripper [proximal, distal] → gripper motor targets
|
||||
8. Send to robot
|
||||
|
||||
Usage:
|
||||
python evaluate.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
|
||||
from lerobot.processor import RelativeStateProcessorStep
|
||||
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
from lerobot.scripts.lerobot_record import record_loop
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration — adapt these to your setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
FPS = 46
|
||||
EPISODE_TIME_SEC = 60
|
||||
TASK_DESCRIPTION = "red cube"
|
||||
|
||||
HF_MODEL_ID = "pepijn223/grabette-umi-pi0"
|
||||
|
||||
# Latency compensation: skip this many predicted action steps to account for
|
||||
# camera + inference + execution latency. Formula: ceil(total_ms / (1000/FPS)).
|
||||
# At 46 FPS (~22ms/step) with ~150ms total latency: ceil(150/22) ≈ 7.
|
||||
# Start with 0 for a safe first test, then increase to match measured latency.
|
||||
LATENCY_SKIP_STEPS = 0
|
||||
|
||||
URDF_PATH = "src/lerobot/robots/openarm_follower/urdf/openarm_bimanual_pybullet.urdf"
|
||||
URDF_EE_FRAME = "openarm_right_ee_target"
|
||||
|
||||
IK_POSITION_WEIGHT = 1.0
|
||||
IK_ORIENTATION_WEIGHT = 1.0
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset features for inference
|
||||
#
|
||||
# The training dataset has only observation.images.cam0 and action.
|
||||
# observation.state is derived from action during training
|
||||
# (derive_state_from_action=true) but must be supplied by the robot at
|
||||
# inference. We define it here so build_dataset_frame can map FK output
|
||||
# to the right feature.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DATASET_FEATURES: dict = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [8],
|
||||
"names": ["x", "y", "z", "ax", "ay", "az", "proximal", "distal"],
|
||||
},
|
||||
"observation.images.cam0": {
|
||||
"dtype": "video",
|
||||
"shape": [3, 720, 960],
|
||||
"names": ["channels", "height", "width"],
|
||||
"info": {
|
||||
"video.height": 720,
|
||||
"video.width": 960,
|
||||
"video.codec": "h264",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.is_depth_map": False,
|
||||
"video.fps": FPS,
|
||||
"video.channels": 3,
|
||||
"has_audio": False,
|
||||
},
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": [8],
|
||||
"names": ["x", "y", "z", "ax", "ay", "az", "proximal", "distal"],
|
||||
},
|
||||
"timestamp": {"dtype": "float32", "shape": [1], "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
"episode_index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
"index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
"task_index": {"dtype": "int64", "shape": [1], "names": None},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FK / IK callables
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JointsToEE:
|
||||
"""FK: raw robot observation → flat dict matching observation.state names.
|
||||
|
||||
Arm joint positions → EE pose [x,y,z,ax,ay,az] via forward kinematics.
|
||||
Gripper motor positions → [proximal, distal].
|
||||
Camera images pass through unchanged.
|
||||
"""
|
||||
|
||||
def __init__(self, kinematics: RobotKinematics, arm_motor_names: list[str]):
|
||||
self.kin = kinematics
|
||||
self.arm = arm_motor_names
|
||||
|
||||
def __call__(self, obs: RobotObservation) -> RobotObservation:
|
||||
q = np.array([float(obs[f"{m}.pos"]) for m in self.arm])
|
||||
t = self.kin.forward_kinematics(q)
|
||||
rot = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
out: dict = {
|
||||
"x": float(t[0, 3]),
|
||||
"y": float(t[1, 3]),
|
||||
"z": float(t[2, 3]),
|
||||
"ax": float(rot[0]),
|
||||
"ay": float(rot[1]),
|
||||
"az": float(rot[2]),
|
||||
"proximal": float(obs["proximal.pos"]),
|
||||
"distal": float(obs["distal.pos"]),
|
||||
}
|
||||
for k, v in obs.items():
|
||||
if not k.endswith((".pos", ".vel", ".torque")):
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
class EEToJoints:
|
||||
"""IK: policy action dict → motor position dict for the robot.
|
||||
|
||||
Reads [x,y,z,ax,ay,az] from the action, runs IK for arm joint targets.
|
||||
Passes [proximal, distal] as direct gripper position commands.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kinematics: RobotKinematics,
|
||||
arm_motor_names: list[str],
|
||||
position_weight: float = 1.0,
|
||||
orientation_weight: float = 1.0,
|
||||
):
|
||||
self.kin = kinematics
|
||||
self.arm = arm_motor_names
|
||||
self.pw = position_weight
|
||||
self.ow = orientation_weight
|
||||
self.q_curr: np.ndarray | None = None
|
||||
|
||||
def __call__(self, args: tuple[RobotAction, RobotObservation]) -> RobotAction:
|
||||
action, obs = args
|
||||
|
||||
q_raw = np.array([float(obs[f"{m}.pos"]) for m in self.arm])
|
||||
if self.q_curr is None:
|
||||
self.q_curr = q_raw
|
||||
|
||||
t_des = np.eye(4)
|
||||
t_des[:3, :3] = Rotation.from_rotvec([action["ax"], action["ay"], action["az"]]).as_matrix()
|
||||
t_des[:3, 3] = [action["x"], action["y"], action["z"]]
|
||||
|
||||
q_target = self.kin.inverse_kinematics(
|
||||
self.q_curr, t_des, position_weight=self.pw, orientation_weight=self.ow
|
||||
)
|
||||
self.q_curr = q_target
|
||||
|
||||
out: dict = {f"{m}.pos": float(q_target[i]) for i, m in enumerate(self.arm)}
|
||||
out["proximal.pos"] = float(action["proximal"])
|
||||
out["distal.pos"] = float(action["distal"])
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
camera_config = {
|
||||
"cam0": OpenCVCameraConfig(index_or_path=0, width=960, height=720, fps=FPS),
|
||||
}
|
||||
robot_config = OpenArmFollowerConfig(
|
||||
port="can0",
|
||||
id="right_openarm",
|
||||
side="right",
|
||||
cameras=camera_config,
|
||||
max_relative_target=8.0,
|
||||
gripper_port="/dev/ttyUSB0",
|
||||
)
|
||||
robot = OpenArmFollower(robot_config)
|
||||
|
||||
policy = PI0Policy.from_pretrained(HF_MODEL_ID)
|
||||
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
|
||||
|
||||
arm_motor_names = list(robot.bus.motors.keys())
|
||||
|
||||
kinematics = RobotKinematics(
|
||||
urdf_path=URDF_PATH,
|
||||
target_frame_name=URDF_EE_FRAME,
|
||||
joint_names=arm_motor_names,
|
||||
)
|
||||
|
||||
fk = JointsToEE(kinematics, arm_motor_names)
|
||||
ik = EEToJoints(kinematics, arm_motor_names, IK_POSITION_WEIGHT, IK_ORIENTATION_WEIGHT)
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="tmp/openarm_eval_scratch",
|
||||
fps=FPS,
|
||||
features=DATASET_FEATURES,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
|
||||
|
||||
robot.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarm_umi_pi0_relative_ee_evaluate")
|
||||
|
||||
try:
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
log_say("Starting policy execution")
|
||||
for step in relative_state_steps:
|
||||
step.reset()
|
||||
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
robot_action_processor=ik,
|
||||
robot_observation_processor=fk,
|
||||
)
|
||||
finally:
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,113 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Replay a dataset episode in EE frame using a browser-based URDF viewer.
|
||||
|
||||
Extracts ``observation.pose`` from the dataset, saves a trajectory JSON file,
|
||||
then launches a local HTTP server and opens the replay viewer. The trajectory
|
||||
is re-centered so frame 0 starts at the OpenArm ``openarm_right_ee_target``
|
||||
EE tip (zero-joint pose).
|
||||
|
||||
Usage:
|
||||
python replay.py
|
||||
python replay.py --episode 3 --repo-id myuser/mydata
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import http.server
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
|
||||
VIEWER_DIR = Path(__file__).resolve().parents[2] / "src/lerobot/robots/openarm_follower/urdf"
|
||||
TRAJECTORY_FILENAME = "trajectory_ep0.json"
|
||||
|
||||
|
||||
def extract_trajectory(repo_id: str, episode: int, output_path: Path) -> dict:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset(repo_id, episodes=[episode])
|
||||
poses = dataset.select_columns("observation.pose")
|
||||
actions = dataset.select_columns("action")
|
||||
|
||||
frames = []
|
||||
for i in range(dataset.num_frames):
|
||||
p = poses[i]["observation.pose"]
|
||||
a = actions[i]["action"]
|
||||
frames.append(
|
||||
{
|
||||
"x": float(p[0]),
|
||||
"y": float(p[1]),
|
||||
"z": float(p[2]),
|
||||
"ax": float(p[3]),
|
||||
"ay": float(p[4]),
|
||||
"az": float(p[5]),
|
||||
"proximal": float(a[0]),
|
||||
"distal": float(a[1]),
|
||||
}
|
||||
)
|
||||
payload = {"fps": dataset.fps, "num_frames": dataset.num_frames, "frames": frames}
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(payload, f)
|
||||
print(f"Extracted {dataset.num_frames} frames at {dataset.fps} FPS → {output_path}")
|
||||
return payload
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Viewer mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def serve_and_open(directory: Path, port: int = 8765):
|
||||
os.chdir(directory)
|
||||
handler = http.server.SimpleHTTPRequestHandler
|
||||
httpd = http.server.HTTPServer(("", port), handler)
|
||||
url = f"http://localhost:{port}/replay_viewer.html"
|
||||
print(f"Serving at {url}")
|
||||
threading.Thread(target=lambda: webbrowser.open(url), daemon=True).start()
|
||||
try:
|
||||
httpd.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\nServer stopped.")
|
||||
httpd.server_close()
|
||||
|
||||
|
||||
def run_viewer(args):
|
||||
trajectory_path = VIEWER_DIR / TRAJECTORY_FILENAME
|
||||
if not trajectory_path.exists() or args.force:
|
||||
extract_trajectory(args.repo_id, args.episode, trajectory_path)
|
||||
else:
|
||||
print(f"Using cached trajectory at {trajectory_path} (pass --force to re-extract)")
|
||||
serve_and_open(VIEWER_DIR, args.port)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Replay a dataset episode in EE frame (URDF viewer)")
|
||||
parser.add_argument("--repo-id", default="glannuzel/grabette-dataset")
|
||||
parser.add_argument("--episode", type=int, default=0)
|
||||
parser.add_argument("--port", type=int, default=8765)
|
||||
parser.add_argument("--force", action="store_true", help="Re-extract trajectory even if cached")
|
||||
args = parser.parse_args()
|
||||
run_viewer(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -99,7 +99,7 @@ dependencies = [
|
||||
# Common
|
||||
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.9.17"]
|
||||
transformers-dep = ["transformers==5.3.0"] # TODO(Steven): https://github.com/huggingface/lerobot/pull/3249
|
||||
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
peft-dep = ["peft>=0.18.0,<1.0.0"]
|
||||
@@ -145,7 +145,6 @@ wallx = [
|
||||
]
|
||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
||||
multi_task_dit = ["lerobot[transformers-dep]"]
|
||||
groot = [
|
||||
"lerobot[transformers-dep]",
|
||||
"lerobot[peft]",
|
||||
@@ -306,8 +305,7 @@ default.extend-ignore-identifiers-re = [
|
||||
"thw",
|
||||
"inpt",
|
||||
"ROBOTIS",
|
||||
"OT_VALUE",
|
||||
"metalness",
|
||||
"OT_VALUE"
|
||||
]
|
||||
|
||||
# TODO: Uncomment when ready to use
|
||||
|
||||
@@ -27,8 +27,7 @@ class DatasetConfig:
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory for a concrete local dataset tree (e.g. 'dataset/path'). If None, local datasets are
|
||||
# looked up under $HF_LEROBOT_HOME/repo_id and Hub downloads use a revision-safe cache under $HF_LEROBOT_HOME/hub.
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
|
||||
@@ -115,17 +115,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
def reward_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def state_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
"""Delta indices specifically for observation.state.
|
||||
|
||||
When not None, overrides ``observation_delta_indices`` for the
|
||||
``observation.state`` key only. Useful for loading state history
|
||||
(e.g. ``[-1, 0]`` for UMI-style relative proprioception) without
|
||||
also loading multiple image timesteps.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig
|
||||
|
||||
__all__ = [
|
||||
"EpisodeAwareSampler",
|
||||
"ImageTransforms",
|
||||
"ImageTransformsConfig",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"StreamingLeRobotDataset",
|
||||
]
|
||||
@@ -13,14 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.io_utils import load_image_as_numpy
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
@@ -629,232 +624,3 @@ def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np
|
||||
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
|
||||
|
||||
return aggregated_stats
|
||||
|
||||
|
||||
def _get_valid_chunk_starts(episode_indices: np.ndarray, chunk_size: int) -> np.ndarray:
|
||||
"""Return all start indices where a chunk of ``chunk_size`` stays within one episode."""
|
||||
total = len(episode_indices)
|
||||
if total < chunk_size:
|
||||
return np.array([], dtype=np.int64)
|
||||
max_start = total - chunk_size
|
||||
starts = np.arange(max_start + 1)
|
||||
valid = episode_indices[starts] == episode_indices[starts + chunk_size - 1]
|
||||
return starts[valid]
|
||||
|
||||
|
||||
def _compute_relative_chunk_batch(
|
||||
start_indices: np.ndarray,
|
||||
all_actions: np.ndarray,
|
||||
all_states: np.ndarray,
|
||||
chunk_size: int,
|
||||
relative_mask: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Vectorised relative-action computation for a batch of start indices.
|
||||
|
||||
Returns an ``(N * chunk_size, action_dim)`` float32 array.
|
||||
"""
|
||||
if len(start_indices) == 0:
|
||||
return np.empty((0, all_actions.shape[1]), dtype=np.float32)
|
||||
offsets = np.arange(chunk_size)
|
||||
frame_idx = start_indices[:, None] + offsets[None, :]
|
||||
chunks = all_actions[frame_idx].copy()
|
||||
states = all_states[start_indices]
|
||||
mask_dim = len(relative_mask)
|
||||
chunks[:, :, :mask_dim] -= states[:, None, :mask_dim] * relative_mask[None, None, :]
|
||||
return chunks.reshape(-1, all_actions.shape[1])
|
||||
|
||||
|
||||
def compute_relative_action_stats(
|
||||
hf_dataset,
|
||||
features: dict,
|
||||
chunk_size: int,
|
||||
exclude_joints: list[str] | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute normalization statistics for relative actions over the full dataset.
|
||||
|
||||
Iterates *all* valid action chunks (within single episodes), converts them to
|
||||
relative actions (action − current_state), and computes per-dimension
|
||||
statistics suitable for normalization.
|
||||
|
||||
Args:
|
||||
hf_dataset: The underlying HuggingFace dataset with "action",
|
||||
"observation.state", and "episode_index" columns.
|
||||
features: Dataset feature metadata (must contain "action" with "shape"
|
||||
and optionally "names").
|
||||
chunk_size: Number of consecutive frames per action chunk.
|
||||
exclude_joints: Joint names whose dimensions should remain absolute
|
||||
(not converted to relative actions).
|
||||
num_workers: Number of parallel threads for computation. Values ≤1
|
||||
mean single-threaded. Numpy releases the GIL so threads give
|
||||
real parallelism here.
|
||||
|
||||
Returns:
|
||||
Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99".
|
||||
|
||||
Raises:
|
||||
ValueError: If the dataset has fewer frames than ``chunk_size``.
|
||||
RuntimeError: If no valid (single-episode) chunks are found.
|
||||
"""
|
||||
from lerobot.processor.relative_action_processor import RelativeActionsProcessorStep
|
||||
|
||||
if exclude_joints is None:
|
||||
exclude_joints = []
|
||||
|
||||
action_dim = features[ACTION]["shape"][0]
|
||||
action_names = features.get(ACTION, {}).get("names")
|
||||
mask_step = RelativeActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=exclude_joints,
|
||||
action_names=action_names,
|
||||
)
|
||||
relative_mask = np.array(mask_step._build_mask(action_dim), dtype=np.float32)
|
||||
|
||||
logging.info("Loading action/state data for relative action stats...")
|
||||
all_actions = np.array(hf_dataset[ACTION], dtype=np.float32)
|
||||
all_states = np.array(hf_dataset[OBS_STATE], dtype=np.float32)
|
||||
episode_indices = np.array(hf_dataset["episode_index"])
|
||||
|
||||
valid_starts = _get_valid_chunk_starts(episode_indices, chunk_size)
|
||||
if len(valid_starts) == 0:
|
||||
raise RuntimeError(
|
||||
f"No valid chunks found (total_frames={len(episode_indices)}, chunk_size={chunk_size})"
|
||||
)
|
||||
|
||||
effective_workers = max(num_workers, 1)
|
||||
logging.info(
|
||||
f"Computing relative action stats from {len(valid_starts)} chunks "
|
||||
f"(chunk_size={chunk_size}, workers={effective_workers})"
|
||||
)
|
||||
|
||||
batch_size = 50_000
|
||||
batches = [valid_starts[i : i + batch_size] for i in range(0, len(valid_starts), batch_size)]
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
|
||||
if num_workers > 1:
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [
|
||||
pool.submit(
|
||||
_compute_relative_chunk_batch,
|
||||
batch,
|
||||
all_actions,
|
||||
all_states,
|
||||
chunk_size,
|
||||
relative_mask,
|
||||
)
|
||||
for batch in batches
|
||||
]
|
||||
for future in as_completed(futures):
|
||||
running_stats.update(future.result())
|
||||
else:
|
||||
for batch in batches:
|
||||
running_stats.update(
|
||||
_compute_relative_chunk_batch(batch, all_actions, all_states, chunk_size, relative_mask)
|
||||
)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
excluded_dims = int(len(relative_mask) - relative_mask.sum())
|
||||
total_frames = len(valid_starts) * chunk_size
|
||||
logging.info(
|
||||
f"Relative action stats ({len(valid_starts)} chunks, {total_frames} frames): "
|
||||
f"relative_dims={int(relative_mask.sum())}/{len(relative_mask)} (excluded={excluded_dims}), "
|
||||
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}, "
|
||||
f"q01={stats['q01'].mean():.4f}, q99={stats['q99'].mean():.4f}"
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def compute_relative_state_stats(
|
||||
hf_dataset,
|
||||
features: dict,
|
||||
state_obs_steps: int = 2,
|
||||
exclude_joints: list[str] | None = None,
|
||||
source_key: str = OBS_STATE,
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""Compute normalization statistics for observation.state after relative conversion.
|
||||
|
||||
For UMI-style relative proprioception with ``state_obs_steps`` timesteps,
|
||||
each state observation becomes a stack of offsets from the current timestep:
|
||||
``state[t-k] - state[t]`` for k in ``range(state_obs_steps-1, -1, -1)``.
|
||||
|
||||
The stats are computed over the flattened ``[state_obs_steps * state_dim]``
|
||||
vector that the model actually sees after ``prepare_state`` flattening.
|
||||
|
||||
Args:
|
||||
hf_dataset: The HuggingFace dataset with the source column and
|
||||
"episode_index" columns.
|
||||
features: Dataset feature metadata.
|
||||
state_obs_steps: Number of observation timesteps (must be >= 2).
|
||||
exclude_joints: State dimension names to keep absolute.
|
||||
source_key: Column to read data from. Defaults to "observation.state".
|
||||
When ``derive_state_from_action=True``, pass ``ACTION`` to read
|
||||
from the action column instead.
|
||||
|
||||
Returns:
|
||||
Statistics dict with keys "mean", "std", "min", "max", "q01", …, "q99".
|
||||
"""
|
||||
from lerobot.processor.relative_action_processor import RelativeStateProcessorStep
|
||||
|
||||
if exclude_joints is None:
|
||||
exclude_joints = []
|
||||
|
||||
state_dim = features[source_key]["shape"][0]
|
||||
state_names = features.get(source_key, {}).get("names")
|
||||
mask_step = RelativeStateProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=exclude_joints,
|
||||
state_names=state_names,
|
||||
)
|
||||
relative_mask = np.array(mask_step._build_mask(state_dim), dtype=np.float32)
|
||||
|
||||
logging.info(f"Loading data from '{source_key}' for relative state stats...")
|
||||
all_states = np.array(hf_dataset[source_key], dtype=np.float32)
|
||||
episode_indices = np.array(hf_dataset["episode_index"])
|
||||
|
||||
# Build all valid windows of length state_obs_steps within each episode
|
||||
n = len(all_states)
|
||||
if n < state_obs_steps:
|
||||
raise ValueError(f"Dataset has {n} frames but state_obs_steps={state_obs_steps}")
|
||||
|
||||
max_start = n - state_obs_steps
|
||||
starts = np.arange(max_start + 1)
|
||||
valid = episode_indices[starts] == episode_indices[starts + state_obs_steps - 1]
|
||||
valid_starts = starts[valid]
|
||||
|
||||
if len(valid_starts) == 0:
|
||||
raise RuntimeError("No valid state windows found within single episodes")
|
||||
|
||||
offsets = np.arange(state_obs_steps)
|
||||
mask_dim = len(relative_mask)
|
||||
|
||||
running_stats = RunningQuantileStats()
|
||||
|
||||
batch_size = 50_000
|
||||
for i in range(0, len(valid_starts), batch_size):
|
||||
batch_starts = valid_starts[i : i + batch_size]
|
||||
frame_idx = batch_starts[:, None] + offsets[None, :] # [N, state_obs_steps]
|
||||
windows = all_states[frame_idx].copy() # [N, state_obs_steps, state_dim]
|
||||
|
||||
# Subtract current (last) timestep from all timesteps for masked dims
|
||||
current = windows[:, -1:, :] # [N, 1, state_dim]
|
||||
windows[:, :, :mask_dim] -= current[:, :, :mask_dim] * relative_mask[None, None, :]
|
||||
|
||||
# Flatten to [N, state_obs_steps * state_dim] (same as prepare_state)
|
||||
flattened = windows.reshape(len(batch_starts), -1)
|
||||
running_stats.update(flattened)
|
||||
|
||||
stats = running_stats.get_statistics()
|
||||
|
||||
excluded_dims = int(mask_dim - relative_mask.sum())
|
||||
logging.info(
|
||||
f"Relative state stats ({len(valid_starts)} windows, obs_steps={state_obs_steps}): "
|
||||
f"relative_dims={int(relative_mask.sum())}/{mask_dim} (excluded={excluded_dims}), "
|
||||
f"mean={np.abs(stats['mean']).mean():.4f}, std={stats['std'].mean():.4f}"
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -44,24 +43,16 @@ from lerobot.datasets.utils import (
|
||||
check_version_compatibility,
|
||||
flatten_dict,
|
||||
get_safe_version,
|
||||
has_legacy_hub_download_metadata,
|
||||
is_valid_version,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, HF_LEROBOT_HUB_CACHE
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
"""Metadata container for a LeRobot dataset.
|
||||
|
||||
Manages the ``info.json``, ``stats.json``, ``tasks.parquet``, and
|
||||
``episodes/`` parquet files that describe a dataset's structure, content,
|
||||
and statistics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
@@ -70,57 +61,33 @@ class LeRobotDatasetMetadata:
|
||||
force_cache_sync: bool = False,
|
||||
metadata_buffer_size: int = 10,
|
||||
):
|
||||
"""Load or download metadata for an existing LeRobot dataset.
|
||||
|
||||
Attempts to load metadata from local disk. If files are missing or
|
||||
``force_cache_sync`` is ``True``, downloads the ``meta/`` directory from
|
||||
the Hub.
|
||||
|
||||
Args:
|
||||
repo_id: Repository identifier (e.g. ``'lerobot/aloha_sim'``).
|
||||
root: Local directory for the dataset. When provided, Hub downloads
|
||||
are materialized directly into this directory. When omitted,
|
||||
existing local datasets are still looked up under
|
||||
``$HF_LEROBOT_HOME/{repo_id}``, but Hub downloads use a
|
||||
revision-safe snapshot cache under
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
revision: Git revision (branch, tag, or commit hash). Defaults to
|
||||
the current codebase version.
|
||||
force_cache_sync: If ``True``, re-download metadata from the Hub
|
||||
even when local files exist.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
in memory before flushing to parquet.
|
||||
"""
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._requested_root = Path(root) if root is not None else None
|
||||
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self._pq_writer = None
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self._metadata_buffer: list[dict] = []
|
||||
self._metadata_buffer_size = metadata_buffer_size
|
||||
self._finalized = False
|
||||
self.metadata_buffer: list[dict] = []
|
||||
self.metadata_buffer_size = metadata_buffer_size
|
||||
|
||||
try:
|
||||
if force_cache_sync or (
|
||||
self._requested_root is None and has_legacy_hub_download_metadata(self.root)
|
||||
):
|
||||
if force_cache_sync:
|
||||
raise FileNotFoundError
|
||||
self._load_metadata()
|
||||
self.load_metadata()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
|
||||
self._pull_from_repo(allow_patterns="meta/")
|
||||
self._load_metadata()
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.load_metadata()
|
||||
|
||||
def _flush_metadata_buffer(self) -> None:
|
||||
"""Write all buffered episode metadata to parquet file."""
|
||||
if not hasattr(self, "_metadata_buffer") or len(self._metadata_buffer) == 0:
|
||||
if not hasattr(self, "metadata_buffer") or len(self.metadata_buffer) == 0:
|
||||
return
|
||||
|
||||
combined_dict = {}
|
||||
for episode_dict in self._metadata_buffer:
|
||||
for episode_dict in self.metadata_buffer:
|
||||
for key, value in episode_dict.items():
|
||||
if key not in combined_dict:
|
||||
combined_dict[key] = []
|
||||
@@ -129,50 +96,40 @@ class LeRobotDatasetMetadata:
|
||||
val = value[0] if isinstance(value, list) else value
|
||||
combined_dict[key].append(val.tolist() if isinstance(val, np.ndarray) else val)
|
||||
|
||||
first_ep = self._metadata_buffer[0]
|
||||
first_ep = self.metadata_buffer[0]
|
||||
chunk_idx = first_ep["meta/episodes/chunk_index"][0]
|
||||
file_idx = first_ep["meta/episodes/file_index"][0]
|
||||
|
||||
table = pa.Table.from_pydict(combined_dict)
|
||||
|
||||
if not self._pq_writer:
|
||||
if not self.writer:
|
||||
path = Path(self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx))
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._pq_writer = pq.ParquetWriter(
|
||||
self.writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
|
||||
self._pq_writer.write_table(table)
|
||||
self.writer.write_table(table)
|
||||
|
||||
self.latest_episode = self._metadata_buffer[-1]
|
||||
self._metadata_buffer.clear()
|
||||
self.latest_episode = self.metadata_buffer[-1]
|
||||
self.metadata_buffer.clear()
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
writer = getattr(self, "_pq_writer", None)
|
||||
writer = getattr(self, "writer", None)
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
self._pq_writer = None
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Flush metadata buffer and close the parquet writer.
|
||||
|
||||
Idempotent — safe to call multiple times.
|
||||
"""
|
||||
if getattr(self, "_finalized", False):
|
||||
return
|
||||
self._close_writer()
|
||||
self._finalized = True
|
||||
self.writer = None
|
||||
|
||||
def __del__(self):
|
||||
"""Safety net: flush and close parquet writer on garbage collection."""
|
||||
# During interpreter shutdown, referenced objects may already be collected.
|
||||
with contextlib.suppress(Exception):
|
||||
self.finalize()
|
||||
"""
|
||||
Trust the user to call .finalize() but as an added safety check call the parquet writer to stop when calling the destructor
|
||||
"""
|
||||
self._close_writer()
|
||||
|
||||
def _load_metadata(self):
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
@@ -180,38 +137,22 @@ class LeRobotDatasetMetadata:
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
def _pull_from_repo(
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
if self._requested_root is None:
|
||||
self.root = Path(
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
cache_dir=HF_LEROBOT_HUB_CACHE,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self.revision,
|
||||
local_dir=self._requested_root,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
self.root = self._requested_root
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
"""Hugging Face Hub URL root for this dataset."""
|
||||
return f"hf://datasets/{self.repo_id}"
|
||||
|
||||
@property
|
||||
@@ -220,17 +161,6 @@ class LeRobotDatasetMetadata:
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
"""Return the relative parquet file path for the given episode index.
|
||||
|
||||
Args:
|
||||
ep_index: Zero-based episode index.
|
||||
|
||||
Returns:
|
||||
Path to the parquet file containing this episode's data.
|
||||
|
||||
Raises:
|
||||
IndexError: If ``ep_index`` is out of range.
|
||||
"""
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
@@ -244,19 +174,6 @@ class LeRobotDatasetMetadata:
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
"""Return the relative video file path for the given episode and video key.
|
||||
|
||||
Args:
|
||||
ep_index: Zero-based episode index.
|
||||
vid_key: Feature key identifying the video stream
|
||||
(e.g. ``'observation.images.laptop'``).
|
||||
|
||||
Returns:
|
||||
Path to the video file containing this episode's frames.
|
||||
|
||||
Raises:
|
||||
IndexError: If ``ep_index`` is out of range.
|
||||
"""
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
@@ -360,17 +277,6 @@ class LeRobotDatasetMetadata:
|
||||
return None
|
||||
|
||||
def save_episode_tasks(self, tasks: list[str]):
|
||||
"""Register tasks for the current episode and persist to disk.
|
||||
|
||||
New tasks that do not already exist in the dataset are assigned
|
||||
sequential task indices and appended to the tasks parquet file.
|
||||
|
||||
Args:
|
||||
tasks: List of unique task descriptions in natural language.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``tasks`` contains duplicates.
|
||||
"""
|
||||
if len(set(tasks)) != len(tasks):
|
||||
raise ValueError(f"Tasks are not unique: {tasks}")
|
||||
|
||||
@@ -430,8 +336,8 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
latest_path = (
|
||||
self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
if self._pq_writer is None
|
||||
else self._pq_writer.where
|
||||
if self.writer is None
|
||||
else self.writer.where
|
||||
)
|
||||
|
||||
if Path(latest_path).exists():
|
||||
@@ -453,10 +359,10 @@ class LeRobotDatasetMetadata:
|
||||
episode_dict["dataset_to_index"] = [self.latest_episode["dataset_to_index"][0] + num_frames]
|
||||
|
||||
# Add to buffer
|
||||
self._metadata_buffer.append(episode_dict)
|
||||
self.metadata_buffer.append(episode_dict)
|
||||
self.latest_episode = episode_dict
|
||||
|
||||
if len(self._metadata_buffer) >= self._metadata_buffer_size:
|
||||
if len(self.metadata_buffer) >= self.metadata_buffer_size:
|
||||
self._flush_metadata_buffer()
|
||||
|
||||
def save_episode(
|
||||
@@ -467,20 +373,6 @@ class LeRobotDatasetMetadata:
|
||||
episode_stats: dict[str, dict],
|
||||
episode_metadata: dict,
|
||||
) -> None:
|
||||
"""Persist episode metadata, update dataset info, and aggregate stats.
|
||||
|
||||
Writes the episode's metadata to the buffered parquet writer, increments
|
||||
the total episode/frame counters in ``info.json``, and merges the
|
||||
episode's statistics into the running dataset statistics.
|
||||
|
||||
Args:
|
||||
episode_index: Zero-based index of the episode being saved.
|
||||
episode_length: Number of frames in this episode.
|
||||
episode_tasks: List of task descriptions for this episode.
|
||||
episode_stats: Per-feature statistics for this episode.
|
||||
episode_metadata: Additional metadata (chunk/file indices, frame
|
||||
ranges, video timestamps, etc.).
|
||||
"""
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": episode_tasks,
|
||||
@@ -587,36 +479,10 @@ class LeRobotDatasetMetadata:
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Create metadata for a new LeRobot dataset from scratch.
|
||||
|
||||
Initializes the ``info.json`` file on disk with the provided feature
|
||||
schema and dataset settings. No episode data is written yet.
|
||||
|
||||
Args:
|
||||
repo_id: Repository identifier (e.g. ``'user/my_dataset'``).
|
||||
fps: Frames per second used during data collection.
|
||||
features: Feature specification dict mapping feature names to their
|
||||
type/shape metadata.
|
||||
robot_type: Optional robot type string stored in metadata.
|
||||
root: Local directory for the dataset. Defaults to
|
||||
``$HF_LEROBOT_HOME/{repo_id}``. Must not already exist.
|
||||
use_videos: If ``True``, visual modalities are encoded as MP4 videos.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
before flushing to parquet.
|
||||
chunks_size: Max number of files per chunk directory. ``None`` uses
|
||||
the default.
|
||||
data_files_size_in_mb: Max parquet file size in MB. ``None`` uses the
|
||||
default.
|
||||
video_files_size_in_mb: Max video file size in MB. ``None`` uses the
|
||||
default.
|
||||
|
||||
Returns:
|
||||
A new :class:`LeRobotDatasetMetadata` instance.
|
||||
"""
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj._requested_root = Path(root) if root is not None else None
|
||||
obj.root = obj._requested_root if obj._requested_root is not None else HF_LEROBOT_HOME / repo_id
|
||||
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
@@ -644,9 +510,8 @@ class LeRobotDatasetMetadata:
|
||||
)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj._pq_writer = None
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
obj._metadata_buffer = []
|
||||
obj._metadata_buffer_size = metadata_buffer_size
|
||||
obj._finalized = False
|
||||
obj.metadata_buffer = []
|
||||
obj.metadata_buffer_size = metadata_buffer_size
|
||||
return obj
|
||||
|
||||
@@ -1,288 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import (
|
||||
check_delta_timestamps,
|
||||
get_delta_indices,
|
||||
get_hf_features_from_features,
|
||||
)
|
||||
from lerobot.datasets.io_utils import (
|
||||
hf_transform_to_torch,
|
||||
load_nested_dataset,
|
||||
)
|
||||
from lerobot.datasets.video_utils import decode_video_frames
|
||||
|
||||
|
||||
class DatasetReader:
|
||||
"""Encapsulates read-side state and methods for LeRobotDataset.
|
||||
|
||||
Owns: hf_dataset, _absolute_to_relative_idx, delta_indices.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
episodes: list[int] | None,
|
||||
tolerance_s: float,
|
||||
video_backend: str,
|
||||
delta_timestamps: dict[str, list[float]] | None,
|
||||
image_transforms: Callable | None,
|
||||
):
|
||||
"""Initialize the reader with metadata, filtering, and transform config.
|
||||
|
||||
The HF dataset is not loaded here — call :meth:`try_load` or
|
||||
:meth:`load_and_activate` afterward.
|
||||
|
||||
Args:
|
||||
meta: Dataset metadata instance.
|
||||
root: Local dataset root directory.
|
||||
episodes: Optional list of episode indices to select. ``None``
|
||||
means all episodes.
|
||||
tolerance_s: Timestamp synchronization tolerance in seconds.
|
||||
video_backend: Video decoding backend identifier.
|
||||
delta_timestamps: Optional dict mapping feature keys to lists of
|
||||
relative timestamp offsets for temporal context windows.
|
||||
image_transforms: Optional torchvision v2 transform applied to
|
||||
visual features.
|
||||
"""
|
||||
self._meta = meta
|
||||
self.root = root
|
||||
self.episodes = episodes
|
||||
self._tolerance_s = tolerance_s
|
||||
self._video_backend = video_backend
|
||||
self._image_transforms = image_transforms
|
||||
|
||||
self.hf_dataset: datasets.Dataset | None = None
|
||||
self._absolute_to_relative_idx: dict[int, int] | None = None
|
||||
|
||||
# Setup delta_indices (doesn't depend on hf_dataset)
|
||||
self.delta_indices = None
|
||||
if delta_timestamps is not None:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
self.hf_dataset = self._load_hf_dataset()
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
self.hf_dataset = None
|
||||
return False
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
self.hf_dataset = None
|
||||
return False
|
||||
self._build_index_mapping()
|
||||
return True
|
||||
|
||||
def load_and_activate(self) -> None:
|
||||
"""Load HF dataset from disk and build index mapping. Call after data is on disk."""
|
||||
self.hf_dataset = self._load_hf_dataset()
|
||||
self._build_index_mapping()
|
||||
|
||||
def _build_index_mapping(self) -> None:
|
||||
"""Build absolute-to-relative index mapping from loaded hf_dataset."""
|
||||
self._absolute_to_relative_idx = None
|
||||
if self.episodes is not None and self.hf_dataset is not None:
|
||||
self._absolute_to_relative_idx = {
|
||||
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
|
||||
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
|
||||
}
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames in selected episodes."""
|
||||
if self.episodes is not None and self.hf_dataset is not None:
|
||||
return len(self.hf_dataset)
|
||||
return self._meta.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
return len(self.episodes) if self.episodes is not None else self._meta.total_episodes
|
||||
|
||||
def _load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
features = get_hf_features_from_features(self._meta.features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
return False
|
||||
|
||||
available_episodes = {
|
||||
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
|
||||
for ep_idx in self.hf_dataset.unique("episode_index")
|
||||
}
|
||||
|
||||
if self.episodes is None:
|
||||
requested_episodes = set(range(self._meta.total_episodes))
|
||||
else:
|
||||
requested_episodes = set(self.episodes)
|
||||
|
||||
if not requested_episodes.issubset(available_episodes):
|
||||
return False
|
||||
|
||||
if len(self._meta.video_keys) > 0:
|
||||
for ep_idx in requested_episodes:
|
||||
for vid_key in self._meta.video_keys:
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
if not video_path.exists():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_episodes_file_paths(self) -> list[Path]:
|
||||
"""Return deduplicated file paths (data + video) for selected episodes.
|
||||
|
||||
Used to build the ``allow_patterns`` list for ``snapshot_download``.
|
||||
"""
|
||||
episodes = self.episodes if self.episodes is not None else list(range(self._meta.total_episodes))
|
||||
fpaths = [str(self._meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
|
||||
if len(self._meta.video_keys) > 0:
|
||||
video_files = [
|
||||
str(self._meta.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self._meta.video_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
# episodes are stored in the same files, so we return unique paths only
|
||||
fpaths = list(set(fpaths))
|
||||
return fpaths
|
||||
|
||||
def _get_query_indices(
|
||||
self, abs_idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
|
||||
"""Compute query indices for delta timestamps."""
|
||||
ep = self._meta.episodes[ep_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
query_indices = {
|
||||
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = {
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
return query_indices, padding
|
||||
|
||||
def _get_query_timestamps(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self._meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
if self._absolute_to_relative_idx is not None:
|
||||
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
|
||||
timestamps = self.hf_dataset[relative_indices]["timestamp"]
|
||||
else:
|
||||
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
"""Query dataset for indices across keys, skipping video keys."""
|
||||
result: dict = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if key in self._meta.video_keys:
|
||||
continue
|
||||
relative_indices = (
|
||||
q_idx
|
||||
if self._absolute_to_relative_idx is None
|
||||
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
|
||||
)
|
||||
try:
|
||||
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
|
||||
except (KeyError, TypeError, IndexError):
|
||||
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
|
||||
return result
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault.
|
||||
"""
|
||||
ep = self._meta.episodes[ep_idx]
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
|
||||
def get_item(self, idx) -> dict:
|
||||
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
|
||||
|
||||
``idx`` is a *relative* index into the (possibly episode-filtered)
|
||||
HF dataset, **not** the absolute frame index stored in the ``index``
|
||||
column. The absolute index is retrieved from the row itself.
|
||||
"""
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
abs_idx = item["index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if len(self._meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
if self._image_transforms is not None:
|
||||
image_keys = self._meta.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self._image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
@@ -37,12 +37,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.compute_stats import (
|
||||
aggregate_stats,
|
||||
compute_episode_stats,
|
||||
compute_relative_action_stats,
|
||||
compute_relative_state_stats,
|
||||
)
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.io_utils import (
|
||||
get_parquet_file_size_in_mb,
|
||||
@@ -61,7 +56,7 @@ from lerobot.datasets.utils import (
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -896,7 +891,7 @@ def _copy_and_reindex_episodes_metadata(
|
||||
|
||||
total_frames += src_episode["length"]
|
||||
|
||||
dst_meta.finalize()
|
||||
dst_meta._close_writer()
|
||||
|
||||
dst_meta.info.update(
|
||||
{
|
||||
@@ -1538,147 +1533,6 @@ def modify_tasks(
|
||||
return dataset
|
||||
|
||||
|
||||
def recompute_stats(
|
||||
dataset: LeRobotDataset,
|
||||
skip_image_video: bool = True,
|
||||
relative_action: bool = False,
|
||||
relative_exclude_joints: list[str] | None = None,
|
||||
chunk_size: int = 50,
|
||||
num_workers: int = 0,
|
||||
relative_state: bool = False,
|
||||
relative_exclude_state_joints: list[str] | None = None,
|
||||
state_obs_steps: int = 2,
|
||||
derive_state_from_action: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
"""Recompute stats.json from scratch by iterating all episodes.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset to recompute stats for.
|
||||
skip_image_video: If True (default), only recompute stats for numeric features
|
||||
(action, state, etc.) and keep existing image/video stats unchanged.
|
||||
relative_action: If True, compute action stats in relative space by
|
||||
iterating all valid action chunks and subtracting the current state.
|
||||
This matches the normalization distribution the model sees during
|
||||
training with ``use_relative_actions=True``.
|
||||
relative_exclude_joints: Joint names to exclude from relative conversion when
|
||||
relative_action=True. These dims keep absolute stats.
|
||||
chunk_size: Action chunk size used for relative stats computation. Should match
|
||||
``policy.chunk_size``. Only used when ``relative_action=True``.
|
||||
num_workers: Number of parallel threads for relative action stats computation.
|
||||
Values ≤1 mean single-threaded. Only used when ``relative_action=True``.
|
||||
relative_state: If True, compute observation.state stats in relative space
|
||||
(multi-timestep offsets from current). This matches the normalization
|
||||
the model sees during training with ``use_relative_state=True``.
|
||||
relative_exclude_state_joints: State dim names to exclude from relative conversion.
|
||||
state_obs_steps: Number of observation timesteps for relative state stats.
|
||||
Should match ``policy.state_obs_steps``. Only used when ``relative_state=True``.
|
||||
derive_state_from_action: If True, compute relative state stats from the
|
||||
action column instead of observation.state. Implies ``relative_state=True``
|
||||
and ``state_obs_steps=2``.
|
||||
|
||||
Returns:
|
||||
The same dataset with updated stats.
|
||||
"""
|
||||
if derive_state_from_action:
|
||||
relative_state = True
|
||||
state_obs_steps = 2
|
||||
features = dataset.meta.features
|
||||
meta_keys = {"index", "episode_index", "task_index", "frame_index", "timestamp"}
|
||||
numeric_features = {
|
||||
k: v
|
||||
for k, v in features.items()
|
||||
if v["dtype"] not in ["image", "video", "string"] and k not in meta_keys
|
||||
}
|
||||
|
||||
if skip_image_video:
|
||||
features_to_compute = numeric_features
|
||||
else:
|
||||
features_to_compute = {
|
||||
k: v for k, v in features.items() if v["dtype"] != "string" and k not in meta_keys
|
||||
}
|
||||
|
||||
# When relative_action is enabled, compute action stats via chunk-based sampling
|
||||
# (matching what the model sees during training) and skip action in the
|
||||
# per-episode pass below.
|
||||
relative_action_stats = None
|
||||
if relative_action and ACTION in features and OBS_STATE in features:
|
||||
if relative_exclude_joints is None:
|
||||
relative_exclude_joints = ["gripper"]
|
||||
relative_action_stats = compute_relative_action_stats(
|
||||
hf_dataset=dataset.hf_dataset,
|
||||
features=features,
|
||||
chunk_size=chunk_size,
|
||||
exclude_joints=relative_exclude_joints,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
features_to_compute.pop(ACTION, None)
|
||||
|
||||
# When relative_state is enabled, compute state stats over the flattened
|
||||
# multi-timestep relative representation (matching what the model sees).
|
||||
relative_state_stats = None
|
||||
if relative_state and (OBS_STATE in features or derive_state_from_action):
|
||||
source_key = ACTION if derive_state_from_action else OBS_STATE
|
||||
relative_state_stats = compute_relative_state_stats(
|
||||
hf_dataset=dataset.hf_dataset,
|
||||
features=features,
|
||||
state_obs_steps=state_obs_steps,
|
||||
exclude_joints=relative_exclude_state_joints,
|
||||
source_key=source_key,
|
||||
)
|
||||
features_to_compute.pop(OBS_STATE, None)
|
||||
|
||||
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
for ep_idx in sorted(df["episode_index"].unique()):
|
||||
ep_df = df[df["episode_index"] == ep_idx]
|
||||
episode_data = {}
|
||||
for key in numeric_keys:
|
||||
if key in ep_df.columns:
|
||||
values = ep_df[key].values
|
||||
if hasattr(values[0], "__len__"):
|
||||
episode_data[key] = np.stack(values)
|
||||
else:
|
||||
episode_data[key] = np.array(values)
|
||||
|
||||
ep_stats = compute_episode_stats(episode_data, features_to_compute)
|
||||
all_episode_stats.append(ep_stats)
|
||||
|
||||
if features_to_compute and not all_episode_stats:
|
||||
logging.warning("No episode stats computed")
|
||||
return dataset
|
||||
|
||||
new_stats = aggregate_stats(all_episode_stats) if all_episode_stats else {}
|
||||
|
||||
if relative_action_stats is not None:
|
||||
new_stats[ACTION] = relative_action_stats
|
||||
|
||||
if relative_state_stats is not None:
|
||||
new_stats[OBS_STATE] = relative_state_stats
|
||||
|
||||
# Merge: keep existing stats for features we didn't recompute
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in new_stats:
|
||||
new_stats[key] = value
|
||||
|
||||
write_stats(new_stats, dataset.root)
|
||||
dataset.meta.stats = new_stats
|
||||
|
||||
logging.info("Stats recomputed successfully")
|
||||
return dataset
|
||||
|
||||
|
||||
def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path | None = None,
|
||||
|
||||
@@ -1,634 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Private writer component for LeRobotDataset. Handles sequential recording (episode buffer, ParquetWriter, image writer, video encoding)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.compute_stats import compute_episode_stats
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import (
|
||||
get_hf_features_from_features,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
)
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.io_utils import (
|
||||
embed_images,
|
||||
get_file_size_in_mb,
|
||||
load_episodes,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
concatenate_video_files,
|
||||
encode_video_frames,
|
||||
get_video_duration_in_s,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _encode_video_worker(
|
||||
video_key: str,
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
|
||||
class DatasetWriter:
|
||||
"""Encapsulates write-side state and methods for LeRobotDataset.
|
||||
|
||||
Owns: episode_buffer, image_writer, _pq_writer (ParquetWriter), _latest_episode,
|
||||
_current_file_start_frame, _streaming_encoder, _episodes_since_last_encoding, _recorded_frames.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
vcodec: str,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
initial_frames: int = 0,
|
||||
):
|
||||
"""Initialize the writer with metadata, codec, and encoding config.
|
||||
|
||||
Args:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
|
||||
for real-time encoding. ``None`` disables streaming mode.
|
||||
initial_frames: Starting frame count (non-zero when resuming).
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._vcodec = vcodec
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
|
||||
# Writer state
|
||||
self.image_writer: AsyncImageWriter | None = None
|
||||
self.episode_buffer: dict = self._create_episode_buffer()
|
||||
self._pq_writer: pq.ParquetWriter | None = None
|
||||
self._latest_episode: dict | None = None
|
||||
self._current_file_start_frame: int | None = None
|
||||
self._episodes_since_last_encoding: int = 0
|
||||
self._recorded_frames: int = initial_frames
|
||||
self._finalized = False
|
||||
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self._meta.total_episodes if episode_index is None else episode_index
|
||||
ep_buffer = {}
|
||||
ep_buffer["size"] = 0
|
||||
ep_buffer["task"] = []
|
||||
for key in self._meta.features:
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
def _save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
||||
) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
write_image(image, fpath, compress_level=compress_level)
|
||||
else:
|
||||
self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level)
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
Add a single frame to the current episode buffer.
|
||||
|
||||
Apart from images written to a temporary directory, nothing is written to disk
|
||||
until ``save_episode()`` is called.
|
||||
|
||||
The caller must provide all user-defined features plus ``"task"``, and must
|
||||
not provide ``"timestamp"`` or ``"frame_index"``; those are computed
|
||||
automatically.
|
||||
"""
|
||||
# Convert torch to numpy if needed
|
||||
for name in frame:
|
||||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
|
||||
validate_frame(frame, self._meta.features)
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
# Automatically add frame_index and timestamp to episode buffer
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame_index / self._meta.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
self.episode_buffer["task"].append(frame.pop("task"))
|
||||
|
||||
# Start streaming encoder on first frame of episode
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self._meta.video_keys),
|
||||
temp_dir=self._root,
|
||||
)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key in frame:
|
||||
if key not in self._meta.features:
|
||||
raise ValueError(
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self._meta.features.keys()}'."
|
||||
)
|
||||
|
||||
if self._meta.features[key]["dtype"] == "video" and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.feed_frame(key, frame[key])
|
||||
self.episode_buffer[key].append(None)
|
||||
elif self._meta.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
compress_level = 1 if self._meta.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_data: dict | None = None,
|
||||
parallel_encoding: bool = True,
|
||||
) -> None:
|
||||
"""Save the current episode in self.episode_buffer to disk."""
|
||||
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
|
||||
|
||||
validate_episode_buffer(episode_buffer, self._meta.total_episodes, self._meta.features)
|
||||
|
||||
# size and task are special cases that won't be added to hf_dataset
|
||||
episode_length = episode_buffer.pop("size")
|
||||
tasks = episode_buffer.pop("task")
|
||||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
|
||||
episode_buffer["index"] = np.arange(self._meta.total_frames, self._meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
# Update tasks and task indices with new tasks if any
|
||||
self._meta.save_episode_tasks(episode_tasks)
|
||||
|
||||
# Given tasks in natural language, find their corresponding task indices
|
||||
episode_buffer["task_index"] = np.array([self._meta.get_task_index(task) for task in tasks])
|
||||
|
||||
for key, ft in self._meta.features.items():
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
|
||||
has_video_keys = len(self._meta.video_keys) > 0
|
||||
use_streaming = self._streaming_encoder is not None and has_video_keys
|
||||
use_batched_encoding = self._batch_encoding_size > 1
|
||||
|
||||
if use_streaming:
|
||||
non_video_buffer = {
|
||||
k: v
|
||||
for k, v in episode_buffer.items()
|
||||
if self._meta.features.get(k, {}).get("dtype") not in ("video",)
|
||||
}
|
||||
non_video_features = {k: v for k, v in self._meta.features.items() if v["dtype"] != "video"}
|
||||
ep_stats = compute_episode_stats(non_video_buffer, non_video_features)
|
||||
else:
|
||||
ep_stats = compute_episode_stats(episode_buffer, self._meta.features)
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
|
||||
if use_streaming:
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self._meta.video_keys:
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
elif has_video_keys and not use_batched_encoding:
|
||||
num_cameras = len(self._meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
|
||||
future_to_key = {
|
||||
executor.submit(
|
||||
_encode_video_worker,
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._vcodec,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
}
|
||||
|
||||
results = {}
|
||||
for future in concurrent.futures.as_completed(future_to_key):
|
||||
video_key = future_to_key[future]
|
||||
try:
|
||||
temp_path = future.result()
|
||||
results[video_key] = temp_path
|
||||
except Exception as exc:
|
||||
logger.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
raise exc
|
||||
|
||||
for video_key in self._meta.video_keys:
|
||||
temp_path = results[video_key]
|
||||
ep_metadata.update(
|
||||
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
|
||||
)
|
||||
else:
|
||||
for video_key in self._meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self._meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
if has_video_keys and use_batched_encoding:
|
||||
self._episodes_since_last_encoding += 1
|
||||
if self._episodes_since_last_encoding == self._batch_encoding_size:
|
||||
start_ep = self._meta.total_episodes - self._batch_encoding_size
|
||||
end_ep = self._meta.total_episodes
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
self._episodes_since_last_encoding = 0
|
||||
|
||||
if episode_data is None:
|
||||
self.clear_episode_buffer(delete_images=len(self._meta.image_keys) > 0)
|
||||
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""Batch save videos for multiple episodes."""
|
||||
if end_episode is None:
|
||||
end_episode = self._meta.total_episodes
|
||||
|
||||
logger.info(
|
||||
f"Batch encoding {self._batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
|
||||
)
|
||||
|
||||
chunk_idx = self._meta.episodes[start_episode]["data/chunk_index"]
|
||||
file_idx = self._meta.episodes[start_episode]["data/file_index"]
|
||||
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
for ep_idx in range(start_episode, end_episode):
|
||||
logger.info(f"Encoding videos for episode {ep_idx}")
|
||||
|
||||
if (
|
||||
self._meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||
or self._meta.episodes[ep_idx]["data/file_index"] != file_idx
|
||||
):
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self._meta.episodes = load_episodes(self._root)
|
||||
|
||||
chunk_idx = self._meta.episodes[ep_idx]["data/chunk_index"]
|
||||
file_idx = self._meta.episodes[ep_idx]["data/file_index"]
|
||||
episode_df_path = self._root / DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
video_ep_metadata = {}
|
||||
for video_key in self._meta.video_keys:
|
||||
video_ep_metadata.update(self._save_episode_video(video_key, ep_idx))
|
||||
video_ep_metadata.pop("episode_index")
|
||||
video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||
dtype_backend="pyarrow"
|
||||
)
|
||||
|
||||
episode_df = episode_df.combine_first(video_ep_df)
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self._meta.episodes = load_episodes(self._root)
|
||||
|
||||
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
||||
"""Save episode data to a parquet file."""
|
||||
# Use metadata features as the authoritative schema
|
||||
hf_features = get_hf_features_from_features(self._meta.features)
|
||||
ep_dict = {key: episode_buffer[key] for key in hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(ep_dict, features=hf_features, split="train")
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
ep_num_frames = len(ep_dataset)
|
||||
|
||||
if self._latest_episode is None:
|
||||
chunk_idx, file_idx = 0, 0
|
||||
global_frame_index = 0
|
||||
self._current_file_start_frame = 0
|
||||
if self._meta.episodes is not None and len(self._meta.episodes) > 0:
|
||||
latest_ep = self._meta.episodes[-1]
|
||||
global_frame_index = latest_ep["dataset_to_index"]
|
||||
chunk_idx = latest_ep["data/chunk_index"]
|
||||
file_idx = latest_ep["data/file_index"]
|
||||
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
|
||||
self._current_file_start_frame = global_frame_index
|
||||
else:
|
||||
latest_ep = self._latest_episode
|
||||
chunk_idx = latest_ep["data/chunk_index"]
|
||||
file_idx = latest_ep["data/file_index"]
|
||||
global_frame_index = latest_ep["index"][-1] + 1
|
||||
|
||||
latest_path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
|
||||
frames_in_current_file = global_frame_index - self._current_file_start_frame
|
||||
av_size_per_frame = (
|
||||
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
|
||||
)
|
||||
|
||||
if latest_size_in_mb + av_size_per_frame * ep_num_frames >= self._meta.data_files_size_in_mb:
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
|
||||
self.close_writer()
|
||||
self._current_file_start_frame = global_frame_index
|
||||
|
||||
ep_dict["data/chunk_index"] = chunk_idx
|
||||
ep_dict["data/file_index"] = file_idx
|
||||
|
||||
path = self._root / self._meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
if not self._pq_writer:
|
||||
self._pq_writer = pq.ParquetWriter(
|
||||
path, schema=table.schema, compression="snappy", use_dictionary=True
|
||||
)
|
||||
self._pq_writer.write_table(table)
|
||||
|
||||
metadata = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
"dataset_from_index": global_frame_index,
|
||||
"dataset_to_index": global_frame_index + ep_num_frames,
|
||||
}
|
||||
|
||||
self._latest_episode = {**ep_dict, **metadata}
|
||||
self._recorded_frames += ep_num_frames
|
||||
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(
|
||||
self,
|
||||
video_key: str,
|
||||
episode_index: int,
|
||||
temp_path: Path | None = None,
|
||||
) -> dict:
|
||||
if temp_path is None:
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
else:
|
||||
ep_path = temp_path
|
||||
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
if (
|
||||
episode_index == 0
|
||||
or self._meta.latest_episode is None
|
||||
or f"videos/{video_key}/chunk_index" not in self._meta.latest_episode
|
||||
):
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self._meta.episodes is not None and len(self._meta.episodes) > 0:
|
||||
old_chunk_idx = self._meta.episodes[-1][f"videos/{video_key}/chunk_index"]
|
||||
old_file_idx = self._meta.episodes[-1][f"videos/{video_key}/file_index"]
|
||||
chunk_idx, file_idx = update_chunk_file_indices(
|
||||
old_chunk_idx, old_file_idx, self._meta.chunks_size
|
||||
)
|
||||
latest_duration_in_s = 0.0
|
||||
new_path = self._root / self._meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
latest_ep = self._meta.latest_episode
|
||||
chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"][0]
|
||||
file_idx = latest_ep[f"videos/{video_key}/file_index"][0]
|
||||
|
||||
latest_path = self._root / self._meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
latest_duration_in_s = latest_ep[f"videos/{video_key}/to_timestamp"][0]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self._meta.video_files_size_in_mb:
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self._meta.chunks_size)
|
||||
new_path = self._root / self._meta.video_path.format(
|
||||
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
concatenate_video_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
|
||||
# Remove temporary directory
|
||||
shutil.rmtree(str(ep_path.parent))
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
"episode_index": episode_index,
|
||||
f"videos/{video_key}/chunk_index": chunk_idx,
|
||||
f"videos/{video_key}/file_index": file_idx,
|
||||
f"videos/{video_key}/from_timestamp": latest_duration_in_s,
|
||||
f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
"""Discard the current episode buffer and optionally delete temp images.
|
||||
|
||||
Args:
|
||||
delete_images: If ``True``, remove temporary image directories
|
||||
written for the current episode.
|
||||
"""
|
||||
# Cancel streaming encoder if active
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.cancel_episode()
|
||||
|
||||
if delete_images:
|
||||
if self.image_writer is not None:
|
||||
self._wait_image_writer()
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
# episode_index is `int` when freshly created, but becomes `np.ndarray` after
|
||||
# save_episode() mutates the buffer. Handle both types here.
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for cam_key in self._meta.image_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||
"""Start an :class:`AsyncImageWriter` for background image persistence.
|
||||
|
||||
Args:
|
||||
num_processes: Number of subprocesses. ``0`` means threads only.
|
||||
num_threads: Number of threads per process.
|
||||
"""
|
||||
if isinstance(self.image_writer, AsyncImageWriter):
|
||||
logger.warning(
|
||||
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
|
||||
)
|
||||
|
||||
self.image_writer = AsyncImageWriter(
|
||||
num_processes=num_processes,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def stop_image_writer(self) -> None:
|
||||
"""Stop the image writer (needed before pickling the dataset for DataLoader)."""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
self.image_writer = None
|
||||
|
||||
def _wait_image_writer(self) -> None:
|
||||
"""Wait for asynchronous image writer to finish."""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait_until_done()
|
||||
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
return _encode_video_worker(
|
||||
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
|
||||
)
|
||||
|
||||
def close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
if self._pq_writer is not None:
|
||||
self._pq_writer.close()
|
||||
self._pq_writer = None
|
||||
|
||||
def flush_pending_videos(self) -> None:
|
||||
"""Flush any pending video encoding (streaming or batch).
|
||||
|
||||
For streaming encoding: closes the encoder.
|
||||
For batch encoding: encodes any remaining episodes that haven't been batch-encoded yet.
|
||||
"""
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.close()
|
||||
elif self._episodes_since_last_encoding > 0:
|
||||
start_ep = self._meta.total_episodes - self._episodes_since_last_encoding
|
||||
end_ep = self._meta.total_episodes
|
||||
logger.info(
|
||||
f"Encoding remaining {self._episodes_since_last_encoding} episodes, "
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
|
||||
def cancel_pending_videos(self) -> None:
|
||||
"""Cancel any in-progress streaming encoding without flushing."""
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.cancel_episode()
|
||||
|
||||
def cleanup_interrupted_episode(self, episode_index: int) -> None:
|
||||
"""Remove temporary image directories for an interrupted episode."""
|
||||
for key in self._meta.video_keys:
|
||||
img_dir = self._get_image_file_path(
|
||||
episode_index=episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
if img_dir.exists():
|
||||
logger.debug(
|
||||
f"Cleaning up interrupted episode images for episode {episode_index}, camera {key}"
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
def finalize(self) -> None:
|
||||
"""Flush all pending work and release all resources.
|
||||
|
||||
Idempotent — safe to call multiple times.
|
||||
"""
|
||||
if getattr(self, "_finalized", False):
|
||||
return
|
||||
# 1. Wait for async image writes to complete, then stop
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait_until_done()
|
||||
self.image_writer.stop()
|
||||
self.image_writer = None
|
||||
# 2. Flush pending video encoding (streaming or batch)
|
||||
self.flush_pending_videos()
|
||||
# 3. Close own parquet writer
|
||||
self.close_writer()
|
||||
# 4. Finalize metadata (idempotent)
|
||||
self._meta.finalize()
|
||||
self._finalized = True
|
||||
|
||||
def __del__(self):
|
||||
"""Safety net: release resources on garbage collection."""
|
||||
# During interpreter shutdown, referenced objects may already be collected.
|
||||
with contextlib.suppress(Exception):
|
||||
self.finalize()
|
||||
@@ -25,7 +25,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, OBS_STATE, REWARD
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -52,15 +52,12 @@ def resolve_delta_timestamps(
|
||||
returns `None` if the resulting dict is empty.
|
||||
"""
|
||||
delta_timestamps = {}
|
||||
state_delta = getattr(cfg, "state_delta_indices", None)
|
||||
for key in ds_meta.features:
|
||||
if key == REWARD and cfg.reward_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == ACTION and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
if key == OBS_STATE and state_delta is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in state_delta]
|
||||
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
|
||||
@@ -365,10 +365,6 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
|
||||
|
||||
|
||||
def validate_frame(frame: dict, features: dict) -> None:
|
||||
# DEFAULT_FEATURES (timestamp, frame_index, episode_index, index, task_index) are
|
||||
# auto-populated by the recording pipeline (add_frame / save_episode) and must not
|
||||
# be supplied by the caller. Excluding them here means any frame dict that contains
|
||||
# these keys will be rejected as extra features.
|
||||
expected_features = set(features) - set(DEFAULT_FEATURES)
|
||||
actual_features = set(frame)
|
||||
|
||||
|
||||
@@ -32,10 +32,10 @@ def safe_stop_image_writer(func):
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
dataset = kwargs.get("dataset")
|
||||
writer = getattr(dataset, "writer", None) if dataset else None
|
||||
if writer is not None and writer.image_writer is not None:
|
||||
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
||||
if image_writer is not None:
|
||||
logger.warning("Waiting for image writer to terminate...")
|
||||
writer.image_writer.stop()
|
||||
image_writer.stop()
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -22,7 +22,6 @@ import torch
|
||||
import torch.utils
|
||||
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import VideoFrame
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
@@ -126,13 +125,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update(
|
||||
{
|
||||
k: v
|
||||
for k, v in get_hf_features_from_features(dataset.features).items()
|
||||
if k not in self.disabled_features
|
||||
}
|
||||
)
|
||||
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
|
||||
return features
|
||||
|
||||
@property
|
||||
|
||||
@@ -255,9 +255,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory to use for local datasets. When omitted, Hub
|
||||
metadata is resolved through a revision-safe snapshot cache under
|
||||
``$HF_LEROBOT_HOME/hub``.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list.
|
||||
image_transforms (Callable | None, optional): Transform to apply to image data.
|
||||
@@ -273,8 +271,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self._requested_root = Path(root) if root else None
|
||||
self.root = self._requested_root if self._requested_root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.streaming_from_local = root is not None
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
@@ -291,15 +288,12 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
|
||||
if self._requested_root is not None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self._requested_root, self.revision, force_cache_sync=force_cache_sync
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
self.root = self.meta.root
|
||||
self.revision = self.meta.revision
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -102,18 +101,6 @@ DEFAULT_FEATURES = {
|
||||
}
|
||||
|
||||
|
||||
def has_legacy_hub_download_metadata(root: Path) -> bool:
|
||||
"""Return ``True`` when *root* looks like a legacy Hub ``local_dir`` mirror.
|
||||
|
||||
``snapshot_download(local_dir=...)`` stores lightweight metadata under
|
||||
``<local_dir>/.cache/huggingface/download/``. The presence of this
|
||||
directory is a reliable indicator that the dataset was downloaded with
|
||||
the old non-revision-safe ``local_dir`` mode and should be re-fetched
|
||||
through the snapshot cache instead.
|
||||
"""
|
||||
return (root / ".cache" / "huggingface" / "download").exists()
|
||||
|
||||
|
||||
def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]:
|
||||
if file_idx == chunks_size - 1:
|
||||
file_idx = 0
|
||||
|
||||
@@ -741,7 +741,6 @@ class StreamingVideoEncoder:
|
||||
self._video_paths: dict[str, Path] = {}
|
||||
self._dropped_frames: dict[str, int] = {}
|
||||
self._episode_active = False
|
||||
self._closed = False
|
||||
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
@@ -896,11 +895,8 @@ class StreamingVideoEncoder:
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the encoder, canceling any in-progress episode."""
|
||||
if self._closed:
|
||||
return
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
self._closed = True
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
"""Clean up queues and thread tracking dicts."""
|
||||
@@ -1067,19 +1063,43 @@ class VideoEncodingManager:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
writer = self.dataset.writer
|
||||
if writer is not None:
|
||||
if exc_type is not None and writer._streaming_encoder is not None:
|
||||
writer.cancel_pending_videos()
|
||||
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
|
||||
|
||||
# finalize() handles flush_pending_videos + parquet + metadata
|
||||
self.dataset.finalize()
|
||||
if streaming_encoder is not None:
|
||||
# Handle streaming encoder cleanup
|
||||
if exc_type is not None:
|
||||
streaming_encoder.cancel_episode()
|
||||
streaming_encoder.close()
|
||||
elif self.dataset.episodes_since_last_encoding > 0:
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if exc_type is not None:
|
||||
logger.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
else:
|
||||
logger.info("Recording stopped. Encoding remaining episodes...")
|
||||
|
||||
# Clean up episode images if recording was interrupted (only for non-streaming mode)
|
||||
if exc_type is not None and writer._streaming_encoder is None:
|
||||
writer.cleanup_interrupted_episode(self.dataset.num_episodes)
|
||||
else:
|
||||
self.dataset.finalize()
|
||||
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
|
||||
end_ep = self.dataset.num_episodes
|
||||
logger.info(
|
||||
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
self.dataset._batch_save_episode_video(start_ep, end_ep)
|
||||
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
|
||||
# Clean up episode images if recording was interrupted (only for non-streaming mode)
|
||||
if exc_type is not None and streaming_encoder is None:
|
||||
interrupted_episode_index = self.dataset.num_episodes
|
||||
for key in self.dataset.meta.video_keys:
|
||||
img_dir = self.dataset._get_image_file_path(
|
||||
episode_index=interrupted_episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
if img_dir.exists():
|
||||
logger.debug(
|
||||
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
# Clean up any remaining images directory if it's empty
|
||||
img_dir = self.dataset.root / "images"
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .groot.configuration_groot import GrootConfig as GrootConfig
|
||||
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as MultiTaskDiTConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
@@ -29,7 +28,6 @@ from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
"DiffusionConfig",
|
||||
"MultiTaskDiTConfig",
|
||||
"PI0Config",
|
||||
"PI05Config",
|
||||
"PI0FastConfig",
|
||||
|
||||
@@ -31,7 +31,6 @@ from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -59,29 +58,6 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
def _reconnect_relative_absolute_steps(
|
||||
preprocessor: PolicyProcessorPipeline, postprocessor: PolicyProcessorPipeline
|
||||
) -> None:
|
||||
"""Wire AbsoluteActionsProcessorStep.relative_step to the RelativeActionsProcessorStep after deserialization.
|
||||
|
||||
After a policy is loaded from disk, the preprocessor and postprocessor are reconstructed
|
||||
independently from their configs. AbsoluteActionsProcessorStep needs a live reference to
|
||||
the RelativeActionsProcessorStep so it can read the cached state at inference time.
|
||||
That reference is not serializable, so we re-establish it here after loading.
|
||||
"""
|
||||
from lerobot.processor.relative_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
)
|
||||
|
||||
relative_step = next((s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep)), None)
|
||||
if relative_step is None:
|
||||
return
|
||||
for step in postprocessor.steps:
|
||||
if isinstance(step, AbsoluteActionsProcessorStep) and step.relative_step is None:
|
||||
step.relative_step = relative_step
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
"""
|
||||
Retrieves a policy class by its registered name.
|
||||
@@ -91,7 +67,8 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"multi_task_dit", "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
|
||||
@@ -110,10 +87,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
return ACTPolicy
|
||||
elif name == "multi_task_dit":
|
||||
from lerobot.policies.multi_task_dit.modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
|
||||
return MultiTaskDiTPolicy
|
||||
elif name == "vqbet":
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
@@ -174,8 +147,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"multi_task_dit", "diffusion", "act", "vqbet", "pi0", "pi05", "sac",
|
||||
"smolvla", "reward_classifier", "wall_x".
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -190,8 +163,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "act":
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "multi_task_dit":
|
||||
return MultiTaskDiTConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
return VQBeTConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
@@ -292,26 +263,26 @@ def make_pre_post_processors(
|
||||
kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
return (
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get(
|
||||
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get(
|
||||
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||
),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
# Create a new processor based on policy type
|
||||
if isinstance(policy_cfg, TDMPCConfig):
|
||||
@@ -338,16 +309,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, MultiTaskDiTConfig):
|
||||
from lerobot.policies.multi_task_dit.processor_multi_task_dit import (
|
||||
make_multi_task_dit_pre_post_processors,
|
||||
)
|
||||
|
||||
processors = make_multi_task_dit_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VQBeTConfig):
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
|
||||
@@ -509,13 +470,6 @@ def make_policy(
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
|
||||
# Store action feature names for relative_exclude_joints support
|
||||
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
|
||||
action_names = ds_meta.features.get(ACTION, {}).get("names")
|
||||
if action_names is not None:
|
||||
cfg.action_feature_names = list(action_names)
|
||||
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
# Multitask DiT Policy
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite the following works:
|
||||
|
||||
```bibtex
|
||||
@misc{jones2025multitaskditpolicy,
|
||||
author = {Bryson Jones},
|
||||
title = {Dissecting and Open-Sourcing Multitask Diffusion Transformer Policy},
|
||||
year = {2025},
|
||||
url = {https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{trilbmteam2025carefulexaminationlargebehaviormodels,
|
||||
author = {TRI LBM Team},
|
||||
title = {A Careful Examination of Large Behavior Models for Multitask Dexterous Manipulation},
|
||||
year = {2025},
|
||||
eprint = {arXiv:2507.05331},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.RO},
|
||||
url = {https://arxiv.org/abs/2507.05331}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{bostondynamics2025largebehaviormodelsatlas,
|
||||
author = {Boston Dynamics and TRI Research Team},
|
||||
title = {Large Behavior Models and Atlas Find New Footing},
|
||||
year = {2025},
|
||||
url = {https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/},
|
||||
note = {Blog post}
|
||||
}
|
||||
```
|
||||
@@ -1,21 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from .modeling_multi_task_dit import MultiTaskDiTPolicy
|
||||
from .processor_multi_task_dit import make_multi_task_dit_pre_post_processors
|
||||
|
||||
__all__ = ["MultiTaskDiTConfig", "MultiTaskDiTPolicy", "make_multi_task_dit_pre_post_processors"]
|
||||
@@ -1,256 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamConfig
|
||||
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("multi_task_dit")
|
||||
@dataclass
|
||||
class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
"""Configuration for the Multi-Task Diffusion Transformer (DiT) policy.
|
||||
|
||||
A transformer-based policy that supports both diffusion and flow matching objectives
|
||||
for multi-task robot learning with text and vision conditioning.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 2 # Number of observation steps for temporal context
|
||||
horizon: int = 32 # Number of action steps to predict
|
||||
n_action_steps: int = 24 # Actions executed per policy call (~0.8s at 30Hz)
|
||||
|
||||
# Objective Selection
|
||||
objective: str = "diffusion" # "diffusion" or "flow_matching"
|
||||
|
||||
# --- Diffusion-specific (used when objective="diffusion") ---
|
||||
noise_scheduler_type: str = "DDPM" # "DDPM" or "DDIM"
|
||||
num_train_timesteps: int = 100 # Number of diffusion timesteps
|
||||
beta_schedule: str = "squaredcos_cap_v2" # Noise schedule type
|
||||
beta_start: float = 0.0001 # Starting noise level
|
||||
beta_end: float = 0.02 # Ending noise level
|
||||
prediction_type: str = "epsilon" # "epsilon" (predict noise) or "sample" (predict clean)
|
||||
clip_sample: bool = True # Clip samples during denoising
|
||||
clip_sample_range: float = 1.0 # Clipping range [-x, x]
|
||||
num_inference_steps: int | None = None # Denoising steps at inference (defaults to num_train_timesteps)
|
||||
|
||||
# --- Flow Matching-specific (used when objective="flow_matching") ---
|
||||
sigma_min: float = 0.0 # Minimum noise in flow interpolation path
|
||||
num_integration_steps: int = 100 # ODE integration steps at inference
|
||||
integration_method: str = "euler" # ODE solver: "euler" or "rk4"
|
||||
timestep_sampling_strategy: str = "beta" # "uniform" or "beta"
|
||||
|
||||
timestep_sampling_s: float = 0.999 # (beta only) Max timestep threshold
|
||||
timestep_sampling_alpha: float = 1.5 # (beta only) Beta distribution alpha
|
||||
timestep_sampling_beta: float = 1.0 # (beta only) Beta distribution beta
|
||||
|
||||
# Transformer Architecture
|
||||
hidden_dim: int = 512 # Transformer hidden dimension
|
||||
num_layers: int = 6 # Number of transformer layers
|
||||
num_heads: int = 8 # Number of attention heads
|
||||
dropout: float = 0.1 # Dropout rate
|
||||
use_positional_encoding: bool = False # Use absolute positional encoding
|
||||
timestep_embed_dim: int = 256 # Timestep embedding dimension
|
||||
use_rope: bool = True # Use Rotary Position Embedding
|
||||
rope_base: float = 10000.0 # RoPE base frequency
|
||||
|
||||
# Vision Encoder (CLIP)
|
||||
vision_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
use_separate_rgb_encoder_per_camera: bool = False # Separate encoder per camera view
|
||||
vision_encoder_lr_multiplier: float = 0.1 # LR multiplier for vision encoder
|
||||
image_resize_shape: tuple[int, int] | None = None # Resize images before crop
|
||||
image_crop_shape: tuple[int, int] | None = (224, 224) # Crop shape (CLIP default)
|
||||
image_crop_is_random: bool = True # Random crop during training, center at inference
|
||||
|
||||
# Text Encoder (CLIP)
|
||||
text_encoder_name: str = "openai/clip-vit-base-patch16" # HuggingFace CLIP model
|
||||
tokenizer_max_length: int = 77 # Max length for tokenized text (CLIP default is 77)
|
||||
tokenizer_padding: str = "max_length" # Padding strategy: "max_length" or "longest"
|
||||
tokenizer_padding_side: str = "right" # Padding side: "left" or "right"
|
||||
tokenizer_truncation: bool = True # Whether to truncate sequences longer than max_length
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Training/Optimizer
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: tuple = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.0
|
||||
scheduler_name: str = "cosine"
|
||||
scheduler_warmup_steps: int = 0
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
# Auto-calculated
|
||||
drop_n_last_frames: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.drop_n_last_frames is None:
|
||||
self.drop_n_last_frames = self.horizon - self.n_action_steps - self.n_obs_steps + 1
|
||||
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
"""Validate configuration parameters."""
|
||||
# Objective validation
|
||||
if self.objective not in ["diffusion", "flow_matching"]:
|
||||
raise ValueError(f"objective must be 'diffusion' or 'flow_matching', got '{self.objective}'")
|
||||
|
||||
# Transformer validation
|
||||
if self.hidden_dim <= 0:
|
||||
raise ValueError("hidden_dim must be positive")
|
||||
if self.num_layers <= 0:
|
||||
raise ValueError("num_layers must be positive")
|
||||
if self.num_heads <= 0:
|
||||
raise ValueError("num_heads must be positive")
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError("hidden_dim must be divisible by num_heads")
|
||||
if not (0.0 <= self.dropout <= 1.0):
|
||||
raise ValueError("dropout must be between 0.0 and 1.0")
|
||||
|
||||
# Vision encoder validation
|
||||
if "clip" not in self.vision_encoder_name.lower():
|
||||
raise ValueError(
|
||||
f"vision_encoder_name must be a CLIP model (contain 'clip'), got '{self.vision_encoder_name}'"
|
||||
)
|
||||
if (
|
||||
self.image_resize_shape
|
||||
and self.image_crop_shape
|
||||
and (
|
||||
self.image_crop_shape[0] > self.image_resize_shape[0]
|
||||
or self.image_crop_shape[1] > self.image_resize_shape[1]
|
||||
)
|
||||
):
|
||||
logging.warning(
|
||||
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
|
||||
self.image_crop_shape,
|
||||
self.image_resize_shape,
|
||||
)
|
||||
self.image_crop_shape = None
|
||||
|
||||
# Text encoder validation
|
||||
if "clip" not in self.text_encoder_name.lower():
|
||||
raise ValueError(
|
||||
f"text_encoder_name must be a CLIP model (contain 'clip'), got '{self.text_encoder_name}'"
|
||||
)
|
||||
|
||||
# Objective-specific validation
|
||||
if self.objective == "diffusion":
|
||||
if self.noise_scheduler_type not in ["DDPM", "DDIM"]:
|
||||
raise ValueError(
|
||||
f"noise_scheduler_type must be 'DDPM' or 'DDIM', got {self.noise_scheduler_type}"
|
||||
)
|
||||
if self.prediction_type not in ["epsilon", "sample"]:
|
||||
raise ValueError(f"prediction_type must be 'epsilon' or 'sample', got {self.prediction_type}")
|
||||
if self.num_train_timesteps <= 0:
|
||||
raise ValueError(f"num_train_timesteps must be positive, got {self.num_train_timesteps}")
|
||||
if not (0.0 <= self.beta_start <= self.beta_end <= 1.0):
|
||||
raise ValueError(f"Invalid beta values: {self.beta_start}, {self.beta_end}")
|
||||
|
||||
elif self.objective == "flow_matching":
|
||||
if not (0.0 <= self.sigma_min <= 1.0):
|
||||
raise ValueError(f"sigma_min must be in [0, 1], got {self.sigma_min}")
|
||||
if self.num_integration_steps <= 0:
|
||||
raise ValueError(f"num_integration_steps must be positive, got {self.num_integration_steps}")
|
||||
if self.integration_method not in ["euler", "rk4"]:
|
||||
raise ValueError(
|
||||
f"integration_method must be 'euler' or 'rk4', got {self.integration_method}"
|
||||
)
|
||||
if self.timestep_sampling_strategy not in ["uniform", "beta"]:
|
||||
raise ValueError("timestep_sampling_strategy must be 'uniform' or 'beta'")
|
||||
if self.timestep_sampling_strategy == "beta":
|
||||
if not (0.0 < self.timestep_sampling_s <= 1.0):
|
||||
raise ValueError(f"timestep_sampling_s must be in (0, 1], got {self.timestep_sampling_s}")
|
||||
if self.timestep_sampling_alpha <= 0:
|
||||
raise ValueError("timestep_sampling_alpha must be positive")
|
||||
if self.timestep_sampling_beta <= 0:
|
||||
raise ValueError("timestep_sampling_beta must be positive")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
return AdamConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
return DiffuserSchedulerConfig(
|
||||
name=self.scheduler_name,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate that required input features are present and properly configured."""
|
||||
# If the configured crop doesn't fit, disable cropping instead of erroring.
|
||||
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
|
||||
if self.image_crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
# image_ft.shape is (C, H, W)
|
||||
effective_h, effective_w = (
|
||||
self.image_resize_shape
|
||||
if self.image_resize_shape is not None
|
||||
else (image_ft.shape[1], image_ft.shape[2])
|
||||
)
|
||||
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
|
||||
logging.warning(
|
||||
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
|
||||
self.image_crop_shape,
|
||||
effective_h,
|
||||
effective_w,
|
||||
key,
|
||||
)
|
||||
self.image_crop_shape = None
|
||||
break
|
||||
|
||||
if len(self.image_features) > 0:
|
||||
first_key, first_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_ft.shape:
|
||||
raise ValueError(
|
||||
f"Image '{key}' shape {image_ft.shape} != '{first_key}' shape {first_ft.shape}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_diffusion(self) -> bool:
|
||||
return self.objective == "diffusion"
|
||||
|
||||
@property
|
||||
def is_flow_matching(self) -> bool:
|
||||
return self.objective == "flow_matching"
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -1,803 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Multi-Task Diffusion Transformer (DiT) Policy
|
||||
|
||||
Transformer-based diffusion policy for multi-task robot learning with text and vision conditioning.
|
||||
Supports both diffusion and flow matching objectives for action generation.
|
||||
|
||||
References:
|
||||
- https://arxiv.org/abs/2507.05331
|
||||
- https://bostondynamics.com/blog/large-behavior-models-atlas-find-new-footing/
|
||||
- https://brysonkjones.substack.com/p/dissecting-and-open-sourcing-multitask-diffusion-transformer-policy
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import CLIPTextModel, CLIPVisionModel
|
||||
else:
|
||||
CLIPTextModel = None
|
||||
CLIPVisionModel = None
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGES,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
# -- Policy --
|
||||
|
||||
|
||||
class MultiTaskDiTPolicy(PreTrainedPolicy):
|
||||
config_class = MultiTaskDiTConfig
|
||||
name = "multi_task_dit"
|
||||
|
||||
def __init__(self, config: MultiTaskDiTConfig, **kwargs):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self._queues = None
|
||||
|
||||
self.observation_encoder = ObservationEncoder(config)
|
||||
conditioning_dim = self.observation_encoder.conditioning_dim
|
||||
self.noise_predictor = DiffusionTransformer(config, conditioning_dim=conditioning_dim)
|
||||
|
||||
action_dim = config.action_feature.shape[0]
|
||||
horizon = config.horizon
|
||||
|
||||
if config.is_diffusion:
|
||||
self.objective = DiffusionObjective(
|
||||
config,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
elif config.is_flow_matching:
|
||||
self.objective = FlowMatchingObjective(
|
||||
config,
|
||||
action_dim=action_dim,
|
||||
horizon=horizon,
|
||||
do_mask_loss_for_padding=config.do_mask_loss_for_padding,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported objective: {config.objective}")
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> list:
|
||||
"""Returns parameter groups with different learning rates for vision vs non-vision parameters"""
|
||||
non_vision_params = []
|
||||
vision_encoder_params = []
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if "observation_encoder.vision_encoder" in name:
|
||||
vision_encoder_params.append(param)
|
||||
else:
|
||||
non_vision_params.append(param)
|
||||
|
||||
return [
|
||||
{"params": non_vision_params},
|
||||
{
|
||||
"params": vision_encoder_params,
|
||||
"lr": self.config.optimizer_lr * self.config.vision_encoder_lr_multiplier,
|
||||
},
|
||||
]
|
||||
|
||||
def _generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec)
|
||||
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.config.n_action_steps
|
||||
actions = actions[:, start:end]
|
||||
return actions
|
||||
|
||||
def reset(self):
|
||||
"""Clear observation and action queues. Should be called on `env.reset()`"""
|
||||
self._queues = {
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
if self.config.image_features:
|
||||
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations"""
|
||||
self.eval()
|
||||
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
|
||||
actions = self._generate_actions(batch)
|
||||
return actions
|
||||
|
||||
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Prepare batch by stacking image features if needed."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy to avoid modifying original
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations"""
|
||||
if ACTION in batch:
|
||||
batch = dict(batch) # shallow copy to avoid modifying original
|
||||
batch.pop(ACTION)
|
||||
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues[ACTION].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""Run the batch through the model and compute the loss for training"""
|
||||
batch = self._prepare_batch(batch)
|
||||
|
||||
conditioning_vec = self.observation_encoder.encode(batch)
|
||||
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)
|
||||
|
||||
return loss, None
|
||||
|
||||
|
||||
# -- Observation Encoders --
|
||||
|
||||
|
||||
class CLIPVisionEncoder(nn.Module):
|
||||
"""CLIP vision encoder using the CLS token for global image representation."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.model = CLIPVisionModel.from_pretrained(self.model_name)
|
||||
self.num_non_spatial_tokens = 1
|
||||
self.embed_dim = self.model.config.hidden_size
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Encode RGB image to CLS token."""
|
||||
outputs = self.model(pixel_values=x, output_hidden_states=False)
|
||||
cls_token = outputs.last_hidden_state[:, 0]
|
||||
b, embed_dim = cls_token.shape
|
||||
return cls_token.reshape(b, embed_dim, 1, 1)
|
||||
|
||||
def get_output_shape(self) -> tuple:
|
||||
return (self.embed_dim, 1, 1)
|
||||
|
||||
|
||||
class CLIPTextEncoder(nn.Module):
|
||||
"""CLIP text encoder with frozen weights and a learnable projection layer.
|
||||
|
||||
Accepts pre-tokenized inputs (input_ids and attention_mask) from the processor pipeline. See the processor
|
||||
pipeline to see how the tokenization is handled.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "openai/clip-vit-base-patch16", projection_dim: int = 512):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.projection_dim = projection_dim
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(model_name)
|
||||
|
||||
for param in self.text_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.text_embed_dim = self.text_encoder.config.hidden_size
|
||||
self.projection = nn.Linear(self.text_embed_dim, projection_dim)
|
||||
|
||||
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
"""Encode pre-tokenized text to feature vectors."""
|
||||
# Ensure inputs are on the same device as the model
|
||||
device = next(self.parameters()).device
|
||||
input_ids = input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
clip_features = outputs.pooler_output
|
||||
|
||||
return self.projection(clip_features)
|
||||
|
||||
|
||||
class ObservationEncoder(nn.Module):
|
||||
"""Handles all observation processing for the conditioning vector."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self._setup_preprocessing(config)
|
||||
|
||||
if config.image_features:
|
||||
self.num_cameras = len(config.image_features)
|
||||
self.camera_names = list(config.image_features.keys())
|
||||
|
||||
if config.use_separate_rgb_encoder_per_camera:
|
||||
self.vision_encoders = nn.ModuleList(
|
||||
[CLIPVisionEncoder(model_name=config.vision_encoder_name) for _ in self.camera_names]
|
||||
)
|
||||
self.vision_encoder = None
|
||||
else:
|
||||
self.vision_encoder = CLIPVisionEncoder(model_name=config.vision_encoder_name)
|
||||
self.vision_encoders = None
|
||||
else:
|
||||
self.vision_encoder = None
|
||||
self.vision_encoders = None
|
||||
self.camera_names = []
|
||||
self.num_cameras = 0
|
||||
|
||||
if hasattr(config, "robot_state_feature") and config.robot_state_feature:
|
||||
self.robot_state_dim = config.robot_state_feature.shape[0]
|
||||
else:
|
||||
self.robot_state_dim = 0
|
||||
|
||||
self.text_dim = config.hidden_dim
|
||||
self.text_encoder = CLIPTextEncoder(model_name=config.text_encoder_name, projection_dim=self.text_dim)
|
||||
|
||||
self._setup_vector_output()
|
||||
|
||||
def _apply_preprocessing(self, images: Tensor) -> Tensor:
|
||||
if self.do_resize:
|
||||
images = self.resize(images)
|
||||
if self.do_crop:
|
||||
images = self.maybe_random_crop(images) if self.training else self.center_crop(images)
|
||||
return images
|
||||
|
||||
def _setup_preprocessing(self, config):
|
||||
if config.image_resize_shape is not None:
|
||||
self.do_resize = True
|
||||
self.resize = torchvision.transforms.Resize(
|
||||
size=config.image_resize_shape,
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
||||
antialias=True,
|
||||
)
|
||||
else:
|
||||
self.do_resize = False
|
||||
|
||||
if config.image_crop_shape is not None:
|
||||
self.do_crop = True
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.image_crop_shape)
|
||||
if config.image_crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.image_crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
self.do_crop = False
|
||||
|
||||
def _setup_vector_output(self):
|
||||
total_dim = 0
|
||||
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
encoder_to_check = self.vision_encoder or next(iter(self.vision_encoders))
|
||||
feature_map_shape = encoder_to_check.get_output_shape()
|
||||
c, h, w = feature_map_shape
|
||||
spatial_feature_dim = c * h * w
|
||||
total_dim += spatial_feature_dim * self.num_cameras
|
||||
|
||||
total_dim += self.robot_state_dim
|
||||
total_dim += self.text_dim
|
||||
|
||||
self.conditioning_dim = total_dim * self.config.n_obs_steps
|
||||
|
||||
def encode(self, batch: dict) -> Tensor:
|
||||
"""Encode observations to vector format."""
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
conditioning_feats = []
|
||||
|
||||
conditioning_feats.append(batch[OBS_STATE])
|
||||
|
||||
if self.vision_encoder is not None or self.vision_encoders is not None:
|
||||
images = batch[OBS_IMAGES]
|
||||
|
||||
if len(images.shape) == 5:
|
||||
images = images.unsqueeze(1)
|
||||
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
camera_features = []
|
||||
for cam_idx in range(self.num_cameras):
|
||||
cam_images = images[:, :, cam_idx]
|
||||
cam_images_flat = einops.rearrange(cam_images, "b s c h w -> (b s) c h w")
|
||||
cam_images_flat = self._apply_preprocessing(cam_images_flat)
|
||||
cam_features = self.vision_encoders[cam_idx](cam_images_flat)
|
||||
cam_visual_features = cam_features.flatten(start_dim=1)
|
||||
cam_features_reshaped = einops.rearrange(
|
||||
cam_visual_features, "(b s) f -> b s f", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
camera_features.append(cam_features_reshaped)
|
||||
img_features = torch.cat(camera_features, dim=-1)
|
||||
conditioning_feats.append(img_features)
|
||||
else:
|
||||
images_flat = einops.rearrange(images, "b s n c h w -> (b s n) c h w")
|
||||
images_flat = self._apply_preprocessing(images_flat)
|
||||
visual_features = self.vision_encoder(images_flat).flatten(start_dim=1)
|
||||
img_features = einops.rearrange(
|
||||
visual_features, "(b s n) f -> b s (n f)", b=batch_size, s=n_obs_steps, n=self.num_cameras
|
||||
)
|
||||
conditioning_feats.append(img_features)
|
||||
|
||||
if self.text_encoder is not None and OBS_LANGUAGE_TOKENS in batch:
|
||||
input_ids = batch[OBS_LANGUAGE_TOKENS] # [batch_size, seq_length]
|
||||
attention_mask = batch[OBS_LANGUAGE_ATTENTION_MASK] # [batch_size, seq_length]
|
||||
|
||||
text_features = self.text_encoder(input_ids, attention_mask)
|
||||
|
||||
text_features = text_features.unsqueeze(1).expand(-1, n_obs_steps, -1)
|
||||
conditioning_feats.append(text_features)
|
||||
|
||||
combined_features = torch.cat(conditioning_feats, dim=-1)
|
||||
return combined_features.flatten(start_dim=1)
|
||||
|
||||
|
||||
# -- Transformer Components --
|
||||
|
||||
|
||||
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
|
||||
"""Modulate input with shift and scale for AdaLN-Zero."""
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""Sinusoidal positional embeddings for timesteps."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding (RoPE) for transformers."""
|
||||
|
||||
def __init__(self, head_dim: int, max_seq_len: int = 512, base: float = 10000.0):
|
||||
super().__init__()
|
||||
assert head_dim % 2 == 0, "head_dim must be even for RoPE"
|
||||
|
||||
self.head_dim = head_dim
|
||||
self.max_seq_len = max_seq_len
|
||||
self.base = base
|
||||
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._precompute_cache(max_seq_len)
|
||||
|
||||
def _precompute_cache(self, seq_len: int):
|
||||
t = torch.arange(seq_len, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
def _rotate_half(self, x: Tensor) -> Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor) -> tuple[Tensor, Tensor]:
|
||||
seq_len = q.shape[2]
|
||||
if seq_len > self.max_seq_len:
|
||||
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}.")
|
||||
|
||||
cos = self._cos_cached[:, :, :seq_len, :].to(q.dtype)
|
||||
sin = self._sin_cached[:, :, :seq_len, :].to(q.dtype)
|
||||
|
||||
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
|
||||
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
|
||||
return q_rotated, k_rotated
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
"""Multi-head self-attention with Rotary Position Embedding (RoPE)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
max_seq_len: int = 512,
|
||||
rope_base: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
assert hidden_size % num_heads == 0, "hidden_size must be divisible by num_heads"
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_size, hidden_size, bias=True)
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
self.rope = RotaryPositionalEmbedding(head_dim=self.head_dim, max_seq_len=max_seq_len, base=rope_base)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, T, _ = x.shape # noqa: N806
|
||||
|
||||
qkv = self.qkv_proj(x)
|
||||
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
|
||||
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q, k = self.rope(q, k)
|
||||
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.dropout.p if isinstance(self.dropout, nn.Dropout) and self.training else 0.0,
|
||||
)
|
||||
|
||||
attn_out = attn_out.transpose(1, 2).reshape(B, T, self.hidden_size)
|
||||
return self.out_proj(attn_out)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""DiT-style transformer block with AdaLN-Zero."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 128,
|
||||
num_heads: int = 4,
|
||||
num_features: int = 128,
|
||||
dropout: float = 0.0,
|
||||
use_rope: bool = False,
|
||||
max_seq_len: int = 512,
|
||||
rope_base: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_rope = use_rope
|
||||
|
||||
if use_rope:
|
||||
self.attn = RoPEAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout,
|
||||
max_seq_len=max_seq_len,
|
||||
rope_base=rope_base,
|
||||
)
|
||||
else:
|
||||
self.multihead_attn = nn.MultiheadAttention(
|
||||
hidden_size, num_heads=num_heads, batch_first=True, dropout=dropout
|
||||
)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size * 4),
|
||||
nn.GELU(approximate="tanh"),
|
||||
nn.Linear(hidden_size * 4, hidden_size),
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(num_features, 6 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, features: Tensor) -> Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
|
||||
features
|
||||
).chunk(6, dim=1)
|
||||
|
||||
attn_input = modulate(self.norm1(x), shift_msa.unsqueeze(1), scale_msa.unsqueeze(1))
|
||||
|
||||
if self.use_rope:
|
||||
attn_out = self.attn(attn_input)
|
||||
else:
|
||||
attn_out, _ = self.multihead_attn(attn_input, attn_input, attn_input)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1) * attn_out
|
||||
|
||||
mlp_input = modulate(self.norm2(x), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1))
|
||||
mlp_out = self.mlp(mlp_input)
|
||||
x = x + gate_mlp.unsqueeze(1) * mlp_out
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionTransformer(nn.Module):
|
||||
"""Transformer-based diffusion noise prediction model."""
|
||||
|
||||
def __init__(self, config, conditioning_dim: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.conditioning_dim = conditioning_dim
|
||||
|
||||
self.action_dim = config.action_feature.shape[0]
|
||||
self.horizon = config.horizon
|
||||
self.hidden_size = config.hidden_dim
|
||||
self.num_layers = config.num_layers
|
||||
self.num_heads = config.num_heads
|
||||
self.dropout = config.dropout
|
||||
self.use_rope = config.use_rope
|
||||
|
||||
self.timestep_embed_dim = config.timestep_embed_dim
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(self.timestep_embed_dim),
|
||||
nn.Linear(self.timestep_embed_dim, 2 * self.timestep_embed_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(2 * self.timestep_embed_dim, self.timestep_embed_dim),
|
||||
nn.GELU(),
|
||||
)
|
||||
|
||||
self.cond_dim = self.timestep_embed_dim + conditioning_dim
|
||||
self.input_proj = nn.Linear(self.action_dim, self.hidden_size)
|
||||
|
||||
if config.use_positional_encoding:
|
||||
self.pos_embedding = nn.Parameter(
|
||||
torch.empty(1, self.horizon, self.hidden_size).normal_(std=0.02)
|
||||
)
|
||||
else:
|
||||
self.pos_embedding = None
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TransformerBlock(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
num_features=self.cond_dim,
|
||||
dropout=self.dropout,
|
||||
use_rope=self.use_rope,
|
||||
max_seq_len=self.horizon,
|
||||
rope_base=config.rope_base,
|
||||
)
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.output_proj = nn.Linear(self.hidden_size, self.action_dim)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for block in self.transformer_blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
def forward(self, x: Tensor, timestep: Tensor, conditioning_vec: Tensor) -> Tensor:
|
||||
_, seq_len, _ = x.shape
|
||||
|
||||
timestep_features = self.time_mlp(timestep)
|
||||
cond_features = torch.cat([timestep_features, conditioning_vec], dim=-1)
|
||||
|
||||
hidden_seq = self.input_proj(x)
|
||||
|
||||
if self.pos_embedding is not None:
|
||||
hidden_seq = hidden_seq + self.pos_embedding[:, :seq_len, :]
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_seq = block(hidden_seq, cond_features)
|
||||
|
||||
return self.output_proj(hidden_seq)
|
||||
|
||||
|
||||
# -- Objectives --
|
||||
|
||||
|
||||
class DiffusionObjective(nn.Module):
|
||||
"""Standard diffusion (DDPM/DDIM) objective implementation."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
scheduler_kwargs = {
|
||||
"num_train_timesteps": config.num_train_timesteps,
|
||||
"beta_start": config.beta_start,
|
||||
"beta_end": config.beta_end,
|
||||
"beta_schedule": config.beta_schedule,
|
||||
"clip_sample": config.clip_sample,
|
||||
"clip_sample_range": config.clip_sample_range,
|
||||
"prediction_type": config.prediction_type,
|
||||
}
|
||||
|
||||
if config.noise_scheduler_type == "DDPM":
|
||||
self.noise_scheduler: DDPMScheduler | DDIMScheduler = DDPMScheduler(**scheduler_kwargs)
|
||||
elif config.noise_scheduler_type == "DDIM":
|
||||
self.noise_scheduler = DDIMScheduler(**scheduler_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported noise scheduler type {config.noise_scheduler_type}")
|
||||
|
||||
self.num_inference_steps = (
|
||||
config.num_inference_steps
|
||||
if config.num_inference_steps is not None
|
||||
else self.noise_scheduler.config.num_train_timesteps
|
||||
)
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
clean_actions = batch[ACTION]
|
||||
noise = torch.randn_like(clean_actions)
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
size=(clean_actions.shape[0],),
|
||||
device=clean_actions.device,
|
||||
).long()
|
||||
noisy_actions = self.noise_scheduler.add_noise(clean_actions, noise, timesteps)
|
||||
|
||||
prediction_type = self.noise_scheduler.config.prediction_type
|
||||
if prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif prediction_type == "sample":
|
||||
target = clean_actions
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type: {prediction_type}")
|
||||
|
||||
predicted = model(noisy_actions, timesteps, conditioning_vec=conditioning_vec)
|
||||
loss = F.mse_loss(predicted, target, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_actions = ~batch["action_is_pad"]
|
||||
loss = loss * valid_actions.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
sample = torch.randn(
|
||||
size=(batch_size, self.horizon, self.action_dim),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||
for t in self.noise_scheduler.timesteps:
|
||||
model_output = model(
|
||||
sample,
|
||||
torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device),
|
||||
conditioning_vec=conditioning_vec,
|
||||
)
|
||||
sample = self.noise_scheduler.step(model_output, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class FlowMatchingObjective(nn.Module):
|
||||
"""Flow matching objective: trains a model to predict velocity fields."""
|
||||
|
||||
def __init__(self, config, action_dim: int, horizon: int, do_mask_loss_for_padding: bool = False):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.action_dim = action_dim
|
||||
self.horizon = horizon
|
||||
self.do_mask_loss_for_padding = do_mask_loss_for_padding
|
||||
|
||||
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Tensor:
|
||||
if self.config.timestep_sampling_strategy == "uniform":
|
||||
return torch.rand(batch_size, device=device)
|
||||
elif self.config.timestep_sampling_strategy == "beta":
|
||||
beta_dist = torch.distributions.Beta(
|
||||
self.config.timestep_sampling_alpha, self.config.timestep_sampling_beta
|
||||
)
|
||||
u = beta_dist.sample((batch_size,)).to(device)
|
||||
return self.config.timestep_sampling_s * (1.0 - u)
|
||||
else:
|
||||
raise ValueError(f"Unknown timestep strategy: {self.config.timestep_sampling_strategy}")
|
||||
|
||||
def compute_loss(self, model: nn.Module, batch: dict[str, Tensor], conditioning_vec: Tensor) -> Tensor:
|
||||
data = batch[ACTION]
|
||||
batch_size = data.shape[0]
|
||||
device = data.device
|
||||
|
||||
noise = torch.randn_like(data)
|
||||
t = self._sample_timesteps(batch_size, device)
|
||||
t_expanded = t.view(-1, 1, 1)
|
||||
x_t = t_expanded * data + (1 - (1 - self.config.sigma_min) * t_expanded) * noise
|
||||
|
||||
target_velocity = data - (1 - self.config.sigma_min) * noise
|
||||
predicted_velocity = model(x_t, t, conditioning_vec=conditioning_vec)
|
||||
loss = F.mse_loss(predicted_velocity, target_velocity, reduction="none")
|
||||
|
||||
if self.do_mask_loss_for_padding and "action_is_pad" in batch:
|
||||
valid_mask = ~batch["action_is_pad"]
|
||||
loss = loss * valid_mask.unsqueeze(-1)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def conditional_sample(self, model: nn.Module, batch_size: int, conditioning_vec: Tensor) -> Tensor:
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
x = torch.randn((batch_size, self.horizon, self.action_dim), dtype=dtype, device=device)
|
||||
|
||||
num_steps = self.config.num_integration_steps
|
||||
time_grid = torch.linspace(0, 1, num_steps + 1, device=device)
|
||||
|
||||
if self.config.integration_method == "euler":
|
||||
x = self._euler_integrate(model, x, time_grid, conditioning_vec)
|
||||
elif self.config.integration_method == "rk4":
|
||||
x = self._rk4_integrate(model, x, time_grid, conditioning_vec)
|
||||
else:
|
||||
raise ValueError(f"Unknown integration method: {self.config.integration_method}")
|
||||
|
||||
return x
|
||||
|
||||
def _euler_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
x = x_init
|
||||
for i in range(len(time_grid) - 1):
|
||||
t_scalar = time_grid[i].item()
|
||||
dt = (time_grid[i + 1] - time_grid[i]).item()
|
||||
t_batch = torch.full((x.shape[0],), t_scalar, dtype=x.dtype, device=x.device)
|
||||
with torch.no_grad():
|
||||
velocity = model(x, t_batch, conditioning_vec=conditioning_vec)
|
||||
x = x + dt * velocity
|
||||
return x
|
||||
|
||||
def _rk4_integrate(
|
||||
self, model: nn.Module, x_init: Tensor, time_grid: Tensor, conditioning_vec: Tensor
|
||||
) -> Tensor:
|
||||
x = x_init
|
||||
|
||||
def dynamics(x_val: Tensor, t_scalar: float) -> Tensor:
|
||||
t_batch = torch.full((x_val.shape[0],), t_scalar, dtype=x_val.dtype, device=x_val.device)
|
||||
with torch.no_grad():
|
||||
return model(x_val, t_batch, conditioning_vec=conditioning_vec)
|
||||
|
||||
for i in range(len(time_grid) - 1):
|
||||
t = time_grid[i].item()
|
||||
dt = (time_grid[i + 1] - time_grid[i]).item()
|
||||
|
||||
k1 = dynamics(x, t)
|
||||
k2 = dynamics(x + dt * k1 / 2, t + dt / 2)
|
||||
k3 = dynamics(x + dt * k2 / 2, t + dt / 2)
|
||||
k4 = dynamics(x + dt * k3, t + dt)
|
||||
|
||||
x = x + dt / 6 * (k1 + 2 * k2 + 2 * k3 + k4)
|
||||
|
||||
return x
|
||||
@@ -1,105 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Bryson Jones and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_multi_task_dit_pre_post_processors(
|
||||
config: MultiTaskDiTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for a Multi-Task DiT policy.
|
||||
|
||||
The pre-processing pipeline prepares the input data for the model by:
|
||||
1. Renaming features.
|
||||
2. Adding a batch dimension.
|
||||
3. Tokenizing the language task description (if present).
|
||||
4. Moving the data to the specified device.
|
||||
5. Normalizing the input and output features based on dataset statistics.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Unnormalizing the output features to their original scale.
|
||||
2. Moving the data to the CPU.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the Multi-Task DiT policy,
|
||||
containing feature definitions, normalization mappings, and device information.
|
||||
dataset_stats: A dictionary of statistics used for normalization.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.text_encoder_name,
|
||||
padding=config.tokenizer_padding,
|
||||
padding_side=config.tokenizer_padding_side,
|
||||
max_length=config.tokenizer_max_length,
|
||||
truncation=config.tokenizer_truncation,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
device=config.device,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -17,65 +17,6 @@ It is designed as a **Vision-Language-Action model for general robot control**.
|
||||
|
||||
---
|
||||
|
||||
## Relative Actions
|
||||
|
||||
π₀ supports training with **relative actions**, where the model learns relative offsets
|
||||
from the current robot state instead of absolute joint positions. This mirrors the
|
||||
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
|
||||
|
||||
### How it works
|
||||
|
||||
1. **During preprocessing**, absolute actions are converted to relative offsets:
|
||||
`relative = action - state` (for selected joints).
|
||||
2. The relative actions are normalized using statistics computed from the relative distribution.
|
||||
3. **During postprocessing**, predicted relative actions are converted back to absolute:
|
||||
`absolute = relative + state`.
|
||||
|
||||
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
|
||||
|
||||
### Configuration
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
|
||||
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
|
||||
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
|
||||
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
|
||||
|
||||
### Training example
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]'
|
||||
```
|
||||
|
||||
When `use_relative_actions=true`, the training script automatically:
|
||||
|
||||
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
|
||||
- Replaces the standard action stats with relative stats for normalization
|
||||
- Broadcasts these stats across all ranks in distributed training
|
||||
|
||||
### Recomputing stats for an existing dataset
|
||||
|
||||
If you want to precompute relative action stats offline, use `recompute_stats` from
|
||||
`lerobot.datasets.dataset_tools`:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.dataset_tools import recompute_stats
|
||||
|
||||
dataset = LeRobotDataset("your_org/your_dataset")
|
||||
dataset = recompute_stats(
|
||||
dataset,
|
||||
relative_action=True,
|
||||
relative_exclude_joints=["gripper"],
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀ paper:
|
||||
|
||||
@@ -50,35 +50,6 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Relative actions: converts absolute actions to relative (relative to state).
|
||||
use_relative_actions: bool = False
|
||||
# Joint names to exclude from relative (kept absolute). Empty list = all dims relative.
|
||||
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Relative state (UMI-style relative proprioception): converts multi-timestep
|
||||
# observation.state to offsets from the current timestep, providing velocity info.
|
||||
# Requires state_obs_steps >= 2. The flattened multi-timestep state is padded to
|
||||
# max_state_dim, so ensure state_obs_steps * state_dim <= max_state_dim.
|
||||
use_relative_state: bool = False
|
||||
state_obs_steps: int = 1
|
||||
relative_exclude_state_joints: list[str] = field(default_factory=list)
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
state_feature_names: list[str] | None = None
|
||||
|
||||
# Derive observation.state from the action column (UMI-style).
|
||||
# When True, action_delta_indices loads one extra leading timestep [-1, 0, ..., chunk_size-1],
|
||||
# DeriveStateFromActionStep extracts [action[t-1], action[t]] as a 2-step state,
|
||||
# and strips the extra timestep from the action chunk.
|
||||
# Implies use_relative_state=True and state_obs_steps=2.
|
||||
derive_state_from_action: bool = False
|
||||
|
||||
# Latency compensation: skip this many steps from the start of each predicted
|
||||
# action chunk during inference. E.g. at 10Hz with ~200ms total latency,
|
||||
# latency_skip_steps=2 compensates for the delay.
|
||||
latency_skip_steps: int = 0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
@@ -128,10 +99,6 @@ class PI0Config(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.derive_state_from_action:
|
||||
self.use_relative_state = True
|
||||
self.state_obs_steps = 2
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
@@ -147,13 +114,6 @@ class PI0Config(PreTrainedConfig):
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
if self.use_relative_state and self.state_obs_steps < 2:
|
||||
raise ValueError(
|
||||
"use_relative_state requires state_obs_steps >= 2 "
|
||||
f"(got {self.state_obs_steps}). Set state_obs_steps=2 for "
|
||||
"UMI-style relative proprioception."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
@@ -199,16 +159,8 @@ class PI0Config(PreTrainedConfig):
|
||||
def observation_delta_indices(self) -> None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def state_delta_indices(self) -> list[int] | None:
|
||||
if self.state_obs_steps >= 2:
|
||||
return list(range(-(self.state_obs_steps - 1), 1))
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
if self.derive_state_from_action:
|
||||
return [-1] + list(range(self.chunk_size))
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
|
||||
@@ -1230,11 +1230,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return images, img_masks
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Flatten multi-timestep state and pad to max_state_dim."""
|
||||
state = batch[OBS_STATE]
|
||||
if state.ndim == 3:
|
||||
state = state.flatten(start_dim=1)
|
||||
state = pad_vector(state, self.config.max_state_dim)
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
@@ -1253,8 +1250,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
if len(self._action_queue) == 0:
|
||||
skip = self.config.latency_skip_steps
|
||||
actions = self.predict_action_chunk(batch)[:, skip : skip + self.config.n_action_steps]
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
# Transpose to get shape (n_action_steps, batch_size, action_dim)
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
|
||||
@@ -21,18 +21,14 @@ import torch
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeriveStateFromActionStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RelativeActionsProcessorStep,
|
||||
RelativeStateProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
@@ -130,25 +126,7 @@ def make_pi0_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
derive_state_step = DeriveStateFromActionStep(
|
||||
enabled=getattr(config, "derive_state_from_action", False),
|
||||
)
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
relative_state_step = RelativeStateProcessorStep(
|
||||
enabled=getattr(config, "use_relative_state", False),
|
||||
exclude_joints=getattr(config, "relative_exclude_state_joints", []),
|
||||
state_names=getattr(config, "state_feature_names", None),
|
||||
)
|
||||
|
||||
# Order: DeriveStateFromAction extracts state from the extended action chunk,
|
||||
# then relative_action uses current state[t] for subtraction,
|
||||
# then relative_state converts the multi-timestep state to offsets.
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
@@ -160,9 +138,6 @@ def make_pi0_pre_post_processors(
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
derive_state_step,
|
||||
relative_step,
|
||||
relative_state_step,
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
@@ -174,7 +149,6 @@ def make_pi0_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_relative_actions, relative_step=relative_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -17,48 +17,6 @@ It is designed as a **Vision-Language-Action model with open-world generalizatio
|
||||
|
||||
---
|
||||
|
||||
## Relative Actions
|
||||
|
||||
π₀.₅ supports training with **relative actions**, where the model learns relative offsets
|
||||
from the current robot state instead of absolute joint positions. This mirrors the
|
||||
relative-action transform in OpenPI (`DeltaActions`) and can improve performance.
|
||||
|
||||
### How it works
|
||||
|
||||
1. **During preprocessing**, absolute actions are converted to relative offsets:
|
||||
`relative = action - state` (for selected joints).
|
||||
2. The relative actions are normalized using statistics computed from the relative distribution.
|
||||
3. **During postprocessing**, predicted relative actions are converted back to absolute:
|
||||
`absolute = relative + state`.
|
||||
|
||||
Joints listed in `relative_exclude_joints` (e.g., gripper) are kept absolute.
|
||||
|
||||
### Configuration
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------------- | ----------- | ------------- | ---------------------------------------------------------------- |
|
||||
| `use_relative_actions` | `bool` | `False` | Enable relative-action training |
|
||||
| `relative_exclude_joints` | `list[str]` | `["gripper"]` | Joint names to keep absolute (matched by substring) |
|
||||
| `action_feature_names` | `list[str]` | `None` | Auto-populated from dataset metadata at runtime by `make_policy` |
|
||||
|
||||
### Training example
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.lerobot_train \
|
||||
--policy.type=pi05 \
|
||||
--dataset.repo_id=your_org/your_dataset \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]'
|
||||
```
|
||||
|
||||
When `use_relative_actions=true`, the training script automatically:
|
||||
|
||||
- Computes relative action statistics from the dataset (sampled chunk-level relative actions)
|
||||
- Replaces the standard action stats with relative stats for normalization
|
||||
- Broadcasts these stats across all ranks in distributed training
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
@@ -50,13 +50,6 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Relative actions: converts absolute actions to relative (relative to state).
|
||||
use_relative_actions: bool = False
|
||||
# Joint names to exclude from relative (kept absolute). Empty list = all dims relative.
|
||||
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ import torch
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
@@ -32,7 +31,6 @@ from lerobot.processor import (
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RelativeActionsProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
@@ -127,17 +125,10 @@ def make_pi05_pre_post_processors(
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# OpenPI order: raw → relative → normalize → model → unnormalize → absolute
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
relative_step,
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
@@ -159,7 +150,6 @@ def make_pi05_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_relative_actions, relative_step=relative_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -41,13 +41,6 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
max_action_dim: int = 32
|
||||
max_action_tokens: int = 256
|
||||
|
||||
# Relative actions: converts absolute actions to relative (relative to state).
|
||||
use_relative_actions: bool = False
|
||||
# Joint names to exclude from relative (kept absolute). Empty list = all dims relative.
|
||||
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
# Populated at runtime from dataset metadata by make_policy.
|
||||
action_feature_names: list[str] | None = None
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ import torch
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
@@ -33,7 +32,6 @@ from lerobot.processor import (
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RelativeActionsProcessorStep,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
@@ -127,24 +125,12 @@ def make_pi0_fast_pre_post_processors(
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=config.use_relative_actions,
|
||||
exclude_joints=getattr(config, "relative_exclude_joints", []),
|
||||
action_names=getattr(config, "action_feature_names", None),
|
||||
)
|
||||
|
||||
# Pi0Fast order: relative → normalize → tokenize → model → unnormalize → absolute
|
||||
# This matches pi0/pi0.5: RelativeActionsProcessorStep runs first on raw absolute actions,
|
||||
# caching the raw state. NormalizerProcessorStep then normalizes the raw relative actions,
|
||||
# so the normalizer (and action tokenizer) sees delta values, relative stats are required.
|
||||
# NOTE: RelativeActionsProcessorStep only modifies the action in the transition; it reads
|
||||
# state from the observation but does not change it. NormalizerProcessorStep still runs
|
||||
# before Pi0FastPrepareStateAndLanguageTokenizerProcessorStep, so the state tokenizer
|
||||
# continues to receive normalized state in [-1, 1] as expected.
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
relative_step,
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi0FastPrepareStateAndLanguageTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
@@ -170,7 +156,6 @@ def make_pi0_fast_pre_post_processors(
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
AbsoluteActionsProcessorStep(enabled=config.use_relative_actions, relative_step=relative_step),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to relative values with respect to the current state before passing to the model.
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
|
||||
@@ -75,15 +75,6 @@ from .policy_robot_bridge import (
|
||||
PolicyActionToRobotActionProcessorStep,
|
||||
RobotActionToPolicyActionProcessorStep,
|
||||
)
|
||||
from .relative_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
DeriveStateFromActionStep,
|
||||
RelativeActionsProcessorStep,
|
||||
RelativeStateProcessorStep,
|
||||
to_absolute_actions,
|
||||
to_relative_actions,
|
||||
to_relative_state,
|
||||
)
|
||||
from .rename_processor import RenameObservationsProcessorStep
|
||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||
|
||||
@@ -109,10 +100,6 @@ __all__ = [
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"DeriveStateFromActionStep",
|
||||
"RelativeActionsProcessorStep",
|
||||
"RelativeStateProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
@@ -142,9 +129,6 @@ __all__ = [
|
||||
"transition_to_batch",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessorStep",
|
||||
"to_absolute_actions",
|
||||
"to_relative_actions",
|
||||
"to_relative_state",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -1,367 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = [
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"DeriveStateFromActionStep",
|
||||
"RelativeActionsProcessorStep",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"RelativeStateProcessorStep",
|
||||
"to_relative_actions",
|
||||
"to_absolute_actions",
|
||||
"to_relative_state",
|
||||
]
|
||||
|
||||
|
||||
def to_relative_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert absolute actions to relative: relative = action - state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
# Align state to the same device/dtype as actions. _last_state is cached before
|
||||
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
|
||||
if state.device != actions.device or state.dtype != actions.dtype:
|
||||
state = state.to(device=actions.device, dtype=actions.dtype)
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] -= state_offset
|
||||
return actions
|
||||
|
||||
|
||||
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert relative actions back to absolute: absolute = relative + state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
# Align state to the same device/dtype as actions. _last_state is cached before
|
||||
# DeviceProcessorStep moves the transition, so it can be on CPU while actions are on CUDA.
|
||||
if state.device != actions.device or state.dtype != actions.dtype:
|
||||
state = state.to(device=actions.device, dtype=actions.dtype)
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] += state_offset
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("derive_state_from_action_processor")
|
||||
@dataclass
|
||||
class DeriveStateFromActionStep(ProcessorStep):
|
||||
"""Derives 2-step observation.state from the action chunk (UMI-style).
|
||||
|
||||
Expects action with one extra leading timestep: [B, chunk_size+1, D]
|
||||
from action_delta_indices = [-1, 0, 1, ..., chunk_size-1].
|
||||
Extracts [action[t-1], action[t]] as state and strips the extra timestep.
|
||||
No-op during inference (state comes from robot).
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is None or action.ndim < 3:
|
||||
return transition
|
||||
new_transition = transition.copy()
|
||||
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
|
||||
new_obs[OBS_STATE] = action[..., :2, :]
|
||||
new_transition[TransitionKey.ACTION] = action[..., 1:, :]
|
||||
new_transition[TransitionKey.OBSERVATION] = new_obs
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class RelativeActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
|
||||
|
||||
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
|
||||
trains on relative offsets instead of absolute positions.
|
||||
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
|
||||
the conversion during postprocessing.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the relative conversion.
|
||||
exclude_joints: Joint names to keep absolute (not converted to relative).
|
||||
action_names: Action dimension names from dataset metadata, used to build
|
||||
the mask from exclude_joints. If None, all dims are converted.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
exclude_joints: list[str] = field(default_factory=list)
|
||||
action_names: list[str] | None = None
|
||||
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _build_mask(self, action_dim: int) -> list[bool]:
|
||||
if not self.exclude_joints or self.action_names is None:
|
||||
return [True] * action_dim
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * action_dim
|
||||
|
||||
mask = []
|
||||
for name in self.action_names[:action_dim]:
|
||||
action_name = str(name).lower()
|
||||
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < action_dim:
|
||||
mask.extend([True] * (action_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
raw_state = observation.get(OBS_STATE) if observation else None
|
||||
|
||||
# When state_delta_indices loads multi-timestep state [B, n_obs, D],
|
||||
# use only the current (last) timestep for relative action conversion.
|
||||
if raw_state is not None:
|
||||
state = raw_state[..., -1, :] if raw_state.ndim >= 3 else raw_state
|
||||
else:
|
||||
state = None
|
||||
|
||||
# Always cache state for the paired AbsoluteActionsProcessorStep
|
||||
if state is not None:
|
||||
self._last_state = state
|
||||
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None or state is None:
|
||||
return new_transition
|
||||
|
||||
mask = self._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_relative_actions(action, state, mask)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"exclude_joints": self.exclude_joints,
|
||||
"action_names": self.action_names,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def to_relative_state(state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert multi-timestep absolute state to relative (offset from current timestep).
|
||||
|
||||
Each timestep becomes: ``state[..., t, :] - state[..., -1, :]`` for masked dims.
|
||||
The last (current) timestep becomes zeros for masked dims.
|
||||
|
||||
Args:
|
||||
state: (..., n_obs, state_dim) — last timestep is the reference (current).
|
||||
mask: Which dims to convert. Can be shorter than state_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=state.dtype, device=state.device)
|
||||
dims = mask_t.shape[0]
|
||||
current = state[..., -1:, :] # (..., 1, state_dim)
|
||||
state = state.clone()
|
||||
state[..., :dims] -= current[..., :dims] * mask_t
|
||||
return state
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("relative_state_processor")
|
||||
@dataclass
|
||||
class RelativeStateProcessorStep(ProcessorStep):
|
||||
"""Converts observation.state to relative (offset from current timestep).
|
||||
|
||||
UMI-style relative proprioception: each state timestep is expressed as
|
||||
an offset from the current EE pose, providing velocity information.
|
||||
|
||||
During training (multi-timestep input from ``state_delta_indices``):
|
||||
``state[..., t, :] -= state[..., -1, :]`` — subtract current from all.
|
||||
|
||||
During inference (single timestep): buffers the previous state and stacks
|
||||
``[previous, current]`` before applying the relative conversion, producing
|
||||
the same ``[n_obs, D]`` shape the model expects.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the relative conversion.
|
||||
exclude_joints: Joint/dim names to keep absolute.
|
||||
state_names: State dimension names from dataset metadata.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
exclude_joints: list[str] = field(default_factory=list)
|
||||
state_names: list[str] | None = None
|
||||
_previous_state: torch.Tensor | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def _build_mask(self, state_dim: int) -> list[bool]:
|
||||
if not self.exclude_joints or self.state_names is None:
|
||||
return [True] * state_dim
|
||||
|
||||
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
|
||||
if not exclude_tokens:
|
||||
return [True] * state_dim
|
||||
|
||||
mask = []
|
||||
for name in self.state_names[:state_dim]:
|
||||
state_name = str(name).lower()
|
||||
is_excluded = any(token == state_name or token in state_name for token in exclude_tokens)
|
||||
mask.append(not is_excluded)
|
||||
|
||||
if len(mask) < state_dim:
|
||||
mask.extend([True] * (state_dim - len(mask)))
|
||||
|
||||
return mask
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION, {})
|
||||
state = observation.get(OBS_STATE) if observation else None
|
||||
|
||||
if state is None:
|
||||
return transition
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_obs = dict(new_transition.get(TransitionKey.OBSERVATION, {}))
|
||||
mask = self._build_mask(state.shape[-1])
|
||||
|
||||
if state.ndim >= 3:
|
||||
# [B, n_obs, D] — multi-timestep (training with state_delta_indices)
|
||||
relative = to_relative_state(state, mask)
|
||||
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, n_obs*D]
|
||||
elif state.ndim == 2:
|
||||
# [B, D] — single timestep (inference): buffer previous and stack
|
||||
current = state
|
||||
if self._previous_state is None:
|
||||
self._previous_state = current.clone()
|
||||
prev = self._previous_state
|
||||
if prev.device != current.device or prev.dtype != current.dtype:
|
||||
prev = prev.to(device=current.device, dtype=current.dtype)
|
||||
stacked = torch.stack([prev, current], dim=-2) # [B, 2, D]
|
||||
relative = to_relative_state(stacked, mask)
|
||||
new_obs[OBS_STATE] = relative.flatten(start_dim=-2) # [B, 2*D]
|
||||
self._previous_state = current.clone()
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = new_obs
|
||||
return new_transition
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the state buffer. Call at episode boundaries during inference."""
|
||||
self._previous_state = None
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"exclude_joints": self.exclude_joints,
|
||||
"state_names": self.state_names,
|
||||
}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("absolute_actions_processor")
|
||||
@dataclass
|
||||
class AbsoluteActionsProcessorStep(ProcessorStep):
|
||||
"""Converts relative actions back to absolute actions (action += state) for all dimensions.
|
||||
|
||||
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
|
||||
predicted relative offsets are converted back to absolute positions for execution.
|
||||
Reads the cached state from its paired RelativeActionsProcessorStep.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the absolute conversion.
|
||||
relative_step: Reference to the paired RelativeActionsProcessorStep that caches state.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
relative_step: RelativeActionsProcessorStep | None = field(default=None, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
if self.relative_step is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires a paired RelativeActionsProcessorStep "
|
||||
"but relative_step is None. Ensure relative_step is set when constructing the postprocessor."
|
||||
)
|
||||
|
||||
if self.relative_step._last_state is None:
|
||||
raise RuntimeError(
|
||||
"AbsoluteActionsProcessorStep requires state from RelativeActionsProcessorStep "
|
||||
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
|
||||
)
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return new_transition
|
||||
|
||||
mask = self.relative_step._build_mask(action.shape[-1])
|
||||
new_transition[TransitionKey.ACTION] = to_absolute_actions(
|
||||
action, self.relative_step._last_state, mask
|
||||
)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
@@ -563,7 +563,7 @@ class ReplayBuffer:
|
||||
)
|
||||
|
||||
# Start writing images if needed
|
||||
lerobot_dataset.writer.start_image_writer(num_processes=0, num_threads=3)
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||
|
||||
# Convert transitions into episodes and frames
|
||||
|
||||
@@ -603,10 +603,10 @@ class ReplayBuffer:
|
||||
lerobot_dataset.save_episode()
|
||||
|
||||
# Save any remaining frames in the buffer
|
||||
if lerobot_dataset.has_pending_frames():
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode()
|
||||
|
||||
lerobot_dataset.writer.stop_image_writer()
|
||||
lerobot_dataset.stop_image_writer()
|
||||
lerobot_dataset.finalize()
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
@@ -752,7 +752,8 @@ def replay_trajectory(
|
||||
episodes=[cfg.dataset.replay_episode],
|
||||
download_videos=False,
|
||||
)
|
||||
actions = dataset.select_columns(ACTION)
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
_, info = env.reset()
|
||||
|
||||
|
||||
@@ -39,31 +39,19 @@ class BiOpenArmFollower(Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
|
||||
# will only open the cameras assigned to it. Per-arm cameras are used
|
||||
# as fallback when top-level cameras are empty.
|
||||
if config.cameras:
|
||||
left_cameras = config.cameras
|
||||
right_cameras = {}
|
||||
else:
|
||||
left_cameras = config.left_arm_config.cameras
|
||||
right_cameras = config.right_arm_config.cameras
|
||||
|
||||
left_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=left_cameras,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
side=config.left_arm_config.side,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
can_bitrate=config.left_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.left_arm_config.can_data_bitrate,
|
||||
motor_config=config.left_arm_config.motor_config,
|
||||
gripper_port=config.left_arm_config.gripper_port,
|
||||
gripper_motor_ids=config.left_arm_config.gripper_motor_ids,
|
||||
position_kd=config.left_arm_config.position_kd,
|
||||
position_kp=config.left_arm_config.position_kp,
|
||||
joint_limits=config.left_arm_config.joint_limits,
|
||||
@@ -75,15 +63,13 @@ class BiOpenArmFollower(Robot):
|
||||
port=config.right_arm_config.port,
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=right_cameras,
|
||||
cameras=config.right_arm_config.cameras,
|
||||
side=config.right_arm_config.side,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
can_bitrate=config.right_arm_config.can_bitrate,
|
||||
can_data_bitrate=config.right_arm_config.can_data_bitrate,
|
||||
motor_config=config.right_arm_config.motor_config,
|
||||
gripper_port=config.right_arm_config.gripper_port,
|
||||
gripper_motor_ids=config.right_arm_config.gripper_motor_ids,
|
||||
position_kd=config.right_arm_config.position_kd,
|
||||
position_kp=config.right_arm_config.position_kp,
|
||||
joint_limits=config.right_arm_config.joint_limits,
|
||||
@@ -107,10 +93,13 @@ class BiOpenArmFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
|
||||
# "right_wrist"), so we merge them directly — unlike motors which need the
|
||||
# left_/right_ prefix to disambiguate identical per-arm joint names.
|
||||
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
|
||||
left_arm_cameras_ft = self.left_arm._cameras_ft
|
||||
right_arm_cameras_ft = self.right_arm._cameras_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -150,17 +139,13 @@ class BiOpenArmFollower(Robot):
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
|
||||
# Camera keys that should NOT get the arm prefix (they already have unique names)
|
||||
left_cam_keys = set(self.left_arm.cameras.keys())
|
||||
right_cam_keys = set(self.right_arm.cameras.keys())
|
||||
|
||||
# Add "left_" prefix
|
||||
left_obs = self.left_arm.get_observation()
|
||||
for key, value in left_obs.items():
|
||||
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
|
||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||
|
||||
# Add "right_" prefix
|
||||
right_obs = self.right_arm.get_observation()
|
||||
for key, value in right_obs.items():
|
||||
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
|
||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -14,9 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase
|
||||
|
||||
from ..config import RobotConfig
|
||||
@@ -29,6 +28,3 @@ class BiOpenArmFollowerConfig(RobotConfig):
|
||||
|
||||
left_arm_config: OpenArmFollowerConfigBase
|
||||
right_arm_config: OpenArmFollowerConfigBase
|
||||
|
||||
# Top-level cameras shared across both arms.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -28,8 +28,7 @@ LEFT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
|
||||
"joint_5": (-85.0, 85.0),
|
||||
"joint_6": (-40.0, 40.0),
|
||||
"joint_7": (-80.0, 80.0),
|
||||
"proximal": (0.0, 100.0),
|
||||
"distal": (0.0, 100.0),
|
||||
"gripper": (-65.0, 0.0),
|
||||
}
|
||||
|
||||
RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
|
||||
@@ -40,8 +39,7 @@ RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = {
|
||||
"joint_5": (-85.0, 85.0),
|
||||
"joint_6": (-40.0, 40.0),
|
||||
"joint_7": (-80.0, 80.0),
|
||||
"proximal": (0.0, 100.0),
|
||||
"distal": (0.0, 100.0),
|
||||
"gripper": (-65.0, 0.0),
|
||||
}
|
||||
|
||||
|
||||
@@ -75,8 +73,13 @@ class OpenArmFollowerConfigBase:
|
||||
# Camera configurations
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# Arm motor configuration (7 DOF, Damiao on CAN bus)
|
||||
# Motor configuration for OpenArms (7 DOF per arm)
|
||||
# Maps motor names to (send_can_id, recv_can_id, motor_type)
|
||||
# Based on: https://docs.openarm.dev/software/setup/configure-test
|
||||
# OpenArms uses 4 types of motors:
|
||||
# - DM8009 (DM-J8009P-2EC) for shoulders (high torque)
|
||||
# - DM4340P and DM4340 for shoulder rotation and elbow
|
||||
# - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper
|
||||
motor_config: dict[str, tuple[int, int, str]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009)
|
||||
@@ -86,18 +89,19 @@ class OpenArmFollowerConfigBase:
|
||||
"joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310)
|
||||
"joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310)
|
||||
"joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310)
|
||||
"gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310)
|
||||
}
|
||||
)
|
||||
|
||||
# UMI-style gripper (Feetech STS3215 on serial bus)
|
||||
gripper_port: str = "/dev/ttyUSB0"
|
||||
gripper_motor_ids: dict[str, int] = field(default_factory=lambda: {"proximal": 1, "distal": 2})
|
||||
# MIT control parameters for position control (used in send_action)
|
||||
# List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper]
|
||||
position_kp: list[float] = field(
|
||||
default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0]
|
||||
)
|
||||
position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3, 0.3])
|
||||
|
||||
# MIT control parameters for the 7 arm joints
|
||||
position_kp: list[float] = field(default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0])
|
||||
position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3])
|
||||
|
||||
# Joint limits. Can be overridden via CLI or by setting config.side to 'left' or 'right'.
|
||||
# Values for joint limits. Can be overridden via CLI (for custom values) or by setting config.side to either 'left' or 'right'.
|
||||
# If config.side is left set to None and no CLI values are passed, the default joint limit values are small for safety.
|
||||
joint_limits: dict[str, tuple[float, float]] = field(
|
||||
default_factory=lambda: {
|
||||
"joint_1": (-5.0, 5.0),
|
||||
@@ -107,8 +111,7 @@ class OpenArmFollowerConfigBase:
|
||||
"joint_5": (-5.0, 5.0),
|
||||
"joint_6": (-5.0, 5.0),
|
||||
"joint_7": (-5.0, 5.0),
|
||||
"proximal": (0.0, 100.0),
|
||||
"distal": (0.0, 100.0),
|
||||
"gripper": (-5.0, 0.0),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import Any
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.damiao import DamiaoMotorsBus
|
||||
from lerobot.motors.feetech import FeetechMotorsBus, OperatingMode
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
@@ -39,7 +38,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class OpenArmFollower(Robot):
|
||||
"""
|
||||
OpenArms Follower Robot: 7 DOF Damiao arm (CAN) + UMI-style Feetech gripper (serial).
|
||||
OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper.
|
||||
The arm uses Damiao motors in MIT control mode.
|
||||
"""
|
||||
|
||||
config_class = OpenArmFollowerConfig
|
||||
@@ -49,17 +49,19 @@ class OpenArmFollower(Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Arm motors (Damiao on CAN bus)
|
||||
arm_motors: dict[str, Motor] = {}
|
||||
# Arm motors
|
||||
motors: dict[str, Motor] = {}
|
||||
for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items():
|
||||
motor = Motor(send_id, motor_type_str, MotorNormMode.DEGREES)
|
||||
motor = Motor(
|
||||
send_id, motor_type_str, MotorNormMode.DEGREES
|
||||
) # Always use degrees for Damiao motors
|
||||
motor.recv_id = recv_id
|
||||
motor.motor_type_str = motor_type_str
|
||||
arm_motors[motor_name] = motor
|
||||
motors[motor_name] = motor
|
||||
|
||||
self.bus = DamiaoMotorsBus(
|
||||
port=self.config.port,
|
||||
motors=arm_motors,
|
||||
motors=motors,
|
||||
calibration=self.calibration,
|
||||
can_interface=self.config.can_interface,
|
||||
use_can_fd=self.config.use_can_fd,
|
||||
@@ -67,17 +69,6 @@ class OpenArmFollower(Robot):
|
||||
data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None,
|
||||
)
|
||||
|
||||
# Gripper motors (Feetech STS3215 on serial bus)
|
||||
gripper_motors: dict[str, Motor] = {
|
||||
name: Motor(motor_id, "sts3215", MotorNormMode.RANGE_0_100)
|
||||
for name, motor_id in config.gripper_motor_ids.items()
|
||||
}
|
||||
self.gripper_bus = FeetechMotorsBus(
|
||||
port=config.gripper_port,
|
||||
motors=gripper_motors,
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
if config.side is not None:
|
||||
if config.side == "left":
|
||||
config.joint_limits = LEFT_DEFAULT_JOINTS_LIMITS
|
||||
@@ -93,6 +84,7 @@ class OpenArmFollower(Robot):
|
||||
)
|
||||
logger.info(f"Values used for joint limits: {config.joint_limits}.")
|
||||
|
||||
# Initialize cameras
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
@@ -101,10 +93,8 @@ class OpenArmFollower(Robot):
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float
|
||||
features[f"{motor}.torque"] = float
|
||||
for motor in self.gripper_bus.motors:
|
||||
features[f"{motor}.pos"] = float
|
||||
features[f"{motor}.vel"] = float # Add this
|
||||
features[f"{motor}.torque"] = float # Add this
|
||||
return features
|
||||
|
||||
@property
|
||||
@@ -126,11 +116,8 @@ class OpenArmFollower(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and self.gripper_bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
)
|
||||
"""Check if robot is connected."""
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
@@ -140,12 +127,12 @@ class OpenArmFollower(Robot):
|
||||
We assume that at connection time, the arms are in a safe rest position,
|
||||
and torque can be safely disabled to run calibration if needed.
|
||||
"""
|
||||
|
||||
# Connect to CAN bus
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
self.bus.connect()
|
||||
|
||||
logger.info(f"Connecting gripper on {self.config.gripper_port}...")
|
||||
self.gripper_bus.connect()
|
||||
|
||||
# Run calibration if needed
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
"Mismatch between calibration values in the motor and the calibration file or no calibration file found"
|
||||
@@ -157,7 +144,7 @@ class OpenArmFollower(Robot):
|
||||
|
||||
self.configure()
|
||||
|
||||
if self.bus.is_calibrated:
|
||||
if self.is_calibrated:
|
||||
self.bus.set_zero_position()
|
||||
|
||||
self.bus.enable_torque()
|
||||
@@ -166,39 +153,47 @@ class OpenArmFollower(Robot):
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus.is_calibrated and self.gripper_bus.is_calibrated
|
||||
"""Check if robot is calibrated."""
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration for both the Damiao arm and Feetech gripper.
|
||||
Run calibration procedure for OpenArms robot.
|
||||
|
||||
Arm calibration: set zero position with arm hanging, ±90° default range.
|
||||
Gripper calibration: SO100-style half-turn homing + range recording.
|
||||
The calibration procedure:
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position
|
||||
4. Record range of motion for each joint
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
# Calibration file exists, ask user whether to use it or run new calibration
|
||||
user_input = input(
|
||||
f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Writing calibration file associated with the id {self.id} to the motors")
|
||||
self.bus.write_calibration(self.calibration)
|
||||
self.gripper_bus.write_calibration(self.calibration)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
|
||||
# --- Arm calibration (Damiao) ---
|
||||
self.bus.disable_torque()
|
||||
|
||||
# Step 1: Set zero position
|
||||
input(
|
||||
"\nCalibration: Set Zero Position\n"
|
||||
"\nCalibration: Set Zero Position)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
# Set current position as zero for all motors
|
||||
self.bus.set_zero_position()
|
||||
logger.info("Arm zero position set.")
|
||||
|
||||
logger.info("Setting range: -90° to +90° for safety by default for all joints")
|
||||
for motor_name, motor in self.bus.motors.items():
|
||||
self.calibration[motor_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
@@ -207,52 +202,17 @@ class OpenArmFollower(Robot):
|
||||
range_min=-90,
|
||||
range_max=90,
|
||||
)
|
||||
|
||||
self.bus.write_calibration(self.calibration)
|
||||
|
||||
# --- Gripper calibration (Feetech) ---
|
||||
self.gripper_bus.disable_torque()
|
||||
for motor in self.gripper_bus.motors:
|
||||
self.gripper_bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
input("Move gripper to the middle of its range of motion and press ENTER....")
|
||||
homing_offsets = self.gripper_bus.set_half_turn_homings()
|
||||
|
||||
gripper_motor_names = list(self.gripper_bus.motors.keys())
|
||||
print(
|
||||
f"Move gripper joints ({', '.join(gripper_motor_names)}) through their "
|
||||
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||
)
|
||||
range_mins, range_maxes = self.gripper_bus.record_ranges_of_motion(gripper_motor_names)
|
||||
|
||||
for motor_name, m in self.gripper_bus.motors.items():
|
||||
self.calibration[motor_name] = MotorCalibration(
|
||||
id=m.id,
|
||||
drive_mode=0,
|
||||
homing_offset=homing_offsets[motor_name],
|
||||
range_min=range_mins[motor_name],
|
||||
range_max=range_maxes[motor_name],
|
||||
)
|
||||
self.gripper_bus.write_calibration(self.calibration)
|
||||
|
||||
self._save_calibration()
|
||||
print(f"Calibration saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Configure both arm (Damiao) and gripper (Feetech) motors."""
|
||||
"""Configure motors with appropriate settings."""
|
||||
# TODO(Steven, Pepijn): Slightly different from what it is happening in the leader
|
||||
with self.bus.torque_disabled():
|
||||
self.bus.configure_motors()
|
||||
|
||||
with self.gripper_bus.torque_disabled():
|
||||
self.gripper_bus.configure_motors()
|
||||
for motor in self.gripper_bus.motors:
|
||||
self.gripper_bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
self.gripper_bus.write("P_Coefficient", motor, 16)
|
||||
self.gripper_bus.write("I_Coefficient", motor, 0)
|
||||
self.gripper_bus.write("D_Coefficient", motor, 32)
|
||||
self.gripper_bus.write("Max_Torque_Limit", motor, 500)
|
||||
self.gripper_bus.write("Protection_Current", motor, 250)
|
||||
self.gripper_bus.write("Overload_Torque", motor, 25)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
@@ -260,23 +220,25 @@ class OpenArmFollower(Robot):
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Read all motor states from arm (CAN) and gripper (serial), plus cameras."""
|
||||
"""
|
||||
Get current observation from robot including position, velocity, and torque.
|
||||
|
||||
Reads all motor states (pos/vel/torque) in one CAN refresh cycle
|
||||
instead of 3 separate reads.
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
obs_dict: dict[str, Any] = {}
|
||||
|
||||
# Arm motors (Damiao) — pos/vel/torque in one CAN refresh cycle
|
||||
states = self.bus.sync_read_all_states()
|
||||
|
||||
for motor in self.bus.motors:
|
||||
state = states.get(motor, {})
|
||||
obs_dict[f"{motor}.pos"] = state.get("position", 0.0)
|
||||
obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0)
|
||||
obs_dict[f"{motor}.torque"] = state.get("torque", 0.0)
|
||||
|
||||
# Gripper motors (Feetech) — position only
|
||||
gripper_positions = self.gripper_bus.sync_read("Present_Position")
|
||||
for motor, val in gripper_positions.items():
|
||||
obs_dict[f"{motor}.pos"] = val
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[cam_key] = cam.read_latest()
|
||||
@@ -296,76 +258,86 @@ class OpenArmFollower(Robot):
|
||||
custom_kd: dict[str, float] | None = None,
|
||||
) -> RobotAction:
|
||||
"""
|
||||
Send action command to robot. Arm joints go to Damiao CAN bus,
|
||||
gripper joints go to Feetech serial bus.
|
||||
Send action command to robot.
|
||||
|
||||
The action magnitude may be clipped based on safety limits.
|
||||
|
||||
Args:
|
||||
action: Dictionary with motor positions (e.g., "joint_1.pos", "proximal.pos")
|
||||
custom_kp: Optional custom kp gains per arm motor
|
||||
custom_kd: Optional custom kd gains per arm motor
|
||||
action: Dictionary with motor positions (e.g., "joint_1.pos", "joint_2.pos")
|
||||
custom_kp: Optional custom kp gains per motor (e.g., {"joint_1": 120.0, "joint_2": 150.0})
|
||||
custom_kd: Optional custom kd gains per motor (e.g., {"joint_1": 1.5, "joint_2": 2.0})
|
||||
|
||||
Returns:
|
||||
The action actually sent (potentially clipped)
|
||||
"""
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Apply joint limit clipping
|
||||
# Apply joint limit clipping to arm
|
||||
for motor_name, position in goal_pos.items():
|
||||
if motor_name in self.config.joint_limits:
|
||||
min_limit, max_limit = self.config.joint_limits[motor_name]
|
||||
clipped_position = max(min_limit, min(max_limit, position))
|
||||
if clipped_position != position:
|
||||
logger.debug(f"Clipped {motor_name} from {position:.2f} to {clipped_position:.2f}")
|
||||
logger.debug(f"Clipped {motor_name} from {position:.2f}° to {clipped_position:.2f}°")
|
||||
goal_pos[motor_name] = clipped_position
|
||||
|
||||
# Split into arm and gripper actions
|
||||
arm_motors = set(self.bus.motors.keys())
|
||||
gripper_motors = set(self.gripper_bus.motors.keys())
|
||||
arm_goal = {k: v for k, v in goal_pos.items() if k in arm_motors}
|
||||
gripper_goal = {k: v for k, v in goal_pos.items() if k in gripper_motors}
|
||||
|
||||
# Cap arm goal position when too far away from present position
|
||||
if self.config.max_relative_target is not None and arm_goal:
|
||||
# Cap goal position when too far away from present position.
|
||||
# /!\ Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.bus.sync_read("Present_Position")
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in arm_goal.items()}
|
||||
arm_goal = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
|
||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
|
||||
# Arm: batch MIT control (Damiao)
|
||||
if arm_goal:
|
||||
arm_motor_names = list(self.bus.motors.keys())
|
||||
commands = {}
|
||||
for motor_name, position_degrees in arm_goal.items():
|
||||
idx = arm_motor_names.index(motor_name) if motor_name in arm_motor_names else 0
|
||||
if custom_kp is not None and motor_name in custom_kp:
|
||||
kp = custom_kp[motor_name]
|
||||
else:
|
||||
kp = (
|
||||
self.config.position_kp[idx]
|
||||
if isinstance(self.config.position_kp, list)
|
||||
else self.config.position_kp
|
||||
)
|
||||
if custom_kd is not None and motor_name in custom_kd:
|
||||
kd = custom_kd[motor_name]
|
||||
else:
|
||||
kd = (
|
||||
self.config.position_kd[idx]
|
||||
if isinstance(self.config.position_kd, list)
|
||||
else self.config.position_kd
|
||||
)
|
||||
commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
|
||||
self.bus._mit_control_batch(commands)
|
||||
# TODO(Steven, Pepijn): Refactor writing
|
||||
# Motor name to index mapping for gains
|
||||
motor_index = {
|
||||
"joint_1": 0,
|
||||
"joint_2": 1,
|
||||
"joint_3": 2,
|
||||
"joint_4": 3,
|
||||
"joint_5": 4,
|
||||
"joint_6": 5,
|
||||
"joint_7": 6,
|
||||
"gripper": 7,
|
||||
}
|
||||
|
||||
# Gripper: position control (Feetech)
|
||||
if gripper_goal:
|
||||
self.gripper_bus.sync_write("Goal_Position", gripper_goal)
|
||||
# Use batch MIT control for arm (sends all commands, then collects responses)
|
||||
commands = {}
|
||||
for motor_name, position_degrees in goal_pos.items():
|
||||
idx = motor_index.get(motor_name, 0)
|
||||
# Use custom gains if provided, otherwise use config defaults
|
||||
if custom_kp is not None and motor_name in custom_kp:
|
||||
kp = custom_kp[motor_name]
|
||||
else:
|
||||
kp = (
|
||||
self.config.position_kp[idx]
|
||||
if isinstance(self.config.position_kp, list)
|
||||
else self.config.position_kp
|
||||
)
|
||||
if custom_kd is not None and motor_name in custom_kd:
|
||||
kd = custom_kd[motor_name]
|
||||
else:
|
||||
kd = (
|
||||
self.config.position_kd[idx]
|
||||
if isinstance(self.config.position_kd, list)
|
||||
else self.config.position_kd
|
||||
)
|
||||
commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0)
|
||||
|
||||
self.bus._mit_control_batch(commands)
|
||||
|
||||
goal_pos.update(arm_goal)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
"""Disconnect from robot."""
|
||||
|
||||
# Disconnect CAN bus
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
self.gripper_bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
|
||||
# Disconnect cameras
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,630 +0,0 @@
|
||||
<?xml version='1.0' encoding='utf-8'?>
|
||||
<robot name="openarm">
|
||||
<link name="world" />
|
||||
<joint name="openarm_body_world_joint" type="fixed">
|
||||
<parent link="world" />
|
||||
<child link="openarm_body_link0" />
|
||||
<origin rpy="0 0 0" xyz="0 0 0" />
|
||||
</joint>
|
||||
<link name="openarm_body_link0">
|
||||
<visual name="openarm_body_link0_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/body/v10/visual/body_link0.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_body_link0_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/body/v10/collision/body_link0_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<mass value="13.89" />
|
||||
<inertia ixx="1.653" ixy="0.0" ixz="0.0" iyy="1.653" iyz="0.0" izz="0.051" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_openarm_body_link0_joint" type="fixed">
|
||||
<parent link="openarm_body_link0" />
|
||||
<child link="openarm_left_link0" />
|
||||
<origin rpy="-1.5708 0 0" xyz="0.0 0.031 0.698" />
|
||||
</joint>
|
||||
<link name="openarm_left_link0">
|
||||
<visual name="openarm_left_link0_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link0.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link0_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link0_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0009483362816297526 -0.0001580207020448382 0.03076860287587199" />
|
||||
<mass value="1.1432284943239561" />
|
||||
<inertia ixx="0.001128" ixy="-4e-06" ixz="-3.3e-05" iyy="0.000962" iyz="-7e-06" izz="0.00147" />
|
||||
</inertial>
|
||||
</link>
|
||||
<link name="openarm_left_link1">
|
||||
<visual name="openarm_left_link1_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link1.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link1_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link1_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0011467657911800769 -3.319987657026362e-05 0.05395284380736254" />
|
||||
<mass value="1.1416684646202298" />
|
||||
<inertia ixx="0.001567" ixy="-1e-06" ixz="-2.9e-05" iyy="0.001273" iyz="1e-06" izz="0.001016" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint1" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0 0.0 0.0625" />
|
||||
<parent link="openarm_left_link0" />
|
||||
<child link="openarm_left_link1" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="40" lower="-3.490659" upper="1.3962629999999998" velocity="16.754666" />
|
||||
</joint>
|
||||
<link name="openarm_left_link2">
|
||||
<visual name="openarm_left_link2_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link2.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link2_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link2_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.00839629182351943 2.0145102027597523e-08 0.03256649300522363" />
|
||||
<mass value="0.2775092746011571" />
|
||||
<inertia ixx="0.000359" ixy="1e-06" ixz="-0.000109" iyy="0.000376" iyz="1e-06" izz="0.000232" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint2" type="revolute">
|
||||
<origin rpy="-1.57079632679 0 0" xyz="-0.0301 0.0 0.06" />
|
||||
<parent link="openarm_left_link1" />
|
||||
<child link="openarm_left_link2" />
|
||||
<axis xyz="-1 0 0" />
|
||||
<limit effort="40" lower="-3.3161253267948965" upper="0.17453267320510335" velocity="16.754666" />
|
||||
</joint>
|
||||
<link name="openarm_left_link3">
|
||||
<visual name="openarm_left_link3_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link3.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link3_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link3_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.002104752099628911 -0.0005549085042607548 0.09047470545721961" />
|
||||
<mass value="1.073863338202347" />
|
||||
<inertia ixx="0.004372" ixy="1e-06" ixz="1.1e-05" iyy="0.004319" iyz="-3.6e-05" izz="0.000661" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint3" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0301 0.0 0.06625" />
|
||||
<parent link="openarm_left_link2" />
|
||||
<child link="openarm_left_link3" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="27" lower="-1.570796" upper="1.570796" velocity="5.445426" />
|
||||
</joint>
|
||||
<link name="openarm_left_link4">
|
||||
<visual name="openarm_left_link4_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link4.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link4_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link4_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0029006831074562967 -0.03030575826634669 0.06339637422196209" />
|
||||
<mass value="0.6348534566833373" />
|
||||
<inertia ixx="0.000623" ixy="-1e-06" ixz="-1.9e-05" iyy="0.000511" iyz="3.8e-05" izz="0.000334" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint4" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="-0.0 0.0315 0.15375" />
|
||||
<parent link="openarm_left_link3" />
|
||||
<child link="openarm_left_link4" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="27" lower="0.0" upper="2.443461" velocity="5.445426" />
|
||||
</joint>
|
||||
<link name="openarm_left_link5">
|
||||
<visual name="openarm_left_link5_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link5.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link5_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link5_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.003049665024221911 -0.0008866902457326625 0.043079803024980934" />
|
||||
<mass value="0.6156588026168502" />
|
||||
<inertia ixx="0.000423" ixy="-8e-06" ixz="6e-06" iyy="0.000445" iyz="-6e-06" izz="0.000324" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint5" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0 -0.0315 0.0955" />
|
||||
<parent link="openarm_left_link4" />
|
||||
<child link="openarm_left_link5" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_left_link6">
|
||||
<visual name="openarm_left_link6_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link6.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link6_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link6_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.037136587005447405 -0.00033230528343419053 -9.498374522309838e-05" />
|
||||
<mass value="0.475202773187987" />
|
||||
<inertia ixx="0.000143" ixy="1e-06" ixz="1e-06" iyy="0.000157" iyz="1e-06" izz="0.000159" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint6" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0375 0.0 0.1205" />
|
||||
<parent link="openarm_left_link5" />
|
||||
<child link="openarm_left_link6" />
|
||||
<axis xyz="1 0 0" />
|
||||
<limit effort="7" lower="-0.785398" upper="0.785398" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_left_link7">
|
||||
<visual name="openarm_left_link7_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link7.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_link7_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link7_symp.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="6.875510271106056e-05 -0.01266175250761268 0.06951945409987448" />
|
||||
<mass value="0.4659771327380578" />
|
||||
<inertia ixx="0.000639" ixy="1e-06" ixz="1e-06" iyy="0.000497" iyz="8.9e-05" izz="0.000342" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_joint7" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="-0.0375 0.0 0.0" />
|
||||
<parent link="openarm_left_link6" />
|
||||
<child link="openarm_left_link7" />
|
||||
<axis xyz="0 -1 0" />
|
||||
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
|
||||
</joint>
|
||||
<joint name="openarm_right_openarm_body_link0_joint" type="fixed">
|
||||
<parent link="openarm_body_link0" />
|
||||
<child link="openarm_right_link0" />
|
||||
<origin rpy="1.5708 0 0" xyz="0.0 -0.031 0.698" />
|
||||
</joint>
|
||||
<link name="openarm_right_link0">
|
||||
<visual name="openarm_right_link0_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link0.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link0_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 0.0" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link0_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0009483362816297526 0.0001580207020448382 0.03076860287587199" />
|
||||
<mass value="1.1432284943239561" />
|
||||
<inertia ixx="0.001128" ixy="-4e-06" ixz="-3.3e-05" iyy="0.000962" iyz="-7e-06" izz="0.00147" />
|
||||
</inertial>
|
||||
</link>
|
||||
<link name="openarm_right_link1">
|
||||
<visual name="openarm_right_link1_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link1.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link1_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 0.0 -0.0625" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link1_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0011467657911800769 3.319987657026362e-05 0.05395284380736254" />
|
||||
<mass value="1.1416684646202298" />
|
||||
<inertia ixx="0.001567" ixy="-1e-06" ixz="-2.9e-05" iyy="0.001273" iyz="1e-06" izz="0.001016" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint1" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0 0.0 0.0625" />
|
||||
<parent link="openarm_right_link0" />
|
||||
<child link="openarm_right_link1" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="40" lower="-1.396263" upper="3.490659" velocity="16.754666" />
|
||||
</joint>
|
||||
<link name="openarm_right_link2">
|
||||
<visual name="openarm_right_link2_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link2.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link2_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0301 0.0 -0.1225" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link2_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.00839629182351943 -2.0145102027597523e-08 0.03256649300522363" />
|
||||
<mass value="0.2775092746011571" />
|
||||
<inertia ixx="0.000359" ixy="1e-06" ixz="-0.000109" iyy="0.000376" iyz="1e-06" izz="0.000232" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint2" type="revolute">
|
||||
<origin rpy="1.57079632679 0 0" xyz="-0.0301 0.0 0.06" />
|
||||
<parent link="openarm_right_link1" />
|
||||
<child link="openarm_right_link2" />
|
||||
<axis xyz="-1 0 0" />
|
||||
<limit effort="40" lower="-0.17453267320510335" upper="3.3161253267948965" velocity="16.754666" />
|
||||
</joint>
|
||||
<link name="openarm_right_link3">
|
||||
<visual name="openarm_right_link3_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link3.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link3_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.18875" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link3_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.002104752099628911 0.0005549085042607548 0.09047470545721961" />
|
||||
<mass value="1.073863338202347" />
|
||||
<inertia ixx="0.004372" ixy="1e-06" ixz="1.1e-05" iyy="0.004319" iyz="-3.6e-05" izz="0.000661" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint3" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0301 0.0 0.06625" />
|
||||
<parent link="openarm_right_link2" />
|
||||
<child link="openarm_right_link3" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="27" lower="-1.570796" upper="1.570796" velocity="5.445426" />
|
||||
</joint>
|
||||
<link name="openarm_right_link4">
|
||||
<visual name="openarm_right_link4_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link4.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link4_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0315 -0.3425" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link4_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0029006831074562967 -0.03030575826634669 0.06339637422196209" />
|
||||
<mass value="0.6348534566833373" />
|
||||
<inertia ixx="0.000623" ixy="-1e-06" ixz="-1.9e-05" iyy="0.000511" iyz="3.8e-05" izz="0.000334" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint4" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="-0.0 0.0315 0.15375" />
|
||||
<parent link="openarm_right_link3" />
|
||||
<child link="openarm_right_link4" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="27" lower="0.0" upper="2.443461" velocity="5.445426" />
|
||||
</joint>
|
||||
<link name="openarm_right_link5">
|
||||
<visual name="openarm_right_link5_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link5.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link5_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0 -0.0 -0.438" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link5_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.003049665024221911 0.0008866902457326625 0.043079803024980934" />
|
||||
<mass value="0.6156588026168502" />
|
||||
<inertia ixx="0.000423" ixy="-8e-06" ixz="6e-06" iyy="0.000445" iyz="-6e-06" izz="0.000324" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint5" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0 -0.0315 0.0955" />
|
||||
<parent link="openarm_right_link4" />
|
||||
<child link="openarm_right_link5" />
|
||||
<axis xyz="0 0 1" />
|
||||
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_right_link6">
|
||||
<visual name="openarm_right_link6_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link6.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link6_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.0375 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link6_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="-0.037136587005447405 0.00033230528343419053 -9.498374522309838e-05" />
|
||||
<mass value="0.475202773187987" />
|
||||
<inertia ixx="0.000143" ixy="1e-06" ixz="1e-06" iyy="0.000157" iyz="1e-06" izz="0.000159" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint6" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="0.0375 0.0 0.1205" />
|
||||
<parent link="openarm_right_link5" />
|
||||
<child link="openarm_right_link6" />
|
||||
<axis xyz="1 0 0" />
|
||||
<limit effort="7" lower="-0.785398" upper="0.785398" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_right_link7">
|
||||
<visual name="openarm_right_link7_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/visual/link7.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_link7_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.0 -0.5585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/arm/v10/collision/link7_symp.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0.0 0.0 0.0" xyz="6.875510271106056e-05 0.01266175250761268 0.06951945409987448" />
|
||||
<mass value="0.4659771327380578" />
|
||||
<inertia ixx="0.000639" ixy="1e-06" ixz="1e-06" iyy="0.000497" iyz="8.9e-05" izz="0.000342" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_joint7" type="revolute">
|
||||
<origin rpy="0 0 0" xyz="-0.0375 0.0 0.0" />
|
||||
<parent link="openarm_right_link6" />
|
||||
<child link="openarm_right_link7" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="7" lower="-1.570796" upper="1.570796" velocity="20.943946" />
|
||||
</joint>
|
||||
<link name="openarm_left_hand">
|
||||
<visual name="openarm_left_hand_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/hand.dae" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_hand_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/hand.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0 0.002 0.03" />
|
||||
<mass value="0.35" />
|
||||
<inertia ixx="0.0002473" ixy="1e-06" ixz="1e-06" iyy="1.763e-05" iyz="1e-06" izz="0.0002521" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="left_openarm_hand_joint" type="fixed">
|
||||
<parent link="openarm_left_link7" />
|
||||
<child link="openarm_left_hand" />
|
||||
<origin rpy="0 0 0" xyz="0 -0.0 0.1001" />
|
||||
</joint>
|
||||
<link name="openarm_left_hand_tcp">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||
<mass value="0.001" />
|
||||
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_hand_tcp_joint" type="fixed">
|
||||
<origin rpy="0 0 0" xyz="0 -0.0 0.08" />
|
||||
<parent link="openarm_left_hand" />
|
||||
<child link="openarm_left_hand_tcp" />
|
||||
</joint>
|
||||
<link name="openarm_left_left_finger">
|
||||
<visual name="openarm_left_left_finger_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_left_finger_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0064528 0.01702 0.0219685" />
|
||||
<mass value="0.03602545343277134" />
|
||||
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
|
||||
</inertial>
|
||||
</link>
|
||||
<link name="openarm_left_right_finger">
|
||||
<visual name="openarm_left_right_finger_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_left_right_finger_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0064528 -0.01702 0.0219685" />
|
||||
<mass value="0.03602545343277134" />
|
||||
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_left_finger_joint1" type="prismatic">
|
||||
<parent link="openarm_left_hand" />
|
||||
<child link="openarm_left_right_finger" />
|
||||
<origin rpy="0 0 0" xyz="0 -0.006 0.015" />
|
||||
<axis xyz="0 -1 0" />
|
||||
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
|
||||
</joint>
|
||||
<joint name="openarm_left_finger_joint2" type="prismatic">
|
||||
<parent link="openarm_left_hand" />
|
||||
<child link="openarm_left_left_finger" />
|
||||
<origin rpy="0 0 0" xyz="0 0.006 0.015" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
|
||||
<mimic joint="openarm_left_finger_joint1" />
|
||||
</joint>
|
||||
<link name="openarm_right_hand">
|
||||
<visual name="openarm_right_hand_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/hand.dae" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_hand_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.0 -0.6585" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/hand.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0 0.002 0.03" />
|
||||
<mass value="0.35" />
|
||||
<inertia ixx="0.0002473" ixy="1e-06" ixz="1e-06" iyy="1.763e-05" iyz="1e-06" izz="0.0002521" />
|
||||
</inertial>
|
||||
</link>
|
||||
<link name="openarm_right_ee_target">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||
<mass value="0.001" />
|
||||
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_ee_target_joint" type="fixed">
|
||||
<parent link="openarm_right_link7" />
|
||||
<child link="openarm_right_ee_target" />
|
||||
<origin rpy="0 0 0" xyz="0 0.0 0.07" />
|
||||
</joint>
|
||||
<joint name="right_openarm_hand_joint" type="fixed">
|
||||
<parent link="openarm_right_link7" />
|
||||
<child link="openarm_right_hand" />
|
||||
<origin rpy="0 0 0" xyz="0 -0.0 0.1001" />
|
||||
</joint>
|
||||
<link name="openarm_right_hand_tcp">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0" rpy="0 0 0" />
|
||||
<mass value="0.001" />
|
||||
<inertia ixx="0.000001" ixy="0.0" ixz="0.0" iyy="0.000001" iyz="0.0" izz="0.000001" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_hand_tcp_joint" type="fixed">
|
||||
<origin rpy="0 0 0" xyz="0 -0.0 0.08" />
|
||||
<parent link="openarm_right_hand" />
|
||||
<child link="openarm_right_hand_tcp" />
|
||||
</joint>
|
||||
<link name="openarm_right_left_finger">
|
||||
<visual name="openarm_right_left_finger_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_left_finger_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 -0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0064528 0.01702 0.0219685" />
|
||||
<mass value="0.03602545343277134" />
|
||||
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
|
||||
</inertial>
|
||||
</link>
|
||||
<link name="openarm_right_right_finger">
|
||||
<visual name="openarm_right_right_finger_visual">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/visual/finger.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</visual>
|
||||
<collision name="openarm_right_right_finger_collision">
|
||||
<origin rpy="0.0 0.0 0.0" xyz="0.0 0.05 -0.673001" />
|
||||
<geometry>
|
||||
<mesh filename="./meshes/ee/openarm_hand/collision/finger.stl" scale="0.001 -0.001 0.001" />
|
||||
</geometry>
|
||||
</collision>
|
||||
<inertial>
|
||||
<origin rpy="0 0 0" xyz="0.0064528 -0.01702 0.0219685" />
|
||||
<mass value="0.03602545343277134" />
|
||||
<inertia ixx="2.3749999999999997e-06" ixy="1e-06" ixz="1e-06" iyy="2.3749999999999997e-06" iyz="1e-06" izz="7.5e-07" />
|
||||
</inertial>
|
||||
</link>
|
||||
<joint name="openarm_right_finger_joint1" type="prismatic">
|
||||
<parent link="openarm_right_hand" />
|
||||
<child link="openarm_right_right_finger" />
|
||||
<origin rpy="0 0 0" xyz="0 -0.006 0.015" />
|
||||
<axis xyz="0 -1 0" />
|
||||
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
|
||||
</joint>
|
||||
<joint name="openarm_right_finger_joint2" type="prismatic">
|
||||
<parent link="openarm_right_hand" />
|
||||
<child link="openarm_right_left_finger" />
|
||||
<origin rpy="0 0 0" xyz="0 0.006 0.015" />
|
||||
<axis xyz="0 1 0" />
|
||||
<limit effort="333" lower="0.0" upper="0.044" velocity="10.0" />
|
||||
<mimic joint="openarm_right_finger_joint1" />
|
||||
</joint>
|
||||
</robot>
|
||||
@@ -1,408 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<title>Dataset Replay — EE Frame Viewer</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body { background: #0d1117; overflow: hidden; font-family: 'JetBrains Mono', monospace; color: #c9d1d9; }
|
||||
canvas { display: block; }
|
||||
|
||||
#panel {
|
||||
position: absolute; top: 14px; left: 14px;
|
||||
background: rgba(13,17,23,0.92); border: 1px solid #30363d;
|
||||
border-radius: 10px; padding: 16px 20px; z-index: 10;
|
||||
width: 340px; backdrop-filter: blur(8px);
|
||||
}
|
||||
#panel h2 { font-size: 14px; color: #58a6ff; margin-bottom: 10px; letter-spacing: 0.5px; }
|
||||
|
||||
.row { display: flex; align-items: center; gap: 8px; margin: 6px 0; font-size: 12px; }
|
||||
.row label { width: 70px; color: #8b949e; flex-shrink: 0; }
|
||||
.row .val { color: #f0f6fc; font-variant-numeric: tabular-nums; }
|
||||
|
||||
#transport {
|
||||
margin-top: 12px; display: flex; align-items: center; gap: 8px;
|
||||
}
|
||||
#transport button {
|
||||
background: #21262d; color: #c9d1d9; border: 1px solid #30363d;
|
||||
padding: 6px 14px; border-radius: 6px; cursor: pointer;
|
||||
font-family: inherit; font-size: 12px; transition: background 0.15s;
|
||||
}
|
||||
#transport button:hover { background: #30363d; }
|
||||
#transport button.active { background: #1f6feb; border-color: #1f6feb; color: #fff; }
|
||||
|
||||
#scrubber {
|
||||
width: 100%; margin-top: 8px;
|
||||
-webkit-appearance: none; appearance: none;
|
||||
height: 6px; border-radius: 3px; background: #21262d; outline: none;
|
||||
}
|
||||
#scrubber::-webkit-slider-thumb {
|
||||
-webkit-appearance: none; width: 14px; height: 14px;
|
||||
border-radius: 50%; background: #58a6ff; cursor: pointer;
|
||||
}
|
||||
|
||||
#speed-ctrl { margin-top: 6px; }
|
||||
#speed-ctrl select {
|
||||
background: #21262d; color: #c9d1d9; border: 1px solid #30363d;
|
||||
padding: 4px 8px; border-radius: 4px; font-family: inherit; font-size: 11px;
|
||||
}
|
||||
|
||||
#frame-counter {
|
||||
font-size: 11px; color: #8b949e; margin-top: 6px;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
.legend { display: flex; align-items: center; gap: 6px; margin: 3px 0; font-size: 11px; }
|
||||
.dot { width: 10px; height: 10px; border-radius: 50%; display: inline-block; }
|
||||
</style>
|
||||
<link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;600&display=swap" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<div id="panel">
|
||||
<h2>DATASET REPLAY — EE FRAME</h2>
|
||||
<div style="font-size:11px;color:#8b949e;margin-bottom:8px;">glannuzel/grabette-dataset · episode 0</div>
|
||||
|
||||
<div class="legend"><span class="dot" style="background:#ff6b6b"></span> EE target (dataset)</div>
|
||||
<div class="legend"><span class="dot" style="background:#ffd43b"></span> Trajectory (past)</div>
|
||||
<div class="legend"><span class="dot" style="background:#30363d"></span> Trajectory (future)</div>
|
||||
|
||||
<div class="row"><label>x</label><span class="val" id="v-x">—</span></div>
|
||||
<div class="row"><label>y</label><span class="val" id="v-y">—</span></div>
|
||||
<div class="row"><label>z</label><span class="val" id="v-z">—</span></div>
|
||||
<div class="row"><label>ax</label><span class="val" id="v-ax">—</span></div>
|
||||
<div class="row"><label>ay</label><span class="val" id="v-ay">—</span></div>
|
||||
<div class="row"><label>az</label><span class="val" id="v-az">—</span></div>
|
||||
<div class="row"><label>gripper</label><span class="val" id="v-grip">—</span></div>
|
||||
|
||||
<div id="transport">
|
||||
<button id="btn-play" onclick="togglePlay()">▶ Play</button>
|
||||
<button onclick="stepFrame(-1)">◀</button>
|
||||
<button onclick="stepFrame(1)">▶</button>
|
||||
<button onclick="resetPlay()">⟳</button>
|
||||
</div>
|
||||
<input type="range" id="scrubber" min="0" max="1" value="0" step="1" />
|
||||
<div id="speed-ctrl">
|
||||
<label style="font-size:11px;color:#8b949e;">Speed:</label>
|
||||
<select id="speed-select" onchange="setSpeed(this.value)">
|
||||
<option value="0.25">0.25×</option>
|
||||
<option value="0.5">0.5×</option>
|
||||
<option value="1" selected>1×</option>
|
||||
<option value="2">2×</option>
|
||||
<option value="4">4×</option>
|
||||
</select>
|
||||
</div>
|
||||
<div id="frame-counter">Frame 0 / 0 · 0.00s</div>
|
||||
</div>
|
||||
|
||||
<script type="importmap">
|
||||
{
|
||||
"imports": {
|
||||
"three": "https://cdn.jsdelivr.net/npm/three@0.169.0/build/three.module.js",
|
||||
"three/examples/jsm/": "https://cdn.jsdelivr.net/npm/three@0.169.0/examples/jsm/"
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<script type="module">
|
||||
import * as THREE from 'three';
|
||||
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js';
|
||||
import { STLLoader } from 'three/examples/jsm/loaders/STLLoader.js';
|
||||
|
||||
let trajectory = null;
|
||||
let currentFrame = 0;
|
||||
let playing = false;
|
||||
let speed = 1.0;
|
||||
let lastTime = 0;
|
||||
let accumulator = 0;
|
||||
|
||||
// Anchor: EE tip world position at zero-joint pose (in Y-up Three.js space)
|
||||
const eeAnchor = new THREE.Vector3();
|
||||
// Z-up → Y-up rotation (same as robotGroup): -90° around X
|
||||
const zUpToYUp = new THREE.Quaternion().setFromAxisAngle(new THREE.Vector3(1, 0, 0), -Math.PI / 2);
|
||||
|
||||
const scene = new THREE.Scene();
|
||||
scene.background = new THREE.Color(0x0d1117);
|
||||
|
||||
const camera = new THREE.PerspectiveCamera(50, window.innerWidth / window.innerHeight, 0.01, 100);
|
||||
|
||||
const renderer = new THREE.WebGLRenderer({ antialias: true });
|
||||
renderer.setSize(window.innerWidth, window.innerHeight);
|
||||
renderer.setPixelRatio(window.devicePixelRatio);
|
||||
renderer.shadowMap.enabled = true;
|
||||
document.body.appendChild(renderer.domElement);
|
||||
|
||||
const controls = new OrbitControls(camera, renderer.domElement);
|
||||
controls.enableDamping = true;
|
||||
controls.dampingFactor = 0.08;
|
||||
|
||||
scene.add(new THREE.AmbientLight(0xffffff, 0.8));
|
||||
const dirLight = new THREE.DirectionalLight(0xffffff, 1.4);
|
||||
dirLight.position.set(2, 4, 3);
|
||||
scene.add(dirLight);
|
||||
scene.add(new THREE.DirectionalLight(0x8899cc, 0.6).translateX(-2).translateY(1).translateZ(-3));
|
||||
scene.add(new THREE.DirectionalLight(0xffffff, 0.5).translateY(-1).translateZ(2));
|
||||
|
||||
const grid = new THREE.GridHelper(2, 20, 0x21262d, 0x161b22);
|
||||
scene.add(grid);
|
||||
scene.add(new THREE.AxesHelper(0.15));
|
||||
|
||||
// EE marker
|
||||
const eeMarker = new THREE.Mesh(
|
||||
new THREE.SphereGeometry(0.012, 20, 20),
|
||||
new THREE.MeshStandardMaterial({ color: 0xff6b6b, emissive: 0xff6b6b, emissiveIntensity: 0.7 })
|
||||
);
|
||||
scene.add(eeMarker);
|
||||
eeMarker.add(new THREE.AxesHelper(0.06));
|
||||
|
||||
// Trajectory lines
|
||||
const MAX_POINTS = 2000;
|
||||
const pastGeo = new THREE.BufferGeometry();
|
||||
pastGeo.setAttribute('position', new THREE.Float32BufferAttribute(new Float32Array(MAX_POINTS * 3), 3));
|
||||
const pastLine = new THREE.Line(pastGeo, new THREE.LineBasicMaterial({ color: 0xffd43b, linewidth: 2 }));
|
||||
scene.add(pastLine);
|
||||
|
||||
const futureGeo = new THREE.BufferGeometry();
|
||||
futureGeo.setAttribute('position', new THREE.Float32BufferAttribute(new Float32Array(MAX_POINTS * 3), 3));
|
||||
const futureLine = new THREE.Line(futureGeo, new THREE.LineBasicMaterial({ color: 0x30363d, linewidth: 1 }));
|
||||
scene.add(futureLine);
|
||||
|
||||
// URDF
|
||||
const stlLoader = new STLLoader();
|
||||
const robotGroup = new THREE.Group();
|
||||
// URDF is Z-up; Three.js is Y-up → rotate -90° around X
|
||||
robotGroup.rotation.x = -Math.PI / 2;
|
||||
scene.add(robotGroup);
|
||||
let urdfLinks = {};
|
||||
|
||||
function rotvecToQuat(ax, ay, az) {
|
||||
const angle = Math.sqrt(ax * ax + ay * ay + az * az);
|
||||
if (angle < 1e-8) return new THREE.Quaternion();
|
||||
return new THREE.Quaternion().setFromAxisAngle(
|
||||
new THREE.Vector3(ax / angle, ay / angle, az / angle), angle
|
||||
);
|
||||
}
|
||||
|
||||
async function loadURDF() {
|
||||
const resp = await fetch('./openarm_bimanual_pybullet.urdf');
|
||||
const text = await resp.text();
|
||||
const xml = new DOMParser().parseFromString(text, 'text/xml');
|
||||
|
||||
const links = {};
|
||||
|
||||
for (const linkEl of xml.querySelectorAll('link')) {
|
||||
const name = linkEl.getAttribute('name');
|
||||
const group = new THREE.Group();
|
||||
group.name = name;
|
||||
|
||||
const visual = linkEl.querySelector('visual');
|
||||
if (visual) {
|
||||
const meshEl = visual.querySelector('mesh');
|
||||
const originEl = visual.querySelector('origin');
|
||||
if (meshEl) {
|
||||
const filename = meshEl.getAttribute('filename');
|
||||
const scaleStr = meshEl.getAttribute('scale');
|
||||
const sc = scaleStr ? scaleStr.split(' ').map(Number) : [1, 1, 1];
|
||||
let xyz = [0, 0, 0];
|
||||
if (originEl && originEl.getAttribute('xyz'))
|
||||
xyz = originEl.getAttribute('xyz').split(' ').map(Number);
|
||||
if (filename.endsWith('.stl')) {
|
||||
try {
|
||||
const geo = await new Promise((res, rej) =>
|
||||
stlLoader.load(filename, res, undefined, rej));
|
||||
const mesh = new THREE.Mesh(geo, new THREE.MeshStandardMaterial({
|
||||
color: 0x8899bb, metalness: 0.3, roughness: 0.5,
|
||||
}));
|
||||
mesh.scale.set(sc[0], sc[1], sc[2]);
|
||||
mesh.position.set(xyz[0], xyz[1], xyz[2]);
|
||||
group.add(mesh);
|
||||
} catch (e) { /* skip missing mesh */ }
|
||||
}
|
||||
}
|
||||
}
|
||||
links[name] = group;
|
||||
}
|
||||
|
||||
const rootLinks = new Set(Object.keys(links));
|
||||
|
||||
for (const jointEl of xml.querySelectorAll('joint')) {
|
||||
const parentName = jointEl.querySelector('parent').getAttribute('link');
|
||||
const childName = jointEl.querySelector('child').getAttribute('link');
|
||||
rootLinks.delete(childName);
|
||||
|
||||
const originEl = jointEl.querySelector('origin');
|
||||
let xyz = [0, 0, 0], rpy = [0, 0, 0];
|
||||
if (originEl) {
|
||||
if (originEl.getAttribute('xyz')) xyz = originEl.getAttribute('xyz').split(' ').map(Number);
|
||||
if (originEl.getAttribute('rpy')) rpy = originEl.getAttribute('rpy').split(' ').map(Number);
|
||||
}
|
||||
|
||||
const parent = links[parentName];
|
||||
const child = links[childName];
|
||||
if (!parent || !child) continue;
|
||||
|
||||
child.position.set(xyz[0], xyz[1], xyz[2]);
|
||||
if (rpy[0] || rpy[1] || rpy[2])
|
||||
child.rotation.set(rpy[0], rpy[1], rpy[2], 'XYZ');
|
||||
parent.add(child);
|
||||
}
|
||||
|
||||
for (const n of rootLinks)
|
||||
if (links[n]) robotGroup.add(links[n]);
|
||||
|
||||
// EE target marker on the URDF
|
||||
const eeTargetLink = links['openarm_right_ee_target'];
|
||||
if (eeTargetLink) {
|
||||
eeTargetLink.add(new THREE.Mesh(
|
||||
new THREE.TorusGeometry(0.02, 0.002, 8, 32),
|
||||
new THREE.MeshStandardMaterial({ color: 0xffaa00, emissive: 0xffaa00, emissiveIntensity: 0.5 })
|
||||
));
|
||||
eeTargetLink.add(new THREE.AxesHelper(0.05));
|
||||
}
|
||||
|
||||
urdfLinks = links;
|
||||
}
|
||||
|
||||
async function loadTrajectory() {
|
||||
const resp = await fetch('./trajectory_ep0.json');
|
||||
trajectory = await resp.json();
|
||||
document.getElementById('scrubber').max = trajectory.num_frames - 1;
|
||||
document.getElementById('scrubber').value = 0;
|
||||
}
|
||||
|
||||
function computeOffset() {
|
||||
if (!trajectory || !urdfLinks['openarm_right_ee_target']) return;
|
||||
|
||||
robotGroup.updateMatrixWorld(true);
|
||||
const eeLink = urdfLinks['openarm_right_ee_target'];
|
||||
eeLink.getWorldPosition(eeAnchor);
|
||||
|
||||
controls.target.copy(eeAnchor);
|
||||
camera.position.set(eeAnchor.x + 0.8, eeAnchor.y + 0.3, eeAnchor.z + 0.0);
|
||||
controls.update();
|
||||
|
||||
updateFrame(0);
|
||||
}
|
||||
|
||||
function mapFramePos(f) {
|
||||
const f0 = trajectory.frames[0];
|
||||
const delta = new THREE.Vector3(f.x - f0.x, f.y - f0.y, f.z - f0.z);
|
||||
delta.applyQuaternion(zUpToYUp);
|
||||
return delta.add(eeAnchor);
|
||||
}
|
||||
|
||||
function updateFrame(idx) {
|
||||
if (!trajectory) return;
|
||||
currentFrame = Math.max(0, Math.min(idx, trajectory.num_frames - 1));
|
||||
|
||||
const f = trajectory.frames[currentFrame];
|
||||
const pos = mapFramePos(f);
|
||||
|
||||
eeMarker.position.copy(pos);
|
||||
// Orientation: rotate the dataset axis-angle into Y-up space
|
||||
const q = rotvecToQuat(f.ax, f.ay, f.az);
|
||||
eeMarker.quaternion.copy(zUpToYUp).multiply(q);
|
||||
|
||||
// Past trajectory
|
||||
const pastArr = pastGeo.attributes.position.array;
|
||||
let pi = 0;
|
||||
for (let i = 0; i <= currentFrame && i < MAX_POINTS; i++) {
|
||||
const p = mapFramePos(trajectory.frames[i]);
|
||||
pastArr[pi++] = p.x; pastArr[pi++] = p.y; pastArr[pi++] = p.z;
|
||||
}
|
||||
pastGeo.setDrawRange(0, Math.min(currentFrame + 1, MAX_POINTS));
|
||||
pastGeo.attributes.position.needsUpdate = true;
|
||||
|
||||
// Future trajectory
|
||||
const futArr = futureGeo.attributes.position.array;
|
||||
let fi = 0;
|
||||
for (let i = currentFrame; i < trajectory.num_frames && (i - currentFrame) < MAX_POINTS; i++) {
|
||||
const p = mapFramePos(trajectory.frames[i]);
|
||||
futArr[fi++] = p.x; futArr[fi++] = p.y; futArr[fi++] = p.z;
|
||||
}
|
||||
futureGeo.setDrawRange(0, Math.min(trajectory.num_frames - currentFrame, MAX_POINTS));
|
||||
futureGeo.attributes.position.needsUpdate = true;
|
||||
|
||||
// UI
|
||||
document.getElementById('v-x').textContent = pos.x.toFixed(4);
|
||||
document.getElementById('v-y').textContent = pos.y.toFixed(4);
|
||||
document.getElementById('v-z').textContent = pos.z.toFixed(4);
|
||||
document.getElementById('v-ax').textContent = f.ax.toFixed(4);
|
||||
document.getElementById('v-ay').textContent = f.ay.toFixed(4);
|
||||
document.getElementById('v-az').textContent = f.az.toFixed(4);
|
||||
document.getElementById('v-grip').textContent =
|
||||
`p=${f.proximal.toFixed(2)} d=${f.distal.toFixed(2)}`;
|
||||
|
||||
document.getElementById('scrubber').value = currentFrame;
|
||||
const timeS = (currentFrame / trajectory.fps).toFixed(2);
|
||||
document.getElementById('frame-counter').textContent =
|
||||
`Frame ${currentFrame} / ${trajectory.num_frames - 1} · ${timeS}s`;
|
||||
}
|
||||
|
||||
// Playback controls
|
||||
window.togglePlay = function() {
|
||||
playing = !playing;
|
||||
const btn = document.getElementById('btn-play');
|
||||
btn.textContent = playing ? '⏸ Pause' : '▶ Play';
|
||||
btn.classList.toggle('active', playing);
|
||||
if (playing) { lastTime = performance.now(); accumulator = 0; }
|
||||
};
|
||||
|
||||
window.stepFrame = function(delta) {
|
||||
playing = false;
|
||||
document.getElementById('btn-play').textContent = '▶ Play';
|
||||
document.getElementById('btn-play').classList.remove('active');
|
||||
updateFrame(currentFrame + delta);
|
||||
};
|
||||
|
||||
window.resetPlay = function() {
|
||||
playing = false;
|
||||
document.getElementById('btn-play').textContent = '▶ Play';
|
||||
document.getElementById('btn-play').classList.remove('active');
|
||||
updateFrame(0);
|
||||
};
|
||||
|
||||
window.setSpeed = function(v) { speed = parseFloat(v); };
|
||||
|
||||
document.getElementById('scrubber').addEventListener('input', (e) => {
|
||||
updateFrame(parseInt(e.target.value));
|
||||
});
|
||||
|
||||
window.addEventListener('resize', () => {
|
||||
camera.aspect = window.innerWidth / window.innerHeight;
|
||||
camera.updateProjectionMatrix();
|
||||
renderer.setSize(window.innerWidth, window.innerHeight);
|
||||
});
|
||||
|
||||
function animate(now) {
|
||||
requestAnimationFrame(animate);
|
||||
controls.update();
|
||||
|
||||
if (playing && trajectory) {
|
||||
const dt = (now - lastTime) / 1000;
|
||||
lastTime = now;
|
||||
accumulator += dt * speed;
|
||||
const frameDuration = 1.0 / trajectory.fps;
|
||||
while (accumulator >= frameDuration) {
|
||||
accumulator -= frameDuration;
|
||||
if (currentFrame < trajectory.num_frames - 1) {
|
||||
updateFrame(currentFrame + 1);
|
||||
} else {
|
||||
playing = false;
|
||||
document.getElementById('btn-play').textContent = '▶ Play';
|
||||
document.getElementById('btn-play').classList.remove('active');
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
renderer.render(scene, camera);
|
||||
}
|
||||
requestAnimationFrame(animate);
|
||||
|
||||
Promise.all([loadURDF(), loadTrajectory()])
|
||||
.then(() => computeOffset())
|
||||
.catch(err => console.error(err));
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,311 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<title>OpenArm URDF Viewer</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body { background: #1a1a2e; overflow: hidden; font-family: 'IBM Plex Mono', monospace; }
|
||||
canvas { display: block; }
|
||||
#info {
|
||||
position: absolute; top: 16px; left: 16px;
|
||||
color: #e0e0e0; font-size: 13px; line-height: 1.6;
|
||||
background: rgba(0,0,0,0.7); padding: 14px 18px; border-radius: 8px;
|
||||
border: 1px solid #333; max-width: 340px; z-index: 10;
|
||||
}
|
||||
#info h2 { font-size: 15px; color: #fff; margin-bottom: 8px; }
|
||||
.legend { display: flex; align-items: center; gap: 8px; margin: 4px 0; }
|
||||
.dot { width: 12px; height: 12px; border-radius: 50%; display: inline-block; flex-shrink: 0; }
|
||||
.dot-red { background: #ff4444; }
|
||||
.dot-green { background: #44ff44; }
|
||||
.dot-blue { background: #4488ff; }
|
||||
#frame-select { margin-top: 10px; }
|
||||
#frame-select button {
|
||||
background: #333; color: #e0e0e0; border: 1px solid #555;
|
||||
padding: 6px 10px; margin: 2px; border-radius: 4px; cursor: pointer;
|
||||
font-family: inherit; font-size: 12px;
|
||||
}
|
||||
#frame-select button:hover { background: #555; }
|
||||
#frame-select button.active { background: #4488ff; color: #fff; border-color: #4488ff; }
|
||||
#status { margin-top: 8px; font-size: 11px; color: #888; }
|
||||
</style>
|
||||
<link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&display=swap" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<div id="info">
|
||||
<h2>OpenArm Right Arm — EE Frame Options</h2>
|
||||
<div class="legend"><span class="dot dot-red"></span> openarm_right_link7 (wrist output)</div>
|
||||
<div class="legend"><span class="dot" style="background:#ffaa00"></span> openarm_right_ee_target (+5cm)</div>
|
||||
<div class="legend"><span class="dot dot-green"></span> openarm_right_hand (+10cm)</div>
|
||||
<div class="legend"><span class="dot dot-blue"></span> openarm_right_hand_tcp (+18cm)</div>
|
||||
<div id="frame-select">
|
||||
<button onclick="focusFrame('link7')" class="active">link7</button>
|
||||
<button onclick="focusFrame('ee_target')">ee_target</button>
|
||||
<button onclick="focusFrame('hand')">hand</button>
|
||||
<button onclick="focusFrame('tcp')">hand_tcp</button>
|
||||
</div>
|
||||
<div id="status">Loading URDF...</div>
|
||||
<p style="margin-top:8px;font-size:11px;color:#888;">Drag to orbit · Scroll to zoom · Right-drag to pan</p>
|
||||
</div>
|
||||
|
||||
<script type="importmap">
|
||||
{
|
||||
"imports": {
|
||||
"three": "https://cdn.jsdelivr.net/npm/three@0.169.0/build/three.module.js",
|
||||
"three/examples/jsm/": "https://cdn.jsdelivr.net/npm/three@0.169.0/examples/jsm/"
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<script type="module">
|
||||
import * as THREE from 'three';
|
||||
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls.js';
|
||||
import { STLLoader } from 'three/examples/jsm/loaders/STLLoader.js';
|
||||
|
||||
const statusEl = document.getElementById('status');
|
||||
|
||||
const scene = new THREE.Scene();
|
||||
scene.background = new THREE.Color(0x1a1a2e);
|
||||
|
||||
const camera = new THREE.PerspectiveCamera(50, window.innerWidth / window.innerHeight, 0.01, 100);
|
||||
camera.position.set(0.8, 1.0, 1.8);
|
||||
|
||||
const renderer = new THREE.WebGLRenderer({ antialias: true });
|
||||
renderer.setSize(window.innerWidth, window.innerHeight);
|
||||
renderer.setPixelRatio(window.devicePixelRatio);
|
||||
renderer.shadowMap.enabled = true;
|
||||
document.body.appendChild(renderer.domElement);
|
||||
|
||||
const controls = new OrbitControls(camera, renderer.domElement);
|
||||
controls.target.set(0, 0, 0.9);
|
||||
controls.enableDamping = true;
|
||||
controls.dampingFactor = 0.08;
|
||||
controls.update();
|
||||
|
||||
// Lighting
|
||||
scene.add(new THREE.AmbientLight(0xffffff, 0.6));
|
||||
const dirLight = new THREE.DirectionalLight(0xffffff, 1.2);
|
||||
dirLight.position.set(3, 5, 4);
|
||||
scene.add(dirLight);
|
||||
scene.add(new THREE.DirectionalLight(0x8888ff, 0.4).translateX(-2).translateY(1).translateZ(-3));
|
||||
|
||||
// Ground grid
|
||||
scene.add(new THREE.GridHelper(4, 40, 0x333355, 0x222244));
|
||||
scene.add(new THREE.AxesHelper(0.3));
|
||||
|
||||
// Parse URDF manually — build the kinematic tree and load STL meshes
|
||||
const stlLoader = new STLLoader();
|
||||
const robotGroup = new THREE.Group();
|
||||
scene.add(robotGroup);
|
||||
|
||||
async function loadURDF() {
|
||||
const resp = await fetch('./openarm_bimanual_pybullet.urdf');
|
||||
const text = await resp.text();
|
||||
const parser = new DOMParser();
|
||||
const xml = parser.parseFromString(text, 'text/xml');
|
||||
|
||||
// Parse links and joints
|
||||
const links = {};
|
||||
const joints = [];
|
||||
|
||||
for (const linkEl of xml.querySelectorAll('link')) {
|
||||
const name = linkEl.getAttribute('name');
|
||||
const group = new THREE.Group();
|
||||
group.name = name;
|
||||
|
||||
// Try to load visual mesh
|
||||
const visual = linkEl.querySelector('visual');
|
||||
if (visual) {
|
||||
const meshEl = visual.querySelector('mesh');
|
||||
const originEl = visual.querySelector('origin');
|
||||
if (meshEl) {
|
||||
const filename = meshEl.getAttribute('filename');
|
||||
const scaleStr = meshEl.getAttribute('scale');
|
||||
const scale = scaleStr ? scaleStr.split(' ').map(Number) : [1, 1, 1];
|
||||
|
||||
let xyz = [0, 0, 0];
|
||||
if (originEl && originEl.getAttribute('xyz')) {
|
||||
xyz = originEl.getAttribute('xyz').split(' ').map(Number);
|
||||
}
|
||||
|
||||
if (filename.endsWith('.stl')) {
|
||||
try {
|
||||
const geo = await new Promise((resolve, reject) => {
|
||||
stlLoader.load(filename, resolve, undefined, reject);
|
||||
});
|
||||
const mat = new THREE.MeshStandardMaterial({
|
||||
color: 0x6688aa,
|
||||
metalness: 0.3,
|
||||
roughness: 0.6,
|
||||
transparent: true,
|
||||
opacity: 0.7,
|
||||
});
|
||||
const mesh = new THREE.Mesh(geo, mat);
|
||||
mesh.scale.set(scale[0], scale[1], scale[2]);
|
||||
mesh.position.set(xyz[0], xyz[1], xyz[2]);
|
||||
group.add(mesh);
|
||||
} catch (e) {
|
||||
// Mesh file not found, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
links[name] = group;
|
||||
}
|
||||
|
||||
// Parse joints and build hierarchy
|
||||
for (const jointEl of xml.querySelectorAll('joint')) {
|
||||
const name = jointEl.getAttribute('name');
|
||||
const type = jointEl.getAttribute('type');
|
||||
const parentName = jointEl.querySelector('parent').getAttribute('link');
|
||||
const childName = jointEl.querySelector('child').getAttribute('link');
|
||||
const originEl = jointEl.querySelector('origin');
|
||||
|
||||
let xyz = [0, 0, 0];
|
||||
let rpy = [0, 0, 0];
|
||||
if (originEl) {
|
||||
if (originEl.getAttribute('xyz')) xyz = originEl.getAttribute('xyz').split(' ').map(Number);
|
||||
if (originEl.getAttribute('rpy')) rpy = originEl.getAttribute('rpy').split(' ').map(Number);
|
||||
}
|
||||
|
||||
joints.push({ name, type, parentName, childName, xyz, rpy });
|
||||
}
|
||||
|
||||
// Build tree
|
||||
const rootLinks = new Set(Object.keys(links));
|
||||
for (const j of joints) {
|
||||
rootLinks.delete(j.childName);
|
||||
}
|
||||
|
||||
for (const j of joints) {
|
||||
const parent = links[j.parentName];
|
||||
const child = links[j.childName];
|
||||
if (parent && child) {
|
||||
child.position.set(j.xyz[0], j.xyz[1], j.xyz[2]);
|
||||
if (j.rpy[0] || j.rpy[1] || j.rpy[2]) {
|
||||
child.rotation.set(j.rpy[0], j.rpy[1], j.rpy[2], 'XYZ');
|
||||
}
|
||||
parent.add(child);
|
||||
}
|
||||
}
|
||||
|
||||
// Add root links to scene
|
||||
for (const name of rootLinks) {
|
||||
if (links[name]) robotGroup.add(links[name]);
|
||||
}
|
||||
|
||||
// Place markers at the three EE frame candidates
|
||||
const targets = {
|
||||
link7: 'openarm_right_link7',
|
||||
ee_target: 'openarm_right_ee_target',
|
||||
hand: 'openarm_right_hand',
|
||||
tcp: 'openarm_right_hand_tcp',
|
||||
};
|
||||
const colors = { link7: 0xff4444, ee_target: 0xffaa00, hand: 0x44ff44, tcp: 0x4488ff };
|
||||
const labels = { link7: 'link7', ee_target: 'ee_target', hand: 'hand', tcp: 'hand_tcp' };
|
||||
|
||||
for (const [key, linkName] of Object.entries(targets)) {
|
||||
const link = links[linkName];
|
||||
if (!link) continue;
|
||||
|
||||
// Marker sphere
|
||||
const sphere = new THREE.Mesh(
|
||||
new THREE.SphereGeometry(0.018, 16, 16),
|
||||
new THREE.MeshStandardMaterial({ color: colors[key], emissive: colors[key], emissiveIntensity: 0.6 })
|
||||
);
|
||||
link.add(sphere);
|
||||
|
||||
// Ring around sphere for visibility
|
||||
const ring = new THREE.Mesh(
|
||||
new THREE.TorusGeometry(0.03, 0.003, 8, 32),
|
||||
new THREE.MeshStandardMaterial({ color: colors[key], emissive: colors[key], emissiveIntensity: 0.4 })
|
||||
);
|
||||
link.add(ring);
|
||||
|
||||
// Axes helper
|
||||
link.add(new THREE.AxesHelper(0.08));
|
||||
|
||||
// Sprite label
|
||||
const canvas = document.createElement('canvas');
|
||||
canvas.width = 512; canvas.height = 80;
|
||||
const ctx = canvas.getContext('2d');
|
||||
ctx.font = 'bold 36px IBM Plex Mono, monospace';
|
||||
ctx.fillStyle = '#' + colors[key].toString(16).padStart(6, '0');
|
||||
ctx.fillText(labels[key], 4, 50);
|
||||
const tex = new THREE.CanvasTexture(canvas);
|
||||
const sprite = new THREE.Sprite(new THREE.SpriteMaterial({ map: tex, depthTest: false }));
|
||||
sprite.scale.set(0.3, 0.05, 1);
|
||||
sprite.position.set(0.06, 0.0, 0.03);
|
||||
link.add(sprite);
|
||||
}
|
||||
|
||||
// Dashed lines between markers (in world space)
|
||||
robotGroup.updateMatrixWorld(true);
|
||||
const positions = {};
|
||||
for (const [key, linkName] of Object.entries(targets)) {
|
||||
const link = links[linkName];
|
||||
if (link) {
|
||||
const wp = new THREE.Vector3();
|
||||
link.getWorldPosition(wp);
|
||||
positions[key] = wp;
|
||||
}
|
||||
}
|
||||
|
||||
function addDashedLine(from, to) {
|
||||
const geo = new THREE.BufferGeometry().setFromPoints([from, to]);
|
||||
const mat = new THREE.LineDashedMaterial({ color: 0xaaaaaa, dashSize: 0.012, gapSize: 0.008 });
|
||||
const line = new THREE.Line(geo, mat);
|
||||
line.computeLineDistances();
|
||||
scene.add(line);
|
||||
}
|
||||
if (positions.link7 && positions.hand) addDashedLine(positions.link7, positions.hand);
|
||||
if (positions.hand && positions.tcp) addDashedLine(positions.hand, positions.tcp);
|
||||
|
||||
// Store for focus buttons
|
||||
window._framePositions = positions;
|
||||
window._links = links;
|
||||
window._targets = targets;
|
||||
|
||||
// Focus on the hand area
|
||||
if (positions.hand) {
|
||||
controls.target.copy(positions.hand);
|
||||
camera.position.set(positions.hand.x + 0.5, positions.hand.y + 0.4, positions.hand.z + 0.5);
|
||||
controls.update();
|
||||
}
|
||||
|
||||
const meshCount = robotGroup.children.length;
|
||||
statusEl.textContent = `Loaded. Right arm chain visible with ${Object.keys(links).length} links.`;
|
||||
}
|
||||
|
||||
window.focusFrame = function(key) {
|
||||
const pos = window._framePositions?.[key];
|
||||
if (!pos) return;
|
||||
controls.target.copy(pos);
|
||||
camera.position.set(pos.x + 0.35, pos.y + 0.25, pos.z + 0.35);
|
||||
controls.update();
|
||||
document.querySelectorAll('#frame-select button').forEach(b => b.classList.remove('active'));
|
||||
event.target.classList.add('active');
|
||||
};
|
||||
|
||||
window.addEventListener('resize', () => {
|
||||
camera.aspect = window.innerWidth / window.innerHeight;
|
||||
camera.updateProjectionMatrix();
|
||||
renderer.setSize(window.innerWidth, window.innerHeight);
|
||||
});
|
||||
|
||||
function animate() {
|
||||
requestAnimationFrame(animate);
|
||||
controls.update();
|
||||
renderer.render(scene, camera);
|
||||
}
|
||||
animate();
|
||||
|
||||
loadURDF().catch(err => {
|
||||
statusEl.textContent = `Error: ${err.message}`;
|
||||
console.error(err);
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -255,16 +255,19 @@ class InverseKinematicsEEToJoints(RobotActionProcessorStep):
|
||||
"""
|
||||
Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
|
||||
|
||||
This step translates a Cartesian command (position and orientation of the end-effector) into
|
||||
the corresponding joint-space commands for each motor.
|
||||
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model for inverse kinematics.
|
||||
motor_names: Arm joint names for IK computation.
|
||||
gripper_names: Gripper joint name(s). ee.gripper_pos is written to all of them.
|
||||
motor_names: A list of motor names for which to compute joint positions.
|
||||
q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver.
|
||||
initial_guess_current_joints: If True, use the robot's current joint state as the IK guess.
|
||||
If False, use the solution from the previous step.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
initial_guess_current_joints: bool = True
|
||||
|
||||
@@ -275,73 +278,63 @@ class InverseKinematicsEEToJoints(RobotActionProcessorStep):
|
||||
wx = action.pop("ee.wx")
|
||||
wy = action.pop("ee.wy")
|
||||
wz = action.pop("ee.wz")
|
||||
gripper_pos = action.pop("ee.gripper_pos")
|
||||
|
||||
ee_keys = [x, y, z, wx, wy, wz]
|
||||
|
||||
if self.gripper_names:
|
||||
gripper_pos = action.pop("ee.gripper_pos")
|
||||
ee_keys.append(gripper_pos)
|
||||
if None in ee_keys:
|
||||
raise ValueError("Missing required end-effector pose components in action")
|
||||
if None in (x, y, z, wx, wy, wz, gripper_pos):
|
||||
raise ValueError(
|
||||
"Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action"
|
||||
)
|
||||
|
||||
observation = self.transition.get(TransitionKey.OBSERVATION).copy()
|
||||
if observation is None:
|
||||
raise ValueError("Joints observation is required for computing robot kinematics")
|
||||
raise ValueError("Joints observation is require for computing robot kinematics")
|
||||
|
||||
q_raw = np.array(
|
||||
[
|
||||
float(v)
|
||||
for k, v in observation.items()
|
||||
if isinstance(k, str) and k.endswith(".pos") and k.removesuffix(".pos") in self.motor_names
|
||||
],
|
||||
[float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")],
|
||||
dtype=float,
|
||||
)
|
||||
if q_raw is None:
|
||||
raise ValueError("Joints observation is require for computing robot kinematics")
|
||||
|
||||
if self.initial_guess_current_joints:
|
||||
if self.initial_guess_current_joints: # Use current joints as initial guess
|
||||
self.q_curr = q_raw
|
||||
else:
|
||||
else: # Use previous ik solution as initial guess
|
||||
if self.q_curr is None:
|
||||
self.q_curr = q_raw
|
||||
|
||||
# Build desired 4x4 transform from pos + rotvec (twist)
|
||||
t_des = np.eye(4, dtype=float)
|
||||
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
t_des[:3, 3] = [x, y, z]
|
||||
|
||||
# Compute inverse kinematics
|
||||
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
|
||||
self.q_curr = q_target
|
||||
|
||||
# TODO: This is sentitive to order of motor_names = q_target mapping
|
||||
for i, name in enumerate(self.motor_names):
|
||||
action[f"{name}.pos"] = float(q_target[i])
|
||||
if name != "gripper":
|
||||
action[f"{name}.pos"] = float(q_target[i])
|
||||
else:
|
||||
action["gripper.pos"] = float(gripper_pos)
|
||||
|
||||
if self.gripper_names:
|
||||
for gname in self.gripper_names:
|
||||
action[f"{gname}.pos"] = float(gripper_pos)
|
||||
|
||||
# When gripper_names is empty, gripper keys (e.g. proximal.pos, distal.pos)
|
||||
# are already in the action dict as absolute positions — left untouched.
|
||||
return action
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
ee_feats = ["x", "y", "z", "wx", "wy", "wz"]
|
||||
if self.gripper_names:
|
||||
ee_feats.append("gripper_pos")
|
||||
for feat in ee_feats:
|
||||
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
|
||||
|
||||
for name in self.motor_names:
|
||||
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
for name in self.gripper_names:
|
||||
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
"""Resets the initial guess for the IK solver."""
|
||||
self.q_curr = None
|
||||
|
||||
|
||||
@@ -409,39 +402,24 @@ class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
|
||||
|
||||
def compute_forward_kinematics_joints_to_ee(
|
||||
joints: dict[str, Any],
|
||||
kinematics: RobotKinematics,
|
||||
motor_names: list[str],
|
||||
gripper_names: list[str] | None = None,
|
||||
joints: dict[str, Any], kinematics: RobotKinematics, motor_names: list[str]
|
||||
) -> dict[str, Any]:
|
||||
if gripper_names is None:
|
||||
gripper_names = ["gripper"]
|
||||
|
||||
motor_joint_values = [joints[f"{n}.pos"] for n in motor_names]
|
||||
|
||||
q = np.array(motor_joint_values, dtype=float)
|
||||
t = kinematics.forward_kinematics(q)
|
||||
pos = t[:3, 3]
|
||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
gripper_pos = joints["gripper.pos"]
|
||||
for n in motor_names:
|
||||
joints.pop(f"{n}.pos")
|
||||
|
||||
joints["ee.x"] = float(pos[0])
|
||||
joints["ee.y"] = float(pos[1])
|
||||
joints["ee.z"] = float(pos[2])
|
||||
joints["ee.wx"] = float(tw[0])
|
||||
joints["ee.wy"] = float(tw[1])
|
||||
joints["ee.wz"] = float(tw[2])
|
||||
|
||||
# When gripper_names is non-empty, fold them into ee.gripper_pos (e.g. SO100).
|
||||
# When empty, gripper joints pass through as-is (absolute position control).
|
||||
if gripper_names:
|
||||
gripper_pos = joints[f"{gripper_names[0]}.pos"]
|
||||
for n in gripper_names:
|
||||
joints.pop(f"{n}.pos", None)
|
||||
joints["ee.gripper_pos"] = float(gripper_pos)
|
||||
|
||||
joints["ee.gripper_pos"] = float(gripper_pos)
|
||||
return joints
|
||||
|
||||
|
||||
@@ -451,33 +429,27 @@ class ForwardKinematicsJointsToEEObservation(ObservationProcessorStep):
|
||||
"""
|
||||
Computes the end-effector pose from joint positions using forward kinematics (FK).
|
||||
|
||||
This step is typically used to add the robot's Cartesian pose to the observation space,
|
||||
which can be useful for visualization or as an input to a policy.
|
||||
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model.
|
||||
motor_names: Arm joint names used for FK computation.
|
||||
gripper_names: Gripper joint name(s) to fold into ee.gripper_pos.
|
||||
Empty list means gripper joints pass through as absolute positions.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
return compute_forward_kinematics_joints_to_ee(
|
||||
observation, self.kinematics, self.motor_names, self.gripper_names
|
||||
)
|
||||
return compute_forward_kinematics_joints_to_ee(observation, self.kinematics, self.motor_names)
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We only use the ee pose in the dataset, so we don't need the joint positions
|
||||
for n in self.motor_names:
|
||||
features[PipelineFeatureType.OBSERVATION].pop(f"{n}.pos", None)
|
||||
ee_keys = ["x", "y", "z", "wx", "wy", "wz"]
|
||||
if self.gripper_names:
|
||||
for n in self.gripper_names:
|
||||
features[PipelineFeatureType.OBSERVATION].pop(f"{n}.pos", None)
|
||||
ee_keys.append("gripper_pos")
|
||||
for k in ee_keys:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
|
||||
features[PipelineFeatureType.OBSERVATION][f"ee.{k}"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(1,)
|
||||
)
|
||||
@@ -490,33 +462,27 @@ class ForwardKinematicsJointsToEEAction(RobotActionProcessorStep):
|
||||
"""
|
||||
Computes the end-effector pose from joint positions using forward kinematics (FK).
|
||||
|
||||
This step is typically used to add the robot's Cartesian pose to the observation space,
|
||||
which can be useful for visualization or as an input to a policy.
|
||||
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model.
|
||||
motor_names: Arm joint names used for FK computation.
|
||||
gripper_names: Gripper joint name(s) to fold into ee.gripper_pos.
|
||||
Empty list means gripper joints pass through as absolute positions.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
return compute_forward_kinematics_joints_to_ee(
|
||||
action, self.kinematics, self.motor_names, self.gripper_names
|
||||
)
|
||||
return compute_forward_kinematics_joints_to_ee(action, self.kinematics, self.motor_names)
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We only use the ee pose in the dataset, so we don't need the joint positions
|
||||
for n in self.motor_names:
|
||||
features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None)
|
||||
ee_keys = ["x", "y", "z", "wx", "wy", "wz"]
|
||||
if self.gripper_names:
|
||||
for n in self.gripper_names:
|
||||
features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None)
|
||||
ee_keys.append("gripper_pos")
|
||||
for k in ee_keys:
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
|
||||
features[PipelineFeatureType.ACTION][f"ee.{k}"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(1,)
|
||||
)
|
||||
@@ -528,14 +494,13 @@ class ForwardKinematicsJointsToEEAction(RobotActionProcessorStep):
|
||||
class ForwardKinematicsJointsToEE(ProcessorStep):
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
|
||||
def __post_init__(self):
|
||||
self.joints_to_ee_action_processor = ForwardKinematicsJointsToEEAction(
|
||||
kinematics=self.kinematics, motor_names=self.motor_names, gripper_names=self.gripper_names
|
||||
kinematics=self.kinematics, motor_names=self.motor_names
|
||||
)
|
||||
self.joints_to_ee_observation_processor = ForwardKinematicsJointsToEEObservation(
|
||||
kinematics=self.kinematics, motor_names=self.motor_names, gripper_names=self.gripper_names
|
||||
kinematics=self.kinematics, motor_names=self.motor_names
|
||||
)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
@@ -559,13 +524,13 @@ class ForwardKinematicsJointsToEE(ProcessorStep):
|
||||
@dataclass
|
||||
class InverseKinematicsRLStep(ProcessorStep):
|
||||
"""
|
||||
IK step for the RL pipeline. Same logic as InverseKinematicsEEToJoints but
|
||||
operates on EnvTransition directly and stores the IK solution.
|
||||
Computes desired joint positions from a target end-effector pose using inverse kinematics (IK).
|
||||
|
||||
This is modified from the InverseKinematicsEEToJoints step to be used in the RL pipeline.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
gripper_names: list[str] = field(default_factory=lambda: ["gripper"])
|
||||
q_curr: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
initial_guess_current_joints: bool = True
|
||||
|
||||
@@ -573,7 +538,7 @@ class InverseKinematicsRLStep(ProcessorStep):
|
||||
new_transition = dict(transition)
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
raise ValueError("Action is required for InverseKinematicsRLStep")
|
||||
raise ValueError("Action is required for InverseKinematicsEEToJoints")
|
||||
action = dict(action)
|
||||
|
||||
x = action.pop("ee.x")
|
||||
@@ -582,46 +547,45 @@ class InverseKinematicsRLStep(ProcessorStep):
|
||||
wx = action.pop("ee.wx")
|
||||
wy = action.pop("ee.wy")
|
||||
wz = action.pop("ee.wz")
|
||||
gripper_pos = action.pop("ee.gripper_pos")
|
||||
|
||||
ee_keys = [x, y, z, wx, wy, wz]
|
||||
if self.gripper_names:
|
||||
gripper_pos = action.pop("ee.gripper_pos")
|
||||
ee_keys.append(gripper_pos)
|
||||
if None in ee_keys:
|
||||
raise ValueError("Missing required end-effector pose components in action")
|
||||
if None in (x, y, z, wx, wy, wz, gripper_pos):
|
||||
raise ValueError(
|
||||
"Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action"
|
||||
)
|
||||
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION).copy()
|
||||
if observation is None:
|
||||
raise ValueError("Joints observation is required for computing robot kinematics")
|
||||
raise ValueError("Joints observation is require for computing robot kinematics")
|
||||
|
||||
q_raw = np.array(
|
||||
[
|
||||
float(v)
|
||||
for k, v in observation.items()
|
||||
if isinstance(k, str) and k.endswith(".pos") and k.removesuffix(".pos") in self.motor_names
|
||||
],
|
||||
[float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")],
|
||||
dtype=float,
|
||||
)
|
||||
if q_raw is None:
|
||||
raise ValueError("Joints observation is require for computing robot kinematics")
|
||||
|
||||
if self.initial_guess_current_joints:
|
||||
if self.initial_guess_current_joints: # Use current joints as initial guess
|
||||
self.q_curr = q_raw
|
||||
else:
|
||||
else: # Use previous ik solution as initial guess
|
||||
if self.q_curr is None:
|
||||
self.q_curr = q_raw
|
||||
|
||||
# Build desired 4x4 transform from pos + rotvec (twist)
|
||||
t_des = np.eye(4, dtype=float)
|
||||
t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix()
|
||||
t_des[:3, 3] = [x, y, z]
|
||||
|
||||
# Compute inverse kinematics
|
||||
q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des)
|
||||
self.q_curr = q_target
|
||||
|
||||
# TODO: This is sentitive to order of motor_names = q_target mapping
|
||||
for i, name in enumerate(self.motor_names):
|
||||
action[f"{name}.pos"] = float(q_target[i])
|
||||
|
||||
if self.gripper_names:
|
||||
for gname in self.gripper_names:
|
||||
action[f"{gname}.pos"] = float(gripper_pos)
|
||||
if name != "gripper":
|
||||
action[f"{name}.pos"] = float(q_target[i])
|
||||
else:
|
||||
action["gripper.pos"] = float(gripper_pos)
|
||||
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
@@ -632,22 +596,16 @@ class InverseKinematicsRLStep(ProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
ee_feats = ["x", "y", "z", "wx", "wy", "wz"]
|
||||
if self.gripper_names:
|
||||
ee_feats.append("gripper_pos")
|
||||
for feat in ee_feats:
|
||||
for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
|
||||
features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None)
|
||||
|
||||
for name in self.motor_names:
|
||||
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
for name in self.gripper_names:
|
||||
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
def reset(self):
|
||||
"""Resets the initial guess for the IK solver."""
|
||||
self.q_curr = None
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
Edit LeRobot datasets using various transformation tools.
|
||||
|
||||
This script allows you to delete episodes, split datasets, merge datasets,
|
||||
remove features, modify tasks, recompute stats, and convert image datasets to video format.
|
||||
remove features, modify tasks, and convert image datasets to video format.
|
||||
When new_repo_id is specified, creates a new dataset.
|
||||
|
||||
Path semantics (v2): --root and --new_root are exact dataset folders containing
|
||||
@@ -148,21 +148,6 @@ Show dataset information without feature details:
|
||||
--operation.type info \
|
||||
--operation.show_features false
|
||||
|
||||
Recompute dataset statistics:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type recompute_stats
|
||||
|
||||
Recompute stats for relative actions and push to hub:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type recompute_stats \
|
||||
--operation.relative_action true \
|
||||
--operation.chunk_size 50 \
|
||||
--operation.relative_exclude_joints "['gripper']" \
|
||||
--operation.num_workers 4 \
|
||||
--push_to_hub true
|
||||
|
||||
Using JSON config file:
|
||||
lerobot-edit-dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
@@ -183,7 +168,6 @@ from lerobot.datasets.dataset_tools import (
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_tasks,
|
||||
recompute_stats,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
@@ -246,20 +230,6 @@ class ConvertImageToVideoConfig(OperationConfig):
|
||||
max_frames_per_batch: int | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("recompute_stats")
|
||||
@dataclass
|
||||
class RecomputeStatsConfig(OperationConfig):
|
||||
skip_image_video: bool = True
|
||||
relative_action: bool = False
|
||||
relative_exclude_joints: list[str] | None = None
|
||||
chunk_size: int = 50
|
||||
num_workers: int = 0
|
||||
relative_state: bool = False
|
||||
relative_exclude_state_joints: list[str] | None = None
|
||||
state_obs_steps: int = 2
|
||||
derive_state_from_action: bool = False
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("info")
|
||||
@dataclass
|
||||
class InfoConfig(OperationConfig):
|
||||
@@ -555,47 +525,6 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
def handle_recompute_stats(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, RecomputeStatsConfig):
|
||||
raise ValueError("Operation config must be RecomputeStatsConfig")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
|
||||
logging.info(f"Recomputing stats for {cfg.repo_id}")
|
||||
if cfg.operation.relative_action:
|
||||
logging.info(
|
||||
f"Relative action stats enabled (chunk_size={cfg.operation.chunk_size}, "
|
||||
f"exclude_joints={cfg.operation.relative_exclude_joints})"
|
||||
)
|
||||
if cfg.operation.relative_state:
|
||||
logging.info(
|
||||
f"Relative state stats enabled (state_obs_steps={cfg.operation.state_obs_steps}, "
|
||||
f"exclude_state_joints={cfg.operation.relative_exclude_state_joints})"
|
||||
)
|
||||
|
||||
if cfg.operation.derive_state_from_action:
|
||||
logging.info("Derive state from action enabled (implies relative_state=True, state_obs_steps=2)")
|
||||
|
||||
recompute_stats(
|
||||
dataset,
|
||||
skip_image_video=cfg.operation.skip_image_video,
|
||||
relative_action=cfg.operation.relative_action,
|
||||
relative_exclude_joints=cfg.operation.relative_exclude_joints,
|
||||
chunk_size=cfg.operation.chunk_size,
|
||||
num_workers=cfg.operation.num_workers,
|
||||
relative_state=cfg.operation.relative_state,
|
||||
relative_exclude_state_joints=cfg.operation.relative_exclude_state_joints,
|
||||
state_obs_steps=cfg.operation.state_obs_steps,
|
||||
derive_state_from_action=cfg.operation.derive_state_from_action,
|
||||
)
|
||||
|
||||
logging.info(f"Stats written to {dataset.root}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {dataset.meta.repo_id}...")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
def _get_dataset_size(repo_path):
|
||||
import os
|
||||
|
||||
@@ -667,8 +596,6 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "recompute_stats":
|
||||
handle_recompute_stats(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
|
||||
@@ -65,7 +65,6 @@ def get_sys_info() -> dict[str, str]:
|
||||
"Platform": platform.platform(),
|
||||
"Python version": platform.python_version(),
|
||||
"Huggingface Hub version": get_package_version("huggingface_hub"),
|
||||
"Transformers version": get_package_version("transformers"),
|
||||
"Datasets version": get_package_version("datasets"),
|
||||
"Numpy version": get_package_version("numpy"),
|
||||
"FFmpeg version": get_ffmpeg_version(),
|
||||
|
||||
@@ -468,8 +468,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
try:
|
||||
if cfg.resume:
|
||||
num_cameras = len(robot.cameras) if hasattr(robot, "cameras") else 0
|
||||
dataset = LeRobotDataset.resume(
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
@@ -477,11 +476,13 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
else 0,
|
||||
)
|
||||
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.dataset.num_image_writer_processes,
|
||||
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
|
||||
@@ -104,13 +104,15 @@ def replay(cfg: ReplayConfig):
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
|
||||
|
||||
actions = dataset.select_columns(ACTION)
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
robot.connect()
|
||||
|
||||
try:
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx][ACTION]
|
||||
|
||||
@@ -252,22 +252,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
processor_pretrained_path = cfg.policy.pretrained_path
|
||||
if (
|
||||
getattr(cfg.policy, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_relative_actions=true with pretrained processors can skip relative transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
@@ -275,7 +263,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if processor_pretrained_path is not None:
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
@@ -297,7 +285,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
This script:
|
||||
1. Loads action chunks from LeRobotDataset (with episode sampling)
|
||||
2. Optionally applies relative transforms (relative vs absolute actions)
|
||||
2. Optionally applies delta transforms (relative vs absolute actions)
|
||||
3. Extracts specified action dimensions for encoding
|
||||
4. Applies normalization (MEAN_STD, MIN_MAX, QUANTILES, or other modes)
|
||||
5. Trains FAST tokenizer (BPE on DCT coefficients) on the action chunks
|
||||
@@ -32,8 +32,8 @@ lerobot-train-tokenizer \
|
||||
--max_episodes=100 \
|
||||
--sample_fraction=0.1 \
|
||||
--encoded_dims="0:6" \
|
||||
--relative_dims="0,1,2,3,4,5" \
|
||||
--use_relative_transform=true \
|
||||
--delta_dims="0,1,2,3,4,5" \
|
||||
--use_delta_transform=true \
|
||||
--state_key="observation.state" \
|
||||
--normalization_mode="QUANTILES" \
|
||||
--vocab_size=1024 \
|
||||
@@ -82,10 +82,10 @@ class TokenizerTrainingConfig:
|
||||
sample_fraction: float = 0.1
|
||||
# Comma-separated dimension ranges to encode (e.g., "0:6,7:23")
|
||||
encoded_dims: str = "0:6,7:23"
|
||||
# Comma-separated dimension indices for relative transform (e.g., "0,1,2,3,4,5")
|
||||
relative_dims: str | None = None
|
||||
# Whether to apply relative transform (relative actions vs absolute actions)
|
||||
use_relative_transform: bool = False
|
||||
# Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5")
|
||||
delta_dims: str | None = None
|
||||
# Whether to apply delta transform (relative actions vs absolute actions)
|
||||
use_delta_transform: bool = False
|
||||
# Dataset key for state observations (default: "observation.state")
|
||||
state_key: str = OBS_STATE
|
||||
# Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY)
|
||||
@@ -104,27 +104,25 @@ class TokenizerTrainingConfig:
|
||||
hub_private: bool = False
|
||||
|
||||
|
||||
def apply_relative_transform(
|
||||
state: np.ndarray, actions: np.ndarray, relative_dims: list[int] | None
|
||||
) -> np.ndarray:
|
||||
"""Apply relative transform to specified dimensions.
|
||||
def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray:
|
||||
"""Apply delta transform to specified dimensions.
|
||||
|
||||
Args:
|
||||
state: Current state [D]
|
||||
actions: Future actions [D]
|
||||
relative_dims: List of dimension indices to apply relative transform to
|
||||
delta_dims: List of dimension indices to apply delta transform to
|
||||
|
||||
Returns:
|
||||
Transformed actions [D]
|
||||
"""
|
||||
if relative_dims is None or len(relative_dims) == 0:
|
||||
if delta_dims is None or len(delta_dims) == 0:
|
||||
return actions
|
||||
|
||||
relative_actions = actions.copy()
|
||||
for dim in relative_dims:
|
||||
relative_actions[dim] = actions[dim] - state[dim]
|
||||
delta_actions = actions.copy()
|
||||
for dim in delta_dims:
|
||||
delta_actions[dim] = actions[dim] - state[dim]
|
||||
|
||||
return relative_actions
|
||||
return delta_actions
|
||||
|
||||
|
||||
def apply_normalization(
|
||||
@@ -187,7 +185,7 @@ def apply_normalization(
|
||||
|
||||
def process_episode(args):
|
||||
"""Process single episode and return action chunks."""
|
||||
dataset, ep_idx, action_horizon, relative_dims, sample_fraction, state_key, use_relative_transform = args
|
||||
dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args
|
||||
|
||||
try:
|
||||
# get episode info
|
||||
@@ -206,15 +204,15 @@ def process_episode(args):
|
||||
|
||||
for abs_idx in range(from_idx, to_idx):
|
||||
# map absolute index to relative index if needed
|
||||
if dataset.reader._absolute_to_relative_idx is not None:
|
||||
if abs_idx not in dataset.reader._absolute_to_relative_idx:
|
||||
if dataset._absolute_to_relative_idx is not None:
|
||||
if abs_idx not in dataset._absolute_to_relative_idx:
|
||||
# this episode's frames aren't in the filtered dataset
|
||||
return None
|
||||
rel_idx = dataset.reader._absolute_to_relative_idx[abs_idx]
|
||||
rel_idx = dataset._absolute_to_relative_idx[abs_idx]
|
||||
else:
|
||||
rel_idx = abs_idx
|
||||
|
||||
frame = dataset.get_raw_item(rel_idx)
|
||||
frame = dataset.hf_dataset[rel_idx]
|
||||
|
||||
# get state (could be from observation.state or other state key)
|
||||
if state_key in frame:
|
||||
@@ -224,7 +222,7 @@ def process_episode(args):
|
||||
else np.array(frame[state_key])
|
||||
)
|
||||
else:
|
||||
# if no state key, use zeros (no relative transform)
|
||||
# if no state key, use zeros (no delta transform)
|
||||
state = np.zeros_like(
|
||||
frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION])
|
||||
)
|
||||
@@ -245,18 +243,18 @@ def process_episode(args):
|
||||
current_state = states[i] # First state in chunk
|
||||
future_absolute_actions = actions[i : i + action_horizon]
|
||||
|
||||
if use_relative_transform:
|
||||
if use_delta_transform:
|
||||
# relative actions
|
||||
relative_chunk = np.zeros_like(future_absolute_actions)
|
||||
delta_chunk = np.zeros_like(future_absolute_actions)
|
||||
for t in range(action_horizon):
|
||||
relative_chunk[t] = apply_relative_transform(
|
||||
delta_chunk[t] = apply_delta_transform(
|
||||
current_state,
|
||||
future_absolute_actions[t],
|
||||
relative_dims,
|
||||
delta_dims,
|
||||
)
|
||||
action_chunks.append(relative_chunk)
|
||||
action_chunks.append(delta_chunk)
|
||||
else:
|
||||
# absolute actions (no relative transform)
|
||||
# absolute actions (no delta)
|
||||
action_chunks.append(future_absolute_actions)
|
||||
|
||||
if len(action_chunks) == 0:
|
||||
@@ -409,20 +407,17 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
|
||||
total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges)
|
||||
print(f"Encoding {total_encoded_dims} dimensions: {cfg.encoded_dims}")
|
||||
|
||||
# parse relative dimensions
|
||||
relative_dim_list = None
|
||||
if cfg.relative_dims is not None and cfg.relative_dims.strip():
|
||||
relative_dim_list = [int(d.strip()) for d in cfg.relative_dims.split(",")]
|
||||
print(f"Relative dimensions: {relative_dim_list}")
|
||||
# parse delta dimensions
|
||||
delta_dim_list = None
|
||||
if cfg.delta_dims is not None and cfg.delta_dims.strip():
|
||||
delta_dim_list = [int(d.strip()) for d in cfg.delta_dims.split(",")]
|
||||
print(f"Delta dimensions: {delta_dim_list}")
|
||||
else:
|
||||
print("No relative dimensions specified")
|
||||
print("No delta dimensions specified")
|
||||
|
||||
print(f"Use relative transform: {cfg.use_relative_transform}")
|
||||
if cfg.use_relative_transform and (relative_dim_list is None or len(relative_dim_list) == 0):
|
||||
print(
|
||||
"Warning: use_relative_transform=True but no relative_dims specified. "
|
||||
"No relative transform will be applied."
|
||||
)
|
||||
print(f"Use delta transform: {cfg.use_delta_transform}")
|
||||
if cfg.use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0):
|
||||
print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.")
|
||||
|
||||
print(f"Action horizon: {cfg.action_horizon}")
|
||||
print(f"State key: {cfg.state_key}")
|
||||
@@ -445,10 +440,10 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
|
||||
dataset,
|
||||
ep_idx,
|
||||
cfg.action_horizon,
|
||||
relative_dim_list,
|
||||
delta_dim_list,
|
||||
cfg.sample_fraction,
|
||||
cfg.state_key,
|
||||
cfg.use_relative_transform,
|
||||
cfg.use_delta_transform,
|
||||
)
|
||||
)
|
||||
if chunks is not None:
|
||||
@@ -549,9 +544,9 @@ def train_tokenizer(cfg: TokenizerTrainingConfig):
|
||||
"encoded_dims": cfg.encoded_dims,
|
||||
"encoded_dim_ranges": encoded_dim_ranges,
|
||||
"total_encoded_dims": total_encoded_dims,
|
||||
"relative_dims": cfg.relative_dims,
|
||||
"relative_dim_list": relative_dim_list,
|
||||
"use_relative_transform": cfg.use_relative_transform,
|
||||
"delta_dims": cfg.delta_dims,
|
||||
"delta_dim_list": delta_dim_list,
|
||||
"use_delta_transform": cfg.use_delta_transform,
|
||||
"state_key": cfg.state_key,
|
||||
"normalization_mode": norm_mode.value,
|
||||
"action_horizon": cfg.action_horizon,
|
||||
|
||||
@@ -65,10 +65,6 @@ if "LEROBOT_HOME" in os.environ:
|
||||
# cache dir
|
||||
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||
# LeRobot's own revision-safe Hub cache (NOT the system-wide ~/.cache/huggingface/hub/).
|
||||
# Used as the ``cache_dir`` argument to ``snapshot_download`` so that different
|
||||
# dataset revisions are stored in isolated snapshot directories.
|
||||
HF_LEROBOT_HUB_CACHE = HF_LEROBOT_HOME / "hub"
|
||||
|
||||
# calibration dir
|
||||
default_calibration_path = HF_LEROBOT_HOME / "calibration"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user