mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Compare commits
73 Commits
feat/eval-
...
fix/add-xv
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7e2bab392b | ||
|
|
2051cc6908 | ||
|
|
9ee793be34 | ||
|
|
d3e5af007d | ||
|
|
174588cd18 | ||
|
|
8e633bf7d9 | ||
|
|
18fd4f740c | ||
|
|
8f59d93458 | ||
|
|
d4e6d60ec3 | ||
|
|
4ad41f7a76 | ||
|
|
9cdf46bd3d | ||
|
|
d22fa47ac0 | ||
|
|
602fb7bf36 | ||
|
|
5a9f3e2555 | ||
|
|
ac1de3719c | ||
|
|
0b326053e9 | ||
|
|
ca4b3d035b | ||
|
|
863ae89ff2 | ||
|
|
fbcf118dcb | ||
|
|
171d50e854 | ||
|
|
1f00978b2a | ||
|
|
825146d218 | ||
|
|
81cf4d8ed5 | ||
|
|
15dc2fd867 | ||
|
|
4e9acd4afe | ||
|
|
f62cfc9ca2 | ||
|
|
829428ac81 | ||
|
|
066fb1bd5d | ||
|
|
abaf870e00 | ||
|
|
6d2166cf04 | ||
|
|
2044e52e36 | ||
|
|
0e21f3fdf7 | ||
|
|
936a6728f0 | ||
|
|
722766b825 | ||
|
|
8f2321af27 | ||
|
|
5052d4d70b | ||
|
|
15188b0cf8 | ||
|
|
90627ca85b | ||
|
|
8ed2755a59 | ||
|
|
e61722fa78 | ||
|
|
a3a5cb1bac | ||
|
|
0ccc60f20b | ||
|
|
9d13b6ceea | ||
|
|
7cfe4c768f | ||
|
|
119ee85dab | ||
|
|
70582ed226 | ||
|
|
99b0722425 | ||
|
|
9c6c8d075b | ||
|
|
efacf8f0e0 | ||
|
|
b16bc5f1ff | ||
|
|
a6404f61e1 | ||
|
|
9896ba4ee4 | ||
|
|
8591fc10b3 | ||
|
|
42d615b69d | ||
|
|
858626dea5 | ||
|
|
5277a9909d | ||
|
|
fb6f59e074 | ||
|
|
f3b25eb425 | ||
|
|
cb7d2ed0fc | ||
|
|
f4547299e4 | ||
|
|
a28a74e43c | ||
|
|
ab763abff3 | ||
|
|
818c75713b | ||
|
|
589788e760 | ||
|
|
cde2e24d79 | ||
|
|
b928c123fb | ||
|
|
f52cf79d8e | ||
|
|
39260a581a | ||
|
|
2219c29690 | ||
|
|
8d9a992953 | ||
|
|
3cb14248a4 | ||
|
|
8a65623dec | ||
|
|
d9e4d374c5 |
11
.github/workflows/fast_tests.yml
vendored
11
.github/workflows/fast_tests.yml
vendored
@@ -60,12 +60,17 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
|
HF_HOME: /mnt/cache/.cache/huggingface
|
||||||
|
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
lfs: true
|
lfs: true
|
||||||
|
|
||||||
|
- name: Setup /mnt storage
|
||||||
|
run: sudo chown -R $USER:$USER /mnt
|
||||||
|
|
||||||
# TODO(Steven): Evaluate the need of these dependencies
|
# TODO(Steven): Evaluate the need of these dependencies
|
||||||
- name: Install apt dependencies
|
- name: Install apt dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -80,8 +85,14 @@ jobs:
|
|||||||
version: ${{ env.UV_VERSION }}
|
version: ${{ env.UV_VERSION }}
|
||||||
python-version: ${{ env.PYTHON_VERSION }}
|
python-version: ${{ env.PYTHON_VERSION }}
|
||||||
|
|
||||||
|
- name: Check disk usage
|
||||||
|
run: df -h
|
||||||
|
|
||||||
- name: Install lerobot with test extras
|
- name: Install lerobot with test extras
|
||||||
run: uv sync --extra "test"
|
run: uv sync --extra "test"
|
||||||
|
|
||||||
|
- name: Check disk usage
|
||||||
|
run: df -h
|
||||||
|
|
||||||
- name: Run pytest
|
- name: Run pytest
|
||||||
run: uv run pytest tests -vv --maxfail=10
|
run: uv run pytest tests -vv --maxfail=10
|
||||||
|
|||||||
14
.github/workflows/full_tests.yml
vendored
14
.github/workflows/full_tests.yml
vendored
@@ -58,12 +58,17 @@ jobs:
|
|||||||
github.event_name == 'workflow_dispatch'
|
github.event_name == 'workflow_dispatch'
|
||||||
env:
|
env:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
|
HF_HOME: /mnt/cache/.cache/huggingface
|
||||||
|
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
lfs: true
|
lfs: true
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
|
||||||
|
- name: Setup /mnt storage
|
||||||
|
run: sudo chown -R $USER:$USER /mnt
|
||||||
|
|
||||||
- name: Install apt dependencies
|
- name: Install apt dependencies
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||||
@@ -80,12 +85,21 @@ jobs:
|
|||||||
- name: Install lerobot with all extras
|
- name: Install lerobot with all extras
|
||||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||||
|
|
||||||
|
- name: Check disk usage
|
||||||
|
run: df -h
|
||||||
|
|
||||||
- name: Run pytest (all extras)
|
- name: Run pytest (all extras)
|
||||||
run: uv run pytest tests -vv --maxfail=10
|
run: uv run pytest tests -vv --maxfail=10
|
||||||
|
|
||||||
|
- name: Check disk usage
|
||||||
|
run: df -h
|
||||||
|
|
||||||
- name: Run end-to-end tests
|
- name: Run end-to-end tests
|
||||||
run: uv run make test-end-to-end
|
run: uv run make test-end-to-end
|
||||||
|
|
||||||
|
- name: Check disk usage
|
||||||
|
run: df -h
|
||||||
|
|
||||||
# This job builds a GPU enabled image for testing
|
# This job builds a GPU enabled image for testing
|
||||||
# It runs everytime a PR is approved or a push to main
|
# It runs everytime a PR is approved or a push to main
|
||||||
# TODO(Steven): For now we skip this job for community PRs
|
# TODO(Steven): For now we skip this job for community PRs
|
||||||
|
|||||||
4
.github/workflows/unbound_deps_tests.yml
vendored
4
.github/workflows/unbound_deps_tests.yml
vendored
@@ -45,11 +45,15 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
MUJOCO_GL: egl
|
MUJOCO_GL: egl
|
||||||
|
HF_HOME: /mnt/cache/.cache/huggingface
|
||||||
|
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
lfs: true
|
lfs: true
|
||||||
persist-credentials: false
|
persist-credentials: false
|
||||||
|
- name: Setup /mnt storage
|
||||||
|
run: sudo chown -R $USER:$USER /mnt
|
||||||
|
|
||||||
- name: Install apt dependencies
|
- name: Install apt dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -37,6 +37,8 @@
|
|||||||
title: π₀.₅ (Pi05)
|
title: π₀.₅ (Pi05)
|
||||||
- local: groot
|
- local: groot
|
||||||
title: NVIDIA GR00T N1.5
|
title: NVIDIA GR00T N1.5
|
||||||
|
- local: xvla
|
||||||
|
title: X-VLA
|
||||||
title: "Policies"
|
title: "Policies"
|
||||||
- sections:
|
- sections:
|
||||||
- local: async
|
- local: async
|
||||||
|
|||||||
@@ -62,6 +62,11 @@ lerobot-eval \
|
|||||||
|
|
||||||
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
||||||
|
|
||||||
|
### Control Mode
|
||||||
|
|
||||||
|
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
|
||||||
|
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
|
||||||
|
|
||||||
### Policy inputs and outputs
|
### Policy inputs and outputs
|
||||||
|
|
||||||
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
||||||
|
|||||||
543
docs/source/xvla.mdx
Normal file
543
docs/source/xvla.mdx
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
# X-VLA: The First Soft-Prompted Robot Foundation Model for Any Robot, Any Task
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
For years, robotics has aspired to build agents that can follow natural human instructions and operate dexterously across many environments and robot bodies. Recent breakthroughs in LLMs and VLMs suggest a path forward: extend these foundation-model architectures to embodied control by grounding them in actions. This has led to the rise of Vision-Language-Action (VLA) models, with the hope that a single generalist model could combine broad semantic understanding with robust manipulation skills.
|
||||||
|
|
||||||
|
But training such models is difficult. Robot data is fragmented across platforms, sensors, embodiments, and collection protocols. Heterogeneity appears everywhere: different arm configurations, different action spaces, different camera setups, different visual domains, and different task distributions. These inconsistencies create major distribution shifts that make pretraining unstable and adaptation unreliable.
|
||||||
|
|
||||||
|
Inspired by meta-learning and prompt learning, we ask: **"What if a VLA model could learn the structure of each robot and dataset the same way LLMs learn tasks, through prompts?"**
|
||||||
|
|
||||||
|
**X-VLA** is a soft-prompted, flow-matching VLA framework that treats each hardware setup as a "task" and encodes it using a small set of learnable embeddings. These **Soft Prompts** capture embodiment and domain-specific variations, guiding the Transformer from the earliest stages of multimodal fusion. With this mechanism, X-VLA can reconcile diverse robot morphologies, data types, and sensor setups within a single unified architecture.
|
||||||
|
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png" width="400">
|
||||||
|
|
||||||
|
Built from pure Transformer encoders, X-VLA scales naturally with model size and dataset diversity. Across 6 simulation benchmarks and 3 real robots, Soft Prompts consistently outperform existing methods in handling hardware and domain differences. X-VLA-0.9B, trained on 290K episodes spanning seven robotic platforms, learns an embodiment-agnostic generalist policy in Phase I, and adapts efficiently to new robots in Phase II simply by learning a new set of prompts, while keeping the backbone frozen.
|
||||||
|
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png" width="400">
|
||||||
|
|
||||||
|
With only 1% of parameters tuned (9M), X-VLA-0.9B achieves near-π₀ performance on LIBERO and Simpler-WidowX, despite using **300× fewer trainable parameters**. It also demonstrates strong real-world dexterity with minimal demonstrations, including folding cloths in under two minutes.
|
||||||
|
|
||||||
|
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-fold.png" width="400">
|
||||||
|
|
||||||
|
X-VLA shows that generalist robot intelligence does not require increasingly complex architectures, only the right way to absorb heterogeneity. Soft Prompts offer a simple, scalable mechanism for unifying diverse robotic data, paving the way toward adaptable, cross-embodiment robot foundation models.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
After installing LeRobot, install the X-VLA dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .[xvla]
|
||||||
|
```
|
||||||
|
|
||||||
|
After the new release, you'll be able to do:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install lerobot[xvla]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
To use X-VLA in your LeRobot configuration, specify the policy type as:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
policy.type=xvla
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluating Pre-trained Checkpoints
|
||||||
|
|
||||||
|
Example evaluation with LIBERO:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path="lerobot/xvla-libero" \
|
||||||
|
--env.type=libero \
|
||||||
|
--env.task=libero_spatial,libero_goal,libero_10 \
|
||||||
|
--env.control_mode=absolute \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=1 \
|
||||||
|
--env.episode_length=800 \
|
||||||
|
--seed=142
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Available Checkpoints
|
||||||
|
|
||||||
|
### 🎯 Base Model
|
||||||
|
|
||||||
|
**[lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base)**
|
||||||
|
|
||||||
|
A 0.9B parameter instantiation of X-VLA, trained with a carefully designed data processing and learning recipe. The training pipeline consists of two phases:
|
||||||
|
|
||||||
|
- **Phase I: Pretraining** - Pretrained on 290K episodes from Droid, Robomind, and Agibot, spanning seven platforms across five types of robotic arms (single-arm to bi-manual setups). By leveraging soft prompts to absorb embodiment-specific variations, the model learns an embodiment-agnostic generalist policy.
|
||||||
|
|
||||||
|
- **Phase II: Domain Adaptation** - Adapted to deployable policies for target domains. A new set of soft prompts is introduced and optimized to encode the hardware configuration of the novel domain, while the pretrained backbone remains frozen.
|
||||||
|
|
||||||
|
### 🎮 Simulation Checkpoints
|
||||||
|
|
||||||
|
**[lerobot/xvla-libero](https://huggingface.co/lerobot/xvla-libero)**
|
||||||
|
|
||||||
|
Achieves 93% success rate on LIBERO benchmarks. Fine-tuned from the base model for simulation tasks.
|
||||||
|
|
||||||
|
**[lerobot/xvla-widowx](https://huggingface.co/lerobot/xvla-widowx)**
|
||||||
|
|
||||||
|
Fine-tuned on BridgeData for pick-and-place experiments on compact WidowX platforms. Demonstrates robust manipulation capabilities.
|
||||||
|
|
||||||
|
### 🤖 Real-World Checkpoints
|
||||||
|
|
||||||
|
**[lerobot/xvla-folding](https://huggingface.co/lerobot/xvla-folding)**
|
||||||
|
|
||||||
|
A fine-tuned dexterous manipulation model trained on the high-quality Soft-FOLD cloth folding dataset. Achieves 100% success rate over 2 hours of continuous cloth folding.
|
||||||
|
|
||||||
|
**[lerobot/xvla-agibot-world](https://huggingface.co/lerobot/xvla-agibot-world)**
|
||||||
|
|
||||||
|
Optimized for AgileX robot dexterous manipulation tasks.
|
||||||
|
|
||||||
|
**[lerobot/xvla-google-robot](https://huggingface.co/lerobot/xvla-google-robot)**
|
||||||
|
|
||||||
|
Adapted for Google Robot platforms.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Training X-VLA
|
||||||
|
|
||||||
|
### Recommended Training Configuration
|
||||||
|
|
||||||
|
When fine-tuning X-VLA for a new embodiment or task, we recommend the following freezing strategy:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=YOUR_DATASET \
|
||||||
|
--output_dir=./outputs/xvla_training \
|
||||||
|
--job_name=xvla_training \
|
||||||
|
--policy.path="lerobot/xvla-base" \
|
||||||
|
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||||
|
--steps=3000 \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.freeze_vision_encoder=True \
|
||||||
|
--policy.freeze_language_encoder=True \
|
||||||
|
--policy.train_policy_transformer=True \
|
||||||
|
--policy.train_soft_prompts=True \
|
||||||
|
--policy.action_mode=YOUR_ACTION_MODE
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Parameters Explained
|
||||||
|
|
||||||
|
| Parameter | Default | Description |
|
||||||
|
| -------------------------- | ------- | ---------------------------------------- |
|
||||||
|
| `freeze_vision_encoder` | `True` | Freeze the VLM vision encoder weights |
|
||||||
|
| `freeze_language_encoder` | `True` | Freeze the VLM language encoder weights |
|
||||||
|
| `train_policy_transformer` | `True` | Allow policy transformer layers to train |
|
||||||
|
| `train_soft_prompts` | `True` | Allow soft prompts to train |
|
||||||
|
|
||||||
|
**💡 Best Practice**: For Phase II adaptation to new embodiments, freeze the VLM encoders and only train the policy transformer and soft prompts. This provides excellent sample efficiency with minimal compute.
|
||||||
|
|
||||||
|
### Example: Training on Bimanual Robot
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||||
|
--output_dir=./outputs/xvla_bimanual \
|
||||||
|
--job_name=xvla_so101_training \
|
||||||
|
--policy.path="lerobot/xvla-base" \
|
||||||
|
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
|
||||||
|
--steps=3000 \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.action_mode=so101_bimanual \
|
||||||
|
--policy.freeze_vision_encoder=True \
|
||||||
|
--policy.freeze_language_encoder=True \
|
||||||
|
--policy.train_policy_transformer=True \
|
||||||
|
--policy.train_soft_prompts=True
|
||||||
|
```
|
||||||
|
|
||||||
|
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
|
||||||
|
|
||||||
|
**🔥 Full-finetune all components with a custom learning-rate scheme**
|
||||||
|
|
||||||
|
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
|
||||||
|
This LR ratio is crucial for achieving strong and stable finetuning performance.
|
||||||
|
To enable this behavior, you must:
|
||||||
|
|
||||||
|
1. Implement a custom optimizer and register it in your training config
|
||||||
|
|
||||||
|
```
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from lerobot.optim.optimizers import OptimizerConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
@OptimizerConfig.register_subclass("xvla-adamw")
|
||||||
|
@dataclass
|
||||||
|
class XVLAAdamW(OptimizerConfig):
|
||||||
|
lr: float = 1e-4
|
||||||
|
betas: tuple[float, float] = (0.9, 0.99)
|
||||||
|
eps: float = 1e-8
|
||||||
|
weight_decay: float = 0.0
|
||||||
|
grad_clip_norm: float = 10.0
|
||||||
|
|
||||||
|
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||||
|
"""
|
||||||
|
Expect `named_parameters()` as input.
|
||||||
|
Apply lr = lr / 10 for all VLM-related parameters.
|
||||||
|
"""
|
||||||
|
assert isinstance(params, dict), \
|
||||||
|
"Custom LR optimizer requires `named_parameters()` as inputs."
|
||||||
|
kwargs = asdict(self)
|
||||||
|
kwargs.pop("grad_clip_norm")
|
||||||
|
vlm_group, other_group = [], []
|
||||||
|
for name, p in params.items():
|
||||||
|
if not p.requires_grad:
|
||||||
|
continue
|
||||||
|
if "vlm" in name.lower():
|
||||||
|
vlm_group.append(p)
|
||||||
|
else:
|
||||||
|
other_group.append(p)
|
||||||
|
|
||||||
|
param_groups = [
|
||||||
|
{"params": vlm_group, "lr": self.lr * 0.1, "weight_decay": self.weight_decay * 0.1},
|
||||||
|
{"params": other_group, "lr": self.lr, "weight_decay": self.weight_decay},
|
||||||
|
]
|
||||||
|
|
||||||
|
return torch.optim.AdamW(param_groups, **kwargs)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Modify X-VLA’s get_optim_params to return named parameters
|
||||||
|
|
||||||
|
Replace:
|
||||||
|
|
||||||
|
```
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
"""Return only trainable parameters for optimization."""
|
||||||
|
return filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
```
|
||||||
|
|
||||||
|
with:
|
||||||
|
|
||||||
|
```
|
||||||
|
def get_optim_params(self):
|
||||||
|
"""Return trainable named parameters."""
|
||||||
|
return filter(lambda kv: kv[1].requires_grad, self.named_parameters())
|
||||||
|
```
|
||||||
|
|
||||||
|
This ensures the optimizer receives a dict of named parameters, allowing it to correctly detect VLM modules and apply the 1/10 LR rule.
|
||||||
|
|
||||||
|
❕Note
|
||||||
|
|
||||||
|
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||||
|
We encourage implementing this in your customized training pipeline for optimal results.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Core Concepts
|
||||||
|
|
||||||
|
### 1. Action Modes
|
||||||
|
|
||||||
|
X-VLA uses an **Action Registry** system to handle different action spaces and embodiments. The `action_mode` parameter defines how actions are processed, what loss functions are used, and how predictions are post-processed.
|
||||||
|
|
||||||
|
#### Available Action Modes
|
||||||
|
|
||||||
|
| Action Mode | Action Dim | Description | Use Case |
|
||||||
|
| ---------------- | --------------------- | ------------------------------------------- | ------------------------------------ |
|
||||||
|
| `ee6d` | 20 | End-effector with xyz, 6D rotation, gripper | Dual-arm setups with spatial control |
|
||||||
|
| `joint` | 14 | Joint-space with gripper | Direct joint control robots |
|
||||||
|
| `agibot_ee6d` | 20 | AGI-bot variant with MSE loss | AGI-bot platforms |
|
||||||
|
| `franka_joint7` | 7 | Franka Panda 7-joint control | Franka robots without gripper |
|
||||||
|
| `so101_bimanual` | 20 (model), 12 (real) | SO101 bimanual robot | Bimanual manipulation tasks |
|
||||||
|
|
||||||
|
#### Why Action Modes Matter
|
||||||
|
|
||||||
|
When you have a pretrained checkpoint like `lerobot/xvla-base` trained with `action_dim=20`, and you want to train on a dataset with a different action dimension (e.g., 14 for bimanual arms), you can't simply trim the action dimension. The action mode orchestrates:
|
||||||
|
|
||||||
|
1. **Loss Computation**: Different loss functions for different action components (MSE for joints, BCE for grippers, etc.)
|
||||||
|
2. **Preprocessing**: Zeroing out gripper channels, padding dimensions
|
||||||
|
3. **Postprocessing**: Applying sigmoid to gripper logits, trimming padding
|
||||||
|
|
||||||
|
#### Example: BimanualSO101 Action Space
|
||||||
|
|
||||||
|
The `so101_bimanual` action mode handles the mismatch between model output (20D) and real robot control (12D):
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Model outputs 20 dimensions for compatibility
|
||||||
|
dim_action = 20
|
||||||
|
|
||||||
|
# Real robot only needs 12 dimensions
|
||||||
|
# [left_arm (6), right_arm (6)] = [joints (5) + gripper (1)] × 2
|
||||||
|
REAL_DIM = 12
|
||||||
|
|
||||||
|
# Preprocessing: Pad 12D actions to 20D for training
|
||||||
|
# Postprocessing: Trim 20D predictions to 12D for deployment
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||||
|
|
||||||
|
### 2. Domain IDs
|
||||||
|
|
||||||
|
Domain IDs are learnable identifiers for different robot configurations and camera setups. They allow X-VLA to distinguish between:
|
||||||
|
|
||||||
|
- Different robots (Robot 1 vs Robot 2)
|
||||||
|
- Different camera configurations (cam1 vs cam2)
|
||||||
|
- Different combinations (Robot1-cam1-cam2 vs Robot1-cam1 vs Robot2-cam1)
|
||||||
|
|
||||||
|
#### Setting Domain IDs
|
||||||
|
|
||||||
|
**During Training**: By default, domain_id is set to 0 for general training.
|
||||||
|
|
||||||
|
**During Evaluation**: Specify the domain_id that matches your checkpoint's training configuration.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Example: LIBERO checkpoint uses domain_id=3
|
||||||
|
domain_id = 3
|
||||||
|
```
|
||||||
|
|
||||||
|
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
|
||||||
|
|
||||||
|
### 3. Processor Steps
|
||||||
|
|
||||||
|
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
|
||||||
|
|
||||||
|
#### Required Preprocessing Steps
|
||||||
|
|
||||||
|
1. **XVLAImageToFloatProcessorStep**: Converts images from [0, 255] to [0, 1] range
|
||||||
|
2. **XVLAImageNetNormalizeProcessorStep**: Applies ImageNet normalization (required for VLM backbone)
|
||||||
|
3. **XVLAAddDomainIdProcessorStep**: Adds domain_id to observations
|
||||||
|
|
||||||
|
#### Example Custom Processor
|
||||||
|
|
||||||
|
For LIBERO environments, a custom processor handles the specific observation format:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.policies.xvla.processor_xvla import LiberoProcessorStep
|
||||||
|
|
||||||
|
processor = LiberoProcessorStep()
|
||||||
|
# Handles robot_state dictionary, converts rotation matrices to 6D representation
|
||||||
|
# Applies 180° image rotation for camera convention
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Configuration Parameters
|
||||||
|
|
||||||
|
Key configuration parameters for X-VLA:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Observation and action
|
||||||
|
n_obs_steps: int = 1 # Number of observation timesteps
|
||||||
|
chunk_size: int = 32 # Action sequence length
|
||||||
|
n_action_steps: int = 32 # Number of action steps to execute
|
||||||
|
|
||||||
|
# Model architecture
|
||||||
|
hidden_size: int = 1024 # Transformer hidden dimension
|
||||||
|
depth: int = 24 # Number of transformer layers
|
||||||
|
num_heads: int = 16 # Number of attention heads
|
||||||
|
num_domains: int = 30 # Maximum number of domain IDs
|
||||||
|
len_soft_prompts: int = 32 # Length of soft prompt embeddings
|
||||||
|
|
||||||
|
# Action space
|
||||||
|
action_mode: str = "ee6d" # Action space type
|
||||||
|
use_proprio: bool = True # Use proprioceptive state
|
||||||
|
max_state_dim: int = 32 # Maximum state dimension
|
||||||
|
|
||||||
|
# Vision
|
||||||
|
num_image_views: int | None # Number of camera views
|
||||||
|
resize_imgs_with_padding: tuple[int, int] | None # Target image size with padding
|
||||||
|
|
||||||
|
# Training
|
||||||
|
num_denoising_steps: int = 10 # Flow matching denoising steps
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Creating Custom Action Modes
|
||||||
|
|
||||||
|
If your robot has a unique action space, you can create a custom action mode:
|
||||||
|
|
||||||
|
### Step 1: Define Your Action Space
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.policies.xvla.action_hub import BaseActionSpace, register_action
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@register_action("my_custom_robot")
|
||||||
|
class MyCustomActionSpace(BaseActionSpace):
|
||||||
|
"""Custom action space for my robot."""
|
||||||
|
|
||||||
|
dim_action = 15 # Your robot's action dimension
|
||||||
|
gripper_idx = (7, 14) # Gripper channel indices
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
"""Define your loss computation."""
|
||||||
|
# Example: MSE for joints, BCE for grippers
|
||||||
|
joints_loss = self.mse(pred[:, :, :7], target[:, :, :7])
|
||||||
|
gripper_loss = self.bce(pred[:, :, self.gripper_idx],
|
||||||
|
target[:, :, self.gripper_idx])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"joints_loss": joints_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""Preprocess actions before training."""
|
||||||
|
# Example: Zero out grippers in proprioception
|
||||||
|
proprio_m = proprio.clone()
|
||||||
|
action_m = action.clone() if action is not None else None
|
||||||
|
proprio_m[..., self.gripper_idx] = 0.0
|
||||||
|
if action_m is not None:
|
||||||
|
action_m[..., self.gripper_idx] = 0.0
|
||||||
|
return proprio_m, action_m
|
||||||
|
|
||||||
|
def postprocess(self, action):
|
||||||
|
"""Post-process predictions for deployment."""
|
||||||
|
# Example: Apply sigmoid to gripper logits
|
||||||
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||||
|
return action
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 2: Use Your Custom Action Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.action_mode=my_custom_robot \
|
||||||
|
--dataset.repo_id=YOUR_DATASET \
|
||||||
|
--policy.path="lerobot/xvla-base" \
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Advanced Topics
|
||||||
|
|
||||||
|
### Multi-Camera Support
|
||||||
|
|
||||||
|
X-VLA supports multiple camera views through the `num_image_views` parameter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Configure for 3 camera views
|
||||||
|
policy.num_image_views=3
|
||||||
|
|
||||||
|
# Add empty cameras if you have fewer physical cameras
|
||||||
|
policy.empty_cameras=1 # Adds 1 zero-padded camera view
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom Preprocessing Pipeline
|
||||||
|
|
||||||
|
Create a custom preprocessing pipeline for your environment:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.processor import PolicyProcessorPipeline
|
||||||
|
from lerobot.policies.xvla.processor_xvla import (
|
||||||
|
XVLAImageToFloatProcessorStep,
|
||||||
|
XVLAImageNetNormalizeProcessorStep,
|
||||||
|
XVLAAddDomainIdProcessorStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build custom pipeline
|
||||||
|
preprocessor = PolicyProcessorPipeline(
|
||||||
|
steps=[
|
||||||
|
YourCustomProcessorStep(), # Your custom processing
|
||||||
|
XVLAImageToFloatProcessorStep(), # Required: convert to float
|
||||||
|
XVLAImageNetNormalizeProcessorStep(), # Required: ImageNet norm
|
||||||
|
XVLAAddDomainIdProcessorStep(domain_id=5), # Your domain ID
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Handling Different Action Dimensions
|
||||||
|
|
||||||
|
When your dataset has fewer action dimensions than the pretrained model:
|
||||||
|
|
||||||
|
**Option 1**: Use padding (automatic in most action modes)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Model expects 20D, dataset has 12D
|
||||||
|
# Action mode handles padding internally
|
||||||
|
action_mode = "so101_bimanual" # Pads 12 → 20
|
||||||
|
```
|
||||||
|
|
||||||
|
**Option 2**: Create a custom action mode that maps dimensions explicitly
|
||||||
|
|
||||||
|
```python
|
||||||
|
@register_action("my_mapped_action")
|
||||||
|
class MappedActionSpace(BaseActionSpace):
|
||||||
|
dim_action = 20
|
||||||
|
REAL_DIM = 12
|
||||||
|
|
||||||
|
def _pad_to_model_dim(self, x):
|
||||||
|
# Custom padding logic
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
**Issue**: "Action dimension mismatch"
|
||||||
|
|
||||||
|
- **Solution**: Check that your `action_mode` matches your robot's action space. Create a custom action mode if needed.
|
||||||
|
|
||||||
|
**Issue**: "Image values outside [0, 1] range"
|
||||||
|
|
||||||
|
- **Solution**: Ensure images are preprocessed with `XVLAImageToFloatProcessorStep` before normalization.
|
||||||
|
|
||||||
|
**Issue**: "Domain ID not found"
|
||||||
|
|
||||||
|
- **Solution**: Make sure `XVLAAddDomainIdProcessorStep` is in your preprocessing pipeline with the correct domain_id.
|
||||||
|
|
||||||
|
**Issue**: "Low success rate on new embodiment"
|
||||||
|
|
||||||
|
- **Solution**:
|
||||||
|
1. Verify your action_mode is correct
|
||||||
|
2. Check that soft prompts are being trained (`train_soft_prompts=True`)
|
||||||
|
3. Ensure proper preprocessing (ImageNet normalization, domain_id)
|
||||||
|
4. Consider increasing training steps
|
||||||
|
|
||||||
|
**Issue**: "Out of memory during training"
|
||||||
|
|
||||||
|
- **Solution**:
|
||||||
|
1. Reduce `chunk_size` (e.g., from 32 to 16)
|
||||||
|
2. Enable gradient checkpointing
|
||||||
|
3. Reduce batch size
|
||||||
|
4. Freeze more components
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use X-VLA in your research, please cite:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{zheng2025x,
|
||||||
|
title = {X-VLA: Soft-Prompted Transformer as Scalable Cross-Embodiment Vision-Language-Action Model},
|
||||||
|
author = {Zheng, Jinliang and Li, Jianxiong and Wang, Zhihao and Liu, Dongxiu and Kang, Xirui
|
||||||
|
and Feng, Yuchun and Zheng, Yinan and Zou, Jiayin and Chen, Yilun and Zeng, Jia and others},
|
||||||
|
journal = {arXiv preprint arXiv:2510.10274},
|
||||||
|
year = {2025}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Additional Resources
|
||||||
|
|
||||||
|
- [X-VLA Paper](https://arxiv.org) (coming soon)
|
||||||
|
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||||
|
- [Action Registry Implementation](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py)
|
||||||
|
- [Processor Implementation](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
|
||||||
|
- [Model Configuration](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
We welcome contributions! If you've implemented a new action mode or processor for your robot, please consider submitting a PR to help the community.
|
||||||
@@ -129,6 +129,7 @@ groot = [
|
|||||||
"ninja>=1.11.1,<2.0.0",
|
"ninja>=1.11.1,<2.0.0",
|
||||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||||
]
|
]
|
||||||
|
xvla = ["lerobot[transformers-dep]"]
|
||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
@@ -157,6 +158,7 @@ all = [
|
|||||||
"lerobot[pi]",
|
"lerobot[pi]",
|
||||||
"lerobot[smolvla]",
|
"lerobot[smolvla]",
|
||||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||||
|
"lerobot[xvla]",
|
||||||
"lerobot[hilserl]",
|
"lerobot[hilserl]",
|
||||||
"lerobot[async]",
|
"lerobot[async]",
|
||||||
"lerobot[dev]",
|
"lerobot[dev]",
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
|||||||
class LiberoEnv(EnvConfig):
|
class LiberoEnv(EnvConfig):
|
||||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||||
fps: int = 30
|
fps: int = 30
|
||||||
episode_length: int = 520
|
episode_length: int | None = None
|
||||||
obs_type: str = "pixels_agent_pos"
|
obs_type: str = "pixels_agent_pos"
|
||||||
render_mode: str = "rgb_array"
|
render_mode: str = "rgb_array"
|
||||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||||
@@ -272,6 +272,7 @@ class LiberoEnv(EnvConfig):
|
|||||||
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
control_mode: str = "relative" # or "absolute"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.obs_type == "pixels":
|
if self.obs_type == "pixels":
|
||||||
|
|||||||
@@ -19,8 +19,10 @@ from typing import Any
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.registration import registry as gym_registry
|
from gymnasium.envs.registration import registry as gym_registry
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
from lerobot.processor import ProcessorStep
|
from lerobot.processor import ProcessorStep
|
||||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||||
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
|||||||
|
|
||||||
def make_env_pre_post_processors(
|
def make_env_pre_post_processors(
|
||||||
env_cfg: EnvConfig,
|
env_cfg: EnvConfig,
|
||||||
|
policy_cfg: PreTrainedConfig,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
@@ -61,6 +64,10 @@ def make_env_pre_post_processors(
|
|||||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||||
preprocessor_steps: list[ProcessorStep] = []
|
preprocessor_steps: list[ProcessorStep] = []
|
||||||
postprocessor_steps: list[ProcessorStep] = []
|
postprocessor_steps: list[ProcessorStep] = []
|
||||||
|
if isinstance(policy_cfg, XVLAConfig):
|
||||||
|
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||||
|
|
||||||
|
return make_xvla_libero_pre_post_processors()
|
||||||
|
|
||||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||||
@@ -136,6 +143,8 @@ def make_env(
|
|||||||
init_states=cfg.init_states,
|
init_states=cfg.init_states,
|
||||||
gym_kwargs=cfg.gym_kwargs,
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
env_cls=env_cls,
|
env_cls=env_cls,
|
||||||
|
control_mode=cfg.control_mode,
|
||||||
|
episode_length=cfg.episode_length,
|
||||||
)
|
)
|
||||||
elif "metaworld" in cfg.type:
|
elif "metaworld" in cfg.type:
|
||||||
from lerobot.envs.metaworld import create_metaworld_envs
|
from lerobot.envs.metaworld import create_metaworld_envs
|
||||||
|
|||||||
@@ -80,10 +80,7 @@ def get_libero_dummy_action():
|
|||||||
return [0, 0, 0, 0, 0, 0, -1]
|
return [0, 0, 0, 0, 0, 0, -1]
|
||||||
|
|
||||||
|
|
||||||
OBS_STATE_DIM = 8
|
|
||||||
ACTION_DIM = 7
|
ACTION_DIM = 7
|
||||||
AGENT_POS_LOW = -1000.0
|
|
||||||
AGENT_POS_HIGH = 1000.0
|
|
||||||
ACTION_LOW = -1.0
|
ACTION_LOW = -1.0
|
||||||
ACTION_HIGH = 1.0
|
ACTION_HIGH = 1.0
|
||||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||||
@@ -103,6 +100,7 @@ class LiberoEnv(gym.Env):
|
|||||||
task_suite: Any,
|
task_suite: Any,
|
||||||
task_id: int,
|
task_id: int,
|
||||||
task_suite_name: str,
|
task_suite_name: str,
|
||||||
|
episode_length: int | None = None,
|
||||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||||
obs_type: str = "pixels",
|
obs_type: str = "pixels",
|
||||||
render_mode: str = "rgb_array",
|
render_mode: str = "rgb_array",
|
||||||
@@ -114,6 +112,7 @@ class LiberoEnv(gym.Env):
|
|||||||
episode_index: int = 0,
|
episode_index: int = 0,
|
||||||
camera_name_mapping: dict[str, str] | None = None,
|
camera_name_mapping: dict[str, str] | None = None,
|
||||||
num_steps_wait: int = 10,
|
num_steps_wait: int = 10,
|
||||||
|
control_mode: str = "relative",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.task_id = task_id
|
self.task_id = task_id
|
||||||
@@ -141,14 +140,19 @@ class LiberoEnv(gym.Env):
|
|||||||
self.camera_name_mapping = camera_name_mapping
|
self.camera_name_mapping = camera_name_mapping
|
||||||
self.num_steps_wait = num_steps_wait
|
self.num_steps_wait = num_steps_wait
|
||||||
self.episode_index = episode_index
|
self.episode_index = episode_index
|
||||||
|
self.episode_length = episode_length
|
||||||
# Load once and keep
|
# Load once and keep
|
||||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||||
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||||
|
|
||||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||||
default_steps = 500
|
default_steps = 500
|
||||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
self._max_episode_steps = (
|
||||||
|
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||||
|
if self.episode_length is None
|
||||||
|
else self.episode_length
|
||||||
|
)
|
||||||
|
self.control_mode = control_mode
|
||||||
images = {}
|
images = {}
|
||||||
for cam in self.camera_name:
|
for cam in self.camera_name:
|
||||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||||
@@ -296,6 +300,15 @@ class LiberoEnv(gym.Env):
|
|||||||
# Increasing this value can improve determinism and reproducibility across resets.
|
# Increasing this value can improve determinism and reproducibility across resets.
|
||||||
for _ in range(self.num_steps_wait):
|
for _ in range(self.num_steps_wait):
|
||||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||||
|
|
||||||
|
if self.control_mode == "absolute":
|
||||||
|
for robot in self._env.robots:
|
||||||
|
robot.controller.use_delta = False
|
||||||
|
elif self.control_mode == "relative":
|
||||||
|
for robot in self._env.robots:
|
||||||
|
robot.controller.use_delta = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid control mode: {self.control_mode}")
|
||||||
observation = self._format_raw_obs(raw_obs)
|
observation = self._format_raw_obs(raw_obs)
|
||||||
info = {"is_success": False}
|
info = {"is_success": False}
|
||||||
return observation, info
|
return observation, info
|
||||||
@@ -341,8 +354,10 @@ def _make_env_fns(
|
|||||||
task_id: int,
|
task_id: int,
|
||||||
n_envs: int,
|
n_envs: int,
|
||||||
camera_names: list[str],
|
camera_names: list[str],
|
||||||
|
episode_length: int | None,
|
||||||
init_states: bool,
|
init_states: bool,
|
||||||
gym_kwargs: Mapping[str, Any],
|
gym_kwargs: Mapping[str, Any],
|
||||||
|
control_mode: str,
|
||||||
) -> list[Callable[[], LiberoEnv]]:
|
) -> list[Callable[[], LiberoEnv]]:
|
||||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||||
|
|
||||||
@@ -354,7 +369,9 @@ def _make_env_fns(
|
|||||||
task_suite_name=suite_name,
|
task_suite_name=suite_name,
|
||||||
camera_name=camera_names,
|
camera_name=camera_names,
|
||||||
init_states=init_states,
|
init_states=init_states,
|
||||||
|
episode_length=episode_length,
|
||||||
episode_index=episode_index,
|
episode_index=episode_index,
|
||||||
|
control_mode=control_mode,
|
||||||
**local_kwargs,
|
**local_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -374,6 +391,8 @@ def create_libero_envs(
|
|||||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||||
init_states: bool = True,
|
init_states: bool = True,
|
||||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||||
|
control_mode: str = "relative",
|
||||||
|
episode_length: int | None = None,
|
||||||
) -> dict[str, dict[int, Any]]:
|
) -> dict[str, dict[int, Any]]:
|
||||||
"""
|
"""
|
||||||
Create vectorized LIBERO environments with a consistent return shape.
|
Create vectorized LIBERO environments with a consistent return shape.
|
||||||
@@ -415,12 +434,14 @@ def create_libero_envs(
|
|||||||
for tid in selected:
|
for tid in selected:
|
||||||
fns = _make_env_fns(
|
fns = _make_env_fns(
|
||||||
suite=suite,
|
suite=suite,
|
||||||
|
episode_length=episode_length,
|
||||||
suite_name=suite_name,
|
suite_name=suite_name,
|
||||||
task_id=tid,
|
task_id=tid,
|
||||||
n_envs=n_envs,
|
n_envs=n_envs,
|
||||||
camera_names=camera_names,
|
camera_names=camera_names,
|
||||||
init_states=init_states,
|
init_states=init_states,
|
||||||
gym_kwargs=gym_kwargs,
|
gym_kwargs=gym_kwargs,
|
||||||
|
control_mode=control_mode,
|
||||||
)
|
)
|
||||||
out[suite_name][tid] = env_cls(fns)
|
out[suite_name][tid] = env_cls(fns)
|
||||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
|||||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||||
|
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ACTConfig",
|
"ACTConfig",
|
||||||
@@ -31,4 +32,5 @@ __all__ = [
|
|||||||
"TDMPCConfig",
|
"TDMPCConfig",
|
||||||
"VQBeTConfig",
|
"VQBeTConfig",
|
||||||
"GrootConfig",
|
"GrootConfig",
|
||||||
|
"XVLAConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
|||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.policies.utils import validate_visual_features_consistency
|
from lerobot.policies.utils import validate_visual_features_consistency
|
||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
batch_to_transition,
|
batch_to_transition,
|
||||||
@@ -107,6 +108,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||||
|
|
||||||
return GrootPolicy
|
return GrootPolicy
|
||||||
|
elif name == "xvla":
|
||||||
|
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||||
|
|
||||||
|
return XVLAPolicy
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||||
|
|
||||||
@@ -150,6 +155,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return RewardClassifierConfig(**kwargs)
|
return RewardClassifierConfig(**kwargs)
|
||||||
elif policy_type == "groot":
|
elif policy_type == "groot":
|
||||||
return GrootConfig(**kwargs)
|
return GrootConfig(**kwargs)
|
||||||
|
elif policy_type == "xvla":
|
||||||
|
return XVLAConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||||
|
|
||||||
@@ -329,6 +336,15 @@ def make_pre_post_processors(
|
|||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
elif isinstance(policy_cfg, XVLAConfig):
|
||||||
|
from lerobot.policies.xvla.processor_xvla import (
|
||||||
|
make_xvla_pre_post_processors,
|
||||||
|
)
|
||||||
|
|
||||||
|
processors = make_xvla_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||||
|
|||||||
6
src/lerobot/policies/xvla/__init__.py
Normal file
6
src/lerobot/policies/xvla/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# register the processor steps
|
||||||
|
from lerobot.policies.xvla.processor_xvla import (
|
||||||
|
XVLAAddDomainIdProcessorStep,
|
||||||
|
XVLAImageNetNormalizeProcessorStep,
|
||||||
|
XVLAImageToFloatProcessorStep,
|
||||||
|
)
|
||||||
454
src/lerobot/policies/xvla/action_hub.py
Normal file
454
src/lerobot/policies/xvla/action_hub.py
Normal file
@@ -0,0 +1,454 @@
|
|||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 2toINF and HuggingFace Inc. (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Registry
|
||||||
|
# =============================================================================
|
||||||
|
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_action(name: str):
|
||||||
|
"""Decorator for registering a new action space."""
|
||||||
|
|
||||||
|
def _wrap(cls):
|
||||||
|
key = name.lower()
|
||||||
|
if key in ACTION_REGISTRY:
|
||||||
|
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
|
||||||
|
ACTION_REGISTRY[key] = cls
|
||||||
|
cls.name = key
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return _wrap
|
||||||
|
|
||||||
|
|
||||||
|
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
|
||||||
|
"""Instantiate a registered action space by name."""
|
||||||
|
key = name.lower()
|
||||||
|
if key not in ACTION_REGISTRY:
|
||||||
|
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
|
||||||
|
return ACTION_REGISTRY[key](**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Base class
|
||||||
|
# =============================================================================
|
||||||
|
class BaseActionSpace(nn.Module):
|
||||||
|
"""
|
||||||
|
Abstract base class for all action-space definitions.
|
||||||
|
|
||||||
|
Each subclass defines:
|
||||||
|
- `dim_action`: dimension of the action vector.
|
||||||
|
- `gripper_idx`: indices of gripper channels.
|
||||||
|
- `compute_loss(pred, target)`: supervised loss for this space.
|
||||||
|
- `preprocess(proprio, action, mode)`: pre-step modifications.
|
||||||
|
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "base"
|
||||||
|
dim_action: int = 0
|
||||||
|
gripper_idx: tuple[int, ...] = ()
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
# Core supervised loss
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||||
|
"""Alias for compute_loss."""
|
||||||
|
return self.compute_loss(pred, target)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
# Space-level hooks
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
proprio: torch.Tensor,
|
||||||
|
action: torch.Tensor,
|
||||||
|
mode: str = "train",
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Default: return unchanged."""
|
||||||
|
return proprio, action
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Default: return unchanged."""
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Utilities
|
||||||
|
# =============================================================================
|
||||||
|
def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
|
||||||
|
bad = [i for i in idx if i < 0 or i >= dim_action]
|
||||||
|
if bad:
|
||||||
|
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Implementations
|
||||||
|
# =============================================================================
|
||||||
|
@register_action("ee6d")
|
||||||
|
class EE6DActionSpace(BaseActionSpace):
|
||||||
|
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
|
||||||
|
|
||||||
|
dim_action = 20
|
||||||
|
gripper_idx = (9, 19)
|
||||||
|
GRIPPER_SCALE = 1.0
|
||||||
|
XYZ_SCALE = 500.0
|
||||||
|
ROT_SCALE = 10.0
|
||||||
|
|
||||||
|
POS_IDX_1 = (0, 1, 2)
|
||||||
|
POS_IDX_2 = (10, 11, 12)
|
||||||
|
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||||
|
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||||
|
batch_size, seq_len, action_dim = pred.shape
|
||||||
|
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||||
|
|
||||||
|
# Gripper BCE
|
||||||
|
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||||
|
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||||
|
|
||||||
|
# XYZ position
|
||||||
|
pos_loss = (
|
||||||
|
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||||
|
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||||
|
) * self.XYZ_SCALE
|
||||||
|
|
||||||
|
# Rotation 6D
|
||||||
|
rot_loss = (
|
||||||
|
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||||
|
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||||
|
) * self.ROT_SCALE
|
||||||
|
|
||||||
|
return {
|
||||||
|
"position_loss": pos_loss,
|
||||||
|
"rotate6D_loss": rot_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""Zero-out gripper channels in proprio/action."""
|
||||||
|
proprio_m = proprio.clone()
|
||||||
|
action_m = action.clone()
|
||||||
|
proprio_m[..., self.gripper_idx] = 0.0
|
||||||
|
action_m[..., self.gripper_idx] = 0.0
|
||||||
|
return proprio_m, action_m
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply sigmoid to gripper logits."""
|
||||||
|
if action.size(-1) > max(self.gripper_idx):
|
||||||
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("joint")
|
||||||
|
class JointActionSpace(BaseActionSpace):
|
||||||
|
"""Joint-space layout with joints + gripper only."""
|
||||||
|
|
||||||
|
dim_action = 14
|
||||||
|
gripper_idx = (6, 13)
|
||||||
|
GRIPPER_SCALE = 0.1
|
||||||
|
JOINTS_SCALE = 1.0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
assert pred.shape == target.shape
|
||||||
|
batch_size, seq_len, action_dim = pred.shape
|
||||||
|
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||||
|
|
||||||
|
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||||
|
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||||
|
|
||||||
|
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
|
||||||
|
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
|
||||||
|
|
||||||
|
return {
|
||||||
|
"joints_loss": joints_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""Zero-out gripper channels in proprio/action."""
|
||||||
|
proprio_m = proprio.clone()
|
||||||
|
action_m = action.clone()
|
||||||
|
proprio_m[..., self.gripper_idx] = 0.0
|
||||||
|
action_m[..., self.gripper_idx] = 0.0
|
||||||
|
return proprio_m, action_m
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply sigmoid to gripper logits."""
|
||||||
|
if action.size(-1) > max(self.gripper_idx):
|
||||||
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("agibot_ee6d")
|
||||||
|
class AGIBOTEE6DActionSpace(BaseActionSpace):
|
||||||
|
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
|
||||||
|
|
||||||
|
dim_action = 20
|
||||||
|
gripper_idx = (9, 19)
|
||||||
|
GRIPPER_SCALE = 10.0
|
||||||
|
XYZ_SCALE = 500.0
|
||||||
|
ROT_SCALE = 10.0
|
||||||
|
POS_IDX_1 = (0, 1, 2)
|
||||||
|
POS_IDX_2 = (10, 11, 12)
|
||||||
|
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||||
|
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
assert pred.shape == target.shape
|
||||||
|
batch_size, seq_len, action_dim = pred.shape
|
||||||
|
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||||
|
|
||||||
|
gripper_loss = (
|
||||||
|
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
||||||
|
)
|
||||||
|
pos_loss = (
|
||||||
|
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||||
|
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||||
|
) * self.XYZ_SCALE
|
||||||
|
rot_loss = (
|
||||||
|
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||||
|
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||||
|
) * self.ROT_SCALE
|
||||||
|
|
||||||
|
return {
|
||||||
|
"position_loss": pos_loss,
|
||||||
|
"rotate6D_loss": rot_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""No preprocessing applied in AGIBOT variant."""
|
||||||
|
return proprio, action
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""AGIBOT does not postprocess."""
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("franka_joint7")
|
||||||
|
class FrankaJoint7ActionSpace(BaseActionSpace):
|
||||||
|
"""Franka Panda joint-space: 7 joints, no gripper."""
|
||||||
|
|
||||||
|
dim_action = 7
|
||||||
|
JOINTS_SCALE = 1.0
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||||
|
joints_loss = self.mse(pred, target) * self.JOINTS_SCALE
|
||||||
|
return {"joints_loss": joints_loss}
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""No preprocessing needed for 7 joint actions."""
|
||||||
|
return proprio, action
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return directly (no sigmoid since no gripper)."""
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@register_action("so101_bimanual")
|
||||||
|
class BimanualSO101ActionSpace(BaseActionSpace):
|
||||||
|
"""
|
||||||
|
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
|
||||||
|
|
||||||
|
Layout (real robot):
|
||||||
|
[left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
|
||||||
|
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||||
|
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||||
|
|
||||||
|
Real action dim: 12
|
||||||
|
Model-facing dim: 20 (extra 8 dummy dims at the end)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model output / training dimension (to match pretrained policy)
|
||||||
|
dim_action = 20
|
||||||
|
|
||||||
|
# Real robot action dimension
|
||||||
|
REAL_DIM = 12
|
||||||
|
|
||||||
|
# Indices of real vs dummy channels
|
||||||
|
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
|
||||||
|
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
|
||||||
|
|
||||||
|
# Grippers live in the real part
|
||||||
|
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
|
||||||
|
GRIPPER_SCALE = 1.0
|
||||||
|
JOINTS_SCALE = 1.0
|
||||||
|
|
||||||
|
# Indices for left and right arm joints (excluding grippers)
|
||||||
|
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
|
||||||
|
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
self.bce = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
# ---------- helpers ----------
|
||||||
|
|
||||||
|
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""If last dim is REAL_DIM (12), pad zeros to reach dim_action (20)."""
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
if x.size(-1) == self.dim_action:
|
||||||
|
return x
|
||||||
|
if x.size(-1) != self.REAL_DIM:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
|
||||||
|
)
|
||||||
|
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM]
|
||||||
|
pad = x.new_zeros(pad_shape)
|
||||||
|
return torch.cat([x, pad], dim=-1)
|
||||||
|
|
||||||
|
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Keep only the first REAL_DIM (12) dims for the real robot."""
|
||||||
|
return x[..., : self.REAL_DIM]
|
||||||
|
|
||||||
|
# ---------- loss ----------
|
||||||
|
|
||||||
|
def compute_loss(self, pred, target):
|
||||||
|
"""
|
||||||
|
pred: [B, T, 20] from the model
|
||||||
|
target: [B, T, 12] or [B, T, 20]
|
||||||
|
We pad target → 20 and compute loss only on the real dims.
|
||||||
|
"""
|
||||||
|
# Ensure both are [B, T, 20]
|
||||||
|
pred = self._pad_to_model_dim(pred)
|
||||||
|
target = self._pad_to_model_dim(target)
|
||||||
|
assert pred.shape == target.shape
|
||||||
|
|
||||||
|
# ---- MSE for all real dims (0–11) ----
|
||||||
|
real_dims = 12
|
||||||
|
|
||||||
|
joints_loss = (
|
||||||
|
self.mse(
|
||||||
|
pred[:, :, :real_dims],
|
||||||
|
target[:, :, :real_dims],
|
||||||
|
)
|
||||||
|
* self.JOINTS_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
|
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||||
|
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||||
|
|
||||||
|
gripper_loss = (
|
||||||
|
self.mse(
|
||||||
|
pred[:, :, [5, 11]],
|
||||||
|
target[:, :, [5, 11]],
|
||||||
|
)
|
||||||
|
* self.GRIPPER_SCALE
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"joints_loss": joints_loss,
|
||||||
|
"gripper_loss": gripper_loss,
|
||||||
|
"left_arm_loss": left_arm_loss,
|
||||||
|
"right_arm_loss": right_arm_loss,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------- preprocess / postprocess ----------
|
||||||
|
|
||||||
|
def preprocess(self, proprio, action, mode="train"):
|
||||||
|
"""
|
||||||
|
- If proprio/action are 12-dim, pad them to 20 for the model.
|
||||||
|
- Zero-out gripper channels in proprio/action to focus learning on joints.
|
||||||
|
"""
|
||||||
|
proprio_m = self._pad_to_model_dim(proprio.clone())
|
||||||
|
action_m = self._pad_to_model_dim(action.clone()) if action is not None else None
|
||||||
|
|
||||||
|
proprio_m[..., self.gripper_idx] = 0.0
|
||||||
|
if action_m is not None:
|
||||||
|
action_m[..., self.gripper_idx] = 0.0
|
||||||
|
|
||||||
|
return proprio_m, action_m
|
||||||
|
|
||||||
|
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
- Model outputs [*, 20]
|
||||||
|
- Apply sigmoid to gripper logits
|
||||||
|
- Return only the first 12 dims for the real robot:
|
||||||
|
["left_shoulder_pan.pos",
|
||||||
|
"left_shoulder_lift.pos",
|
||||||
|
"left_elbow_flex.pos",
|
||||||
|
"left_wrist_flex.pos",
|
||||||
|
"left_wrist_roll.pos",
|
||||||
|
"left_gripper.pos",
|
||||||
|
"right_shoulder_pan.pos",
|
||||||
|
"right_shoulder_lift.pos",
|
||||||
|
"right_elbow_flex.pos",
|
||||||
|
"right_wrist_flex.pos",
|
||||||
|
"right_wrist_roll.pos",
|
||||||
|
"right_gripper.pos"]
|
||||||
|
"""
|
||||||
|
# Ensure we at least have the real dims + grippers
|
||||||
|
if action.size(-1) < self.REAL_DIM:
|
||||||
|
raise ValueError(f"Expected at least {self.REAL_DIM} dims in action, got {action.size(-1)}")
|
||||||
|
|
||||||
|
# Apply sigmoid on gripper channels in model space (indices 5 and 11)
|
||||||
|
if action.size(-1) > max(self.gripper_idx):
|
||||||
|
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||||
|
|
||||||
|
# Return only the real 12-dim control vector for the env
|
||||||
|
return self._trim_to_real_dim(action)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Exports
|
||||||
|
# =============================================================================
|
||||||
|
__all__ = [
|
||||||
|
"BaseActionSpace",
|
||||||
|
"build_action_space",
|
||||||
|
"register_action",
|
||||||
|
"EE6DActionSpace",
|
||||||
|
"JointActionSpace",
|
||||||
|
"AGIBOTEE6DActionSpace",
|
||||||
|
"FrankaJoint7ActionSpace",
|
||||||
|
"BimanualSO101ActionSpace",
|
||||||
|
"ACTION_REGISTRY",
|
||||||
|
]
|
||||||
353
src/lerobot/policies/xvla/configuration_florence2.py
Normal file
353
src/lerobot/policies/xvla/configuration_florence2.py
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
""" Florence-2 configuration"""
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2VisionConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
||||||
|
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
||||||
|
The dropout rate of the drop path layer.
|
||||||
|
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
|
||||||
|
The patch size of the image.
|
||||||
|
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
|
||||||
|
The patch stride of the image.
|
||||||
|
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
|
||||||
|
The patch padding of the image.
|
||||||
|
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
|
||||||
|
Whether to apply layer normalization before the patch embedding layer.
|
||||||
|
enable_checkpoint (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to enable checkpointing.
|
||||||
|
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
|
||||||
|
The dimension of the embedding layer.
|
||||||
|
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||||
|
The number of attention heads.
|
||||||
|
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||||
|
The number of groups.
|
||||||
|
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
|
||||||
|
The depth of the model.
|
||||||
|
window_size (`int`, *optional*, defaults to 12):
|
||||||
|
The window size of the model.
|
||||||
|
projection_dim (`int`, *optional*, defaults to 1024):
|
||||||
|
The dimension of the projection layer.
|
||||||
|
visual_temporal_embedding (`dict`, *optional*):
|
||||||
|
The configuration of the visual temporal embedding.
|
||||||
|
image_pos_embed (`dict`, *optional*):
|
||||||
|
The configuration of the image position embedding.
|
||||||
|
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
|
||||||
|
The source of the image feature.
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
|
||||||
|
|
||||||
|
>>> # Initializing a Florence2 Vision style configuration
|
||||||
|
>>> configuration = Florence2VisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights)
|
||||||
|
>>> model = Florence2VisionModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "davit"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
patch_size=None,
|
||||||
|
patch_stride=None,
|
||||||
|
patch_padding=None,
|
||||||
|
patch_prenorm=None,
|
||||||
|
enable_checkpoint=False,
|
||||||
|
dim_embed=None,
|
||||||
|
num_heads=None,
|
||||||
|
num_groups=None,
|
||||||
|
depths=None,
|
||||||
|
window_size=12,
|
||||||
|
projection_dim=1024,
|
||||||
|
visual_temporal_embedding=None,
|
||||||
|
image_pos_embed=None,
|
||||||
|
image_feature_source=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
|
||||||
|
self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
|
||||||
|
self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
|
||||||
|
self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
|
||||||
|
self.enable_checkpoint = enable_checkpoint
|
||||||
|
self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
|
||||||
|
self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
|
||||||
|
self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
|
||||||
|
self.depths = depths if depths is not None else [1, 1, 9, 1]
|
||||||
|
self.window_size = window_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
|
||||||
|
if visual_temporal_embedding is None:
|
||||||
|
visual_temporal_embedding = {
|
||||||
|
"type": "COSINE",
|
||||||
|
"max_temporal_embeddings": 100,
|
||||||
|
}
|
||||||
|
self.visual_temporal_embedding = visual_temporal_embedding
|
||||||
|
|
||||||
|
if image_pos_embed is None:
|
||||||
|
image_pos_embed = {
|
||||||
|
"type": "learned_abs_2d",
|
||||||
|
"max_pos_embeddings": 1000,
|
||||||
|
}
|
||||||
|
self.image_pos_embed = image_pos_embed
|
||||||
|
|
||||||
|
self.image_feature_source = (
|
||||||
|
image_feature_source
|
||||||
|
if image_feature_source is not None
|
||||||
|
else ["spatial_avg_pool", "temporal_avg_pool"]
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2LanguageConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the BART
|
||||||
|
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 51289):
|
||||||
|
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`Florence2LanguageModel`].
|
||||||
|
d_model (`int`, *optional*, defaults to 1024):
|
||||||
|
Dimensionality of the layers and the pooler layer.
|
||||||
|
encoder_layers (`int`, *optional*, defaults to 12):
|
||||||
|
Number of encoder layers.
|
||||||
|
decoder_layers (`int`, *optional*, defaults to 12):
|
||||||
|
Number of decoder layers.
|
||||||
|
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||||
|
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||||
|
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||||
|
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||||
|
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.1):
|
||||||
|
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for activations inside the fully connected layer.
|
||||||
|
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for classifier.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||||
|
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||||
|
just in case (e.g., 512 or 1024 or 2048).
|
||||||
|
init_std (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||||
|
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||||
|
for more details.
|
||||||
|
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||||
|
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||||
|
for more details.
|
||||||
|
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||||
|
Scale embeddings by diving by sqrt(d_model).
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||||
|
num_labels (`int`, *optional*, defaults to 3):
|
||||||
|
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
|
||||||
|
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
||||||
|
`eos_token_id`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
|
||||||
|
|
||||||
|
>>> # Initializing a Florence2 Language style configuration
|
||||||
|
>>> configuration = Florence2LanguageConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model (with random weights)
|
||||||
|
>>> model = Florence2LanguageModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "florence2_language"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=51289,
|
||||||
|
max_position_embeddings=1024,
|
||||||
|
encoder_layers=12,
|
||||||
|
encoder_ffn_dim=4096,
|
||||||
|
encoder_attention_heads=16,
|
||||||
|
decoder_layers=12,
|
||||||
|
decoder_ffn_dim=4096,
|
||||||
|
decoder_attention_heads=16,
|
||||||
|
encoder_layerdrop=0.0,
|
||||||
|
decoder_layerdrop=0.0,
|
||||||
|
activation_function="gelu",
|
||||||
|
d_model=1024,
|
||||||
|
dropout=0.1,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
activation_dropout=0.0,
|
||||||
|
init_std=0.02,
|
||||||
|
classifier_dropout=0.0,
|
||||||
|
scale_embedding=False,
|
||||||
|
use_cache=True,
|
||||||
|
num_labels=3,
|
||||||
|
pad_token_id=1,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=2,
|
||||||
|
is_encoder_decoder=True,
|
||||||
|
decoder_start_token_id=2,
|
||||||
|
forced_eos_token_id=2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.d_model = d_model
|
||||||
|
self.encoder_ffn_dim = encoder_ffn_dim
|
||||||
|
self.encoder_layers = encoder_layers
|
||||||
|
self.encoder_attention_heads = encoder_attention_heads
|
||||||
|
self.decoder_ffn_dim = decoder_ffn_dim
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.decoder_attention_heads = decoder_attention_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
self.activation_function = activation_function
|
||||||
|
self.init_std = init_std
|
||||||
|
self.encoder_layerdrop = encoder_layerdrop
|
||||||
|
self.decoder_layerdrop = decoder_layerdrop
|
||||||
|
self.classifier_dropout = classifier_dropout
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.num_hidden_layers = encoder_layers
|
||||||
|
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
num_labels=num_labels,
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
is_encoder_decoder=is_encoder_decoder,
|
||||||
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
forced_eos_token_id=forced_eos_token_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure backward compatibility for BART CNN models
|
||||||
|
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||||
|
self.forced_bos_token_id = self.bos_token_id
|
||||||
|
warnings.warn(
|
||||||
|
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
||||||
|
"The config can simply be saved and uploaded again to be fixed.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Florence2Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
||||||
|
Florence-2 model according to the specified arguments, defining the model architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_config (`Florence2VisionConfig`, *optional*):
|
||||||
|
Custom vision config or dict
|
||||||
|
text_config (`Union[AutoConfig, dict]`, *optional*):
|
||||||
|
The config object of the text backbone.
|
||||||
|
ignore_index (`int`, *optional*, defaults to -100):
|
||||||
|
The ignore index for the loss function.
|
||||||
|
vocab_size (`int`, *optional*, defaults to 51289):
|
||||||
|
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
|
||||||
|
projection_dim (`int`, *optional*, defaults to 1024):
|
||||||
|
Dimension of the multimodal projection space.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
|
||||||
|
|
||||||
|
>>> # Initializing a clip-like vision config
|
||||||
|
>>> vision_config = CLIPVisionConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Bart config
|
||||||
|
>>> text_config = BartConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a Florence-2 configuration
|
||||||
|
>>> configuration = Florence2Config(vision_config, text_config)
|
||||||
|
|
||||||
|
>>> # Initializing a model from the florence-2 configuration
|
||||||
|
>>> model = Florence2ForConditionalGeneration(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "florence2"
|
||||||
|
is_composition = False
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_config=None,
|
||||||
|
text_config=None,
|
||||||
|
ignore_index=-100,
|
||||||
|
vocab_size=51289,
|
||||||
|
projection_dim=1024,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.projection_dim = projection_dim
|
||||||
|
if vision_config is not None:
|
||||||
|
vision_config = Florence2VisionConfig(**vision_config)
|
||||||
|
self.vision_config = vision_config
|
||||||
|
|
||||||
|
self.text_config = text_config
|
||||||
|
if text_config is not None:
|
||||||
|
self.text_config = Florence2LanguageConfig(**text_config)
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
190
src/lerobot/policies/xvla/configuration_xvla.py
Normal file
190
src/lerobot/policies/xvla/configuration_xvla.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES
|
||||||
|
|
||||||
|
# Conditional import for type checking and lazy loading
|
||||||
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
|
if TYPE_CHECKING or _transformers_available:
|
||||||
|
from .configuration_florence2 import Florence2Config
|
||||||
|
else:
|
||||||
|
Florence2Config = None
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("xvla")
|
||||||
|
@dataclass
|
||||||
|
class XVLAConfig(PreTrainedConfig):
|
||||||
|
"""
|
||||||
|
Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
|
||||||
|
plug into the LeRobot training stack.
|
||||||
|
|
||||||
|
The config mirrors the knobs exposed in the original XVLA repository but also
|
||||||
|
declares the input/output feature contract required by LeRobot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Input / output structure
|
||||||
|
n_obs_steps: int = 1
|
||||||
|
chunk_size: int = 32
|
||||||
|
n_action_steps: int = 32
|
||||||
|
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||||
|
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
|
"ACTION": NormalizationMode.MEAN_STD,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Florence2 backbone and tokenizer configuration
|
||||||
|
florence_config: dict[str, Any] = field(default_factory=dict)
|
||||||
|
tokenizer_name: str = "facebook/bart-large"
|
||||||
|
tokenizer_max_length: int = 64
|
||||||
|
tokenizer_padding_side: str = "right"
|
||||||
|
pad_language_to: str = "max_length"
|
||||||
|
|
||||||
|
# Transformer head
|
||||||
|
hidden_size: int = 1024
|
||||||
|
depth: int = 24
|
||||||
|
num_heads: int = 16
|
||||||
|
mlp_ratio: float = 4.0
|
||||||
|
num_domains: int = 30
|
||||||
|
len_soft_prompts: int = 32
|
||||||
|
dim_time: int = 32
|
||||||
|
max_len_seq: int = 512
|
||||||
|
use_hetero_proj: bool = False
|
||||||
|
|
||||||
|
# Action & proprioception
|
||||||
|
action_mode: str = "ee6d"
|
||||||
|
num_denoising_steps: int = 10
|
||||||
|
use_proprio: bool = True
|
||||||
|
max_state_dim: int = 32
|
||||||
|
domain_feature_key: str | None = None
|
||||||
|
|
||||||
|
# Vision preprocessing
|
||||||
|
resize_imgs_with_padding: tuple[int, int] | None = None
|
||||||
|
num_image_views: int | None = None
|
||||||
|
empty_cameras: int = 0
|
||||||
|
|
||||||
|
# Freezing options for VLM components
|
||||||
|
# By default, VLM encoders are frozen and only policy transformer + soft prompts train
|
||||||
|
freeze_vision_encoder: bool = True # Freeze VLM vision encoder weights
|
||||||
|
freeze_language_encoder: bool = True # Freeze VLM language encoder weights
|
||||||
|
train_policy_transformer: bool = True # Allow policy transformer to train
|
||||||
|
train_soft_prompts: bool = True # Allow soft prompts to train
|
||||||
|
|
||||||
|
# Training presets
|
||||||
|
optimizer_lr: float = 1e-4
|
||||||
|
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||||
|
optimizer_eps: float = 1e-8
|
||||||
|
optimizer_weight_decay: float = 1e-4
|
||||||
|
optimizer_grad_clip_norm: float = 10.0
|
||||||
|
|
||||||
|
scheduler_warmup_steps: int = 1_000
|
||||||
|
scheduler_decay_steps: int = 30_000
|
||||||
|
scheduler_decay_lr: float = 2.5e-6
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
if self.chunk_size <= 0:
|
||||||
|
raise ValueError("`chunk_size` must be strictly positive.")
|
||||||
|
if self.n_action_steps > self.chunk_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
|
||||||
|
)
|
||||||
|
if self.num_image_views is not None and self.num_image_views <= 0:
|
||||||
|
raise ValueError("`num_image_views` must be > 0 when specified.")
|
||||||
|
if self.dtype not in ["bfloat16", "float32"]:
|
||||||
|
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||||
|
self._florence_config_obj: Florence2Config | None = None
|
||||||
|
|
||||||
|
def get_florence_config(self) -> Florence2Config:
|
||||||
|
"""
|
||||||
|
Build (and cache) the Florence2 transformer config that should back the VLM.
|
||||||
|
"""
|
||||||
|
if self._florence_config_obj is None:
|
||||||
|
config_dict = dict(self.florence_config)
|
||||||
|
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
|
||||||
|
raise ValueError("vision_config is required")
|
||||||
|
|
||||||
|
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||||
|
raise ValueError("text_config is required")
|
||||||
|
self._florence_config_obj = Florence2Config(**config_dict)
|
||||||
|
return self._florence_config_obj
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
if not self.image_features:
|
||||||
|
raise ValueError("XVLA requires at least one visual feature in the inputs.")
|
||||||
|
if self.use_proprio and self.robot_state_feature is None:
|
||||||
|
raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
|
||||||
|
if self.num_image_views is None:
|
||||||
|
self.num_image_views = len(self.image_features) + self.empty_cameras
|
||||||
|
else:
|
||||||
|
self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
|
||||||
|
|
||||||
|
if self.empty_cameras > 0:
|
||||||
|
height, width = (480, 640)
|
||||||
|
if self.resize_imgs_with_padding is not None:
|
||||||
|
height, width = self.resize_imgs_with_padding
|
||||||
|
for idx in range(self.empty_cameras):
|
||||||
|
key = f"{OBS_IMAGES}.empty_camera_{idx}"
|
||||||
|
if key not in self.input_features:
|
||||||
|
self.input_features[key] = PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, height, width),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
|
return AdamWConfig(
|
||||||
|
lr=self.optimizer_lr,
|
||||||
|
betas=self.optimizer_betas,
|
||||||
|
eps=self.optimizer_eps,
|
||||||
|
weight_decay=self.optimizer_weight_decay,
|
||||||
|
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||||
|
return CosineDecayWithWarmupSchedulerConfig(
|
||||||
|
peak_lr=self.optimizer_lr,
|
||||||
|
decay_lr=self.scheduler_decay_lr,
|
||||||
|
num_warmup_steps=self.scheduler_warmup_steps,
|
||||||
|
num_decay_steps=self.scheduler_decay_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list[int] | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list[int]:
|
||||||
|
return list(range(self.chunk_size))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> list[int] | None:
|
||||||
|
return None
|
||||||
2754
src/lerobot/policies/xvla/modeling_florence2.py
Normal file
2754
src/lerobot/policies/xvla/modeling_florence2.py
Normal file
File diff suppressed because it is too large
Load Diff
526
src/lerobot/policies/xvla/modeling_xvla.py
Normal file
526
src/lerobot/policies/xvla/modeling_xvla.py
Normal file
@@ -0,0 +1,526 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F # noqa: N812
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
|
from lerobot.policies.utils import populate_queues
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||||
|
|
||||||
|
from .action_hub import build_action_space
|
||||||
|
from .configuration_florence2 import Florence2Config
|
||||||
|
from .configuration_xvla import XVLAConfig
|
||||||
|
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||||
|
from .soft_transformer import SoftPromptedTransformer
|
||||||
|
|
||||||
|
|
||||||
|
class XVLAModel(nn.Module):
|
||||||
|
"""
|
||||||
|
XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: XVLAConfig,
|
||||||
|
florence_config: Florence2Config,
|
||||||
|
proprio_dim: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.chunk_size: int = config.chunk_size
|
||||||
|
self.use_proprio: bool = config.use_proprio
|
||||||
|
self.action_space = build_action_space(config.action_mode.lower())
|
||||||
|
self.dim_action = self.action_space.dim_action
|
||||||
|
self.dim_proprio = proprio_dim
|
||||||
|
|
||||||
|
self.vlm = Florence2ForConditionalGeneration(florence_config)
|
||||||
|
if hasattr(self.vlm, "language_model"):
|
||||||
|
lm = self.vlm.language_model
|
||||||
|
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
|
||||||
|
del lm.model.decoder
|
||||||
|
if hasattr(lm, "lm_head"):
|
||||||
|
del lm.lm_head
|
||||||
|
|
||||||
|
projection_dim = getattr(self.vlm.config, "projection_dim", None)
|
||||||
|
if projection_dim is None:
|
||||||
|
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
|
||||||
|
|
||||||
|
self.transformer = SoftPromptedTransformer(
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
multi_modal_input_size=projection_dim,
|
||||||
|
depth=config.depth,
|
||||||
|
num_heads=config.num_heads,
|
||||||
|
mlp_ratio=config.mlp_ratio,
|
||||||
|
num_domains=config.num_domains,
|
||||||
|
dim_action=self.dim_action,
|
||||||
|
dim_propio=self.dim_proprio,
|
||||||
|
len_soft_prompts=config.len_soft_prompts,
|
||||||
|
dim_time=config.dim_time,
|
||||||
|
max_len_seq=config.max_len_seq,
|
||||||
|
use_hetero_proj=config.use_hetero_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply freezing based on config
|
||||||
|
self._apply_freezing()
|
||||||
|
|
||||||
|
# Apply dtype casting based on config
|
||||||
|
self._apply_dtype()
|
||||||
|
|
||||||
|
def _get_target_dtype(self) -> torch.dtype:
|
||||||
|
"""Get the target dtype based on config."""
|
||||||
|
if self.config.dtype == "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
def _apply_dtype(self) -> None:
|
||||||
|
"""
|
||||||
|
Apply dtype casting to model components based on config.
|
||||||
|
"""
|
||||||
|
target_dtype = self._get_target_dtype()
|
||||||
|
self.to(dtype=target_dtype)
|
||||||
|
|
||||||
|
def _apply_freezing(self) -> None:
|
||||||
|
"""
|
||||||
|
Freeze VLM vision and language encoders based on config options.
|
||||||
|
Keep only policy transformer and soft prompts trainable.
|
||||||
|
"""
|
||||||
|
# Freeze vision encoder
|
||||||
|
if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
|
||||||
|
for param in self.vlm.vision_tower.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze language encoder
|
||||||
|
if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
|
||||||
|
lm = self.vlm.language_model
|
||||||
|
# Freeze encoder
|
||||||
|
if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
|
||||||
|
for param in lm.model.encoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
# Freeze shared embeddings
|
||||||
|
if hasattr(lm, "model") and hasattr(lm.model, "shared"):
|
||||||
|
for param in lm.model.shared.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze or unfreeze policy transformer
|
||||||
|
if not self.config.train_policy_transformer:
|
||||||
|
for name, param in self.transformer.named_parameters():
|
||||||
|
if "soft_prompts" not in name:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze or unfreeze soft prompts
|
||||||
|
if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
|
||||||
|
for param in self.transformer.soft_prompt_hub.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward_vlm(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
image_mask: torch.Tensor,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Encode text and multi-view images via Florence2 encoder.
|
||||||
|
"""
|
||||||
|
batch_size, num_views = pixel_values.shape[:2]
|
||||||
|
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
|
||||||
|
flat_images = pixel_values.flatten(0, 1)
|
||||||
|
num_valid = int(flat_mask.sum().item())
|
||||||
|
if num_valid == 0:
|
||||||
|
raise ValueError("At least one image view must be valid per batch.")
|
||||||
|
|
||||||
|
valid_images = flat_images[flat_mask]
|
||||||
|
valid_feats = self.vlm._encode_image(valid_images)
|
||||||
|
tokens_per_view, hidden_dim = valid_feats.shape[1:]
|
||||||
|
|
||||||
|
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
|
||||||
|
image_features[flat_mask] = valid_feats
|
||||||
|
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
|
||||||
|
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
|
||||||
|
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
|
||||||
|
image_features[:, 0],
|
||||||
|
inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
enc_out = self.vlm.language_model.model.encoder(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
inputs_embeds=merged_embeds,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
|
||||||
|
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
image_input: torch.FloatTensor,
|
||||||
|
image_mask: torch.Tensor,
|
||||||
|
domain_id: torch.LongTensor,
|
||||||
|
proprio: torch.Tensor,
|
||||||
|
action: torch.Tensor,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||||
|
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
t = (
|
||||||
|
torch.rand(1, device=input_ids.device)
|
||||||
|
+ torch.arange(batch_size, device=input_ids.device) / batch_size
|
||||||
|
) % (1 - 1e-5)
|
||||||
|
|
||||||
|
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||||
|
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
||||||
|
|
||||||
|
pred_action = self.transformer(
|
||||||
|
domain_id=domain_id,
|
||||||
|
action_with_noise=action_noisy_m,
|
||||||
|
t=t,
|
||||||
|
proprio=proprio_m,
|
||||||
|
**enc,
|
||||||
|
)
|
||||||
|
return self.action_space.compute_loss(pred_action, action)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate_actions(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
image_input: torch.FloatTensor,
|
||||||
|
image_mask: torch.Tensor,
|
||||||
|
domain_id: torch.LongTensor,
|
||||||
|
proprio: torch.Tensor,
|
||||||
|
steps: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
self.eval()
|
||||||
|
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||||
|
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
action_dim = self.dim_action
|
||||||
|
|
||||||
|
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=proprio.dtype)
|
||||||
|
action = torch.zeros_like(x1)
|
||||||
|
|
||||||
|
steps = max(1, int(steps))
|
||||||
|
for i in range(steps, 0, -1):
|
||||||
|
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=proprio.dtype)
|
||||||
|
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||||
|
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
|
||||||
|
action = self.transformer(
|
||||||
|
domain_id=domain_id,
|
||||||
|
action_with_noise=x_t_m,
|
||||||
|
proprio=proprio_m,
|
||||||
|
t=t,
|
||||||
|
**enc,
|
||||||
|
)
|
||||||
|
return self.action_space.postprocess(action)
|
||||||
|
|
||||||
|
|
||||||
|
class XVLAPolicy(PreTrainedPolicy):
|
||||||
|
"""LeRobot-compliant wrapper built around the XVLA model."""
|
||||||
|
|
||||||
|
config_class = XVLAConfig
|
||||||
|
name = "xvla"
|
||||||
|
|
||||||
|
def __init__(self, config: XVLAConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
config.validate_features()
|
||||||
|
florence_config = config.get_florence_config()
|
||||||
|
proprio_dim = config.max_state_dim if config.use_proprio else 0
|
||||||
|
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self._queues = {
|
||||||
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
"""Return only trainable parameters for optimization."""
|
||||||
|
return filter(lambda p: p.requires_grad, self.parameters())
|
||||||
|
|
||||||
|
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||||
|
if not self.config.use_proprio or OBS_STATE not in batch:
|
||||||
|
return torch.zeros(batch_size, 0, device=device)
|
||||||
|
state = batch[OBS_STATE]
|
||||||
|
if state.ndim > 2:
|
||||||
|
state = state[:, -1, :]
|
||||||
|
return pad_vector(state, self.model.dim_proprio)
|
||||||
|
|
||||||
|
def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||||
|
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||||
|
if len(present_img_keys) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"All image features are missing from the batch. "
|
||||||
|
f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
masks = []
|
||||||
|
for key in present_img_keys:
|
||||||
|
img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
|
||||||
|
if self.config.resize_imgs_with_padding is not None:
|
||||||
|
img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
|
||||||
|
images.append(img)
|
||||||
|
masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
|
||||||
|
|
||||||
|
stacked_imgs = torch.stack(images, dim=1)
|
||||||
|
stacked_masks = torch.stack(masks, dim=1)
|
||||||
|
|
||||||
|
total_views = self.config.num_image_views or stacked_imgs.size(1)
|
||||||
|
total_views = max(total_views, stacked_imgs.size(1))
|
||||||
|
num_pad = total_views - stacked_imgs.size(1)
|
||||||
|
if num_pad > 0:
|
||||||
|
pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
|
||||||
|
pad_imgs = stacked_imgs.new_zeros(pad_shape)
|
||||||
|
pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
|
||||||
|
stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
|
||||||
|
stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
|
||||||
|
|
||||||
|
return stacked_imgs, stacked_masks
|
||||||
|
|
||||||
|
def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||||
|
candidate = None
|
||||||
|
if self.config.domain_feature_key and self.config.domain_feature_key in batch:
|
||||||
|
candidate = batch[self.config.domain_feature_key]
|
||||||
|
elif "domain_id" in batch:
|
||||||
|
candidate = batch["domain_id"]
|
||||||
|
|
||||||
|
if candidate is None:
|
||||||
|
return torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
if not isinstance(candidate, torch.Tensor):
|
||||||
|
candidate = torch.as_tensor(candidate, device=device)
|
||||||
|
else:
|
||||||
|
candidate = candidate.to(device=device)
|
||||||
|
|
||||||
|
if candidate.ndim == 0:
|
||||||
|
candidate = candidate.expand(batch_size)
|
||||||
|
if candidate.ndim > 1:
|
||||||
|
candidate = candidate.view(candidate.shape[0], -1)[:, 0]
|
||||||
|
if candidate.shape[0] != batch_size:
|
||||||
|
candidate = candidate.expand(batch_size)
|
||||||
|
return candidate.to(dtype=torch.long)
|
||||||
|
|
||||||
|
def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
if ACTION not in batch:
|
||||||
|
raise ValueError("Batch is missing action targets required for training.")
|
||||||
|
actions = batch[ACTION]
|
||||||
|
if actions.ndim == 2:
|
||||||
|
actions = actions.unsqueeze(1)
|
||||||
|
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
|
||||||
|
if actions.shape[-1] != self.model.dim_action:
|
||||||
|
actions = pad_vector(actions, self.model.dim_action)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
input_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
images, image_mask = self._prepare_images(batch)
|
||||||
|
domain_id = self._get_domain_id(batch, batch_size, images.device)
|
||||||
|
proprio = self._prepare_state(batch, batch_size, images.device)
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"image_input": images,
|
||||||
|
"image_mask": image_mask,
|
||||||
|
"domain_id": domain_id,
|
||||||
|
"proprio": proprio,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _trim_action_dim(self, actions: Tensor) -> Tensor:
|
||||||
|
feature = self.config.action_feature
|
||||||
|
if feature is None:
|
||||||
|
return actions
|
||||||
|
desired_dim = self.model.dim_action
|
||||||
|
if desired_dim == actions.shape[-1]:
|
||||||
|
return actions
|
||||||
|
if desired_dim < actions.shape[-1]:
|
||||||
|
return actions[..., :desired_dim]
|
||||||
|
return pad_vector(actions, desired_dim)
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||||
|
inputs = self._build_model_inputs(batch)
|
||||||
|
targets = self._prepare_action_targets(batch)
|
||||||
|
losses = self.model(action=targets, **inputs)
|
||||||
|
total_loss = sum(losses.values())
|
||||||
|
|
||||||
|
log_dict = {k: v.detach().item() for k, v in losses.items()}
|
||||||
|
log_dict["loss"] = total_loss.detach().item()
|
||||||
|
return total_loss, log_dict
|
||||||
|
|
||||||
|
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
inputs = self._build_model_inputs(batch)
|
||||||
|
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
||||||
|
actions = self._trim_action_dim(actions)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||||
|
self.eval()
|
||||||
|
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||||
|
return self._get_action_chunk(batch)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||||
|
self.eval()
|
||||||
|
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||||
|
|
||||||
|
if len(self._queues[ACTION]) == 0:
|
||||||
|
actions = self._get_action_chunk(batch)
|
||||||
|
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||||
|
|
||||||
|
return self._queues[ACTION].popleft()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls: builtins.type[T],
|
||||||
|
pretrained_name_or_path: str | Path,
|
||||||
|
*,
|
||||||
|
config: PreTrainedConfig | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
cache_dir: str | Path | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
strict: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads XVLA model weights with:
|
||||||
|
- automatic prefix 'model.' added to all keys
|
||||||
|
- skip list for layers that should remain randomly initialized
|
||||||
|
"""
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
|
# step 1: load config
|
||||||
|
# TODO: jadechoghari, fix this
|
||||||
|
if config is None:
|
||||||
|
config = PreTrainedConfig.from_pretrained(
|
||||||
|
pretrained_name_or_path=pretrained_name_or_path,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = str(pretrained_name_or_path)
|
||||||
|
instance = cls(config, **kwargs)
|
||||||
|
# step 2: locate model.safetensors
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
logging.info("Loading weights from local directory")
|
||||||
|
model_file = os.path.join(model_id, "model.safetensors")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.utils import HfHubHTTPError
|
||||||
|
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename="model.safetensors",
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
except HfHubHTTPError as e:
|
||||||
|
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
||||||
|
|
||||||
|
logging.info(f"Loading checkpoint from {model_file}")
|
||||||
|
# step 3: load state dict
|
||||||
|
state_dict = safetensors.torch.load_file(model_file)
|
||||||
|
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
||||||
|
shared_key = "model.vlm.language_model.model.shared.weight"
|
||||||
|
if encoder_key in state_dict:
|
||||||
|
state_dict[shared_key] = state_dict[encoder_key]
|
||||||
|
# or deepcopy
|
||||||
|
# step 4: load into instance
|
||||||
|
instance.load_state_dict(state_dict, strict=True)
|
||||||
|
logging.info("Loaded XVLA checkpoint")
|
||||||
|
# step 5: finalize
|
||||||
|
# Reapply dtype after loading state dict
|
||||||
|
instance.model._apply_dtype()
|
||||||
|
instance.to(config.device)
|
||||||
|
instance.eval()
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
|
||||||
|
if img.ndim != 4:
|
||||||
|
raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
|
||||||
|
|
||||||
|
current_height, current_width = img.shape[2:]
|
||||||
|
if current_height == height and current_width == width:
|
||||||
|
return img
|
||||||
|
|
||||||
|
ratio = max(current_width / width, current_height / height)
|
||||||
|
resized_height = int(current_height / ratio)
|
||||||
|
resized_width = int(current_width / ratio)
|
||||||
|
resized_img = F.interpolate(
|
||||||
|
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
pad_height = max(0, height - resized_height)
|
||||||
|
pad_width = max(0, width - resized_width)
|
||||||
|
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||||
|
return padded_img
|
||||||
|
|
||||||
|
|
||||||
|
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
|
||||||
|
if vector.shape[-1] == new_dim:
|
||||||
|
return vector
|
||||||
|
if new_dim == 0:
|
||||||
|
shape = list(vector.shape)
|
||||||
|
shape[-1] = 0
|
||||||
|
return vector.new_zeros(*shape)
|
||||||
|
shape = list(vector.shape)
|
||||||
|
current_dim = shape[-1]
|
||||||
|
shape[-1] = new_dim
|
||||||
|
new_vector = vector.new_zeros(*shape)
|
||||||
|
length = min(current_dim, new_dim)
|
||||||
|
new_vector[..., :length] = vector[..., :length]
|
||||||
|
return new_vector
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
|
||||||
|
current_len = tensor.size(dim)
|
||||||
|
if current_len == target_len:
|
||||||
|
return tensor
|
||||||
|
if current_len > target_len:
|
||||||
|
slices = [slice(None)] * tensor.dim()
|
||||||
|
slices[dim] = slice(0, target_len)
|
||||||
|
return tensor[tuple(slices)]
|
||||||
|
pad_shape = list(tensor.shape)
|
||||||
|
pad_shape[dim] = target_len - current_len
|
||||||
|
pad_tensor = tensor.new_zeros(pad_shape)
|
||||||
|
return torch.cat([tensor, pad_tensor], dim=dim)
|
||||||
551
src/lerobot/policies/xvla/processor_xvla.py
Normal file
551
src/lerobot/policies/xvla/processor_xvla.py
Normal file
@@ -0,0 +1,551 @@
|
|||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||||
|
from lerobot.datasets.factory import IMAGENET_STATS
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
|
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
ObservationProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
ProcessorStep,
|
||||||
|
ProcessorStepRegistry,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
TokenizerProcessorStep,
|
||||||
|
UnnormalizerProcessorStep,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||||
|
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||||
|
from lerobot.utils.constants import (
|
||||||
|
OBS_IMAGES,
|
||||||
|
OBS_STATE,
|
||||||
|
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_xvla_pre_post_processors(
|
||||||
|
config: XVLAConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Build the LeRobot processor pipelines for XVLA.
|
||||||
|
"""
|
||||||
|
|
||||||
|
features = {**config.input_features, **config.output_features}
|
||||||
|
input_steps = [
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
TokenizerProcessorStep(
|
||||||
|
tokenizer_name=config.tokenizer_name,
|
||||||
|
max_length=config.tokenizer_max_length,
|
||||||
|
padding=config.pad_language_to,
|
||||||
|
padding_side=config.tokenizer_padding_side,
|
||||||
|
),
|
||||||
|
XVLAImageToFloatProcessorStep(),
|
||||||
|
XVLAImageNetNormalizeProcessorStep(),
|
||||||
|
XVLAAddDomainIdProcessorStep(),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||||
|
),
|
||||||
|
]
|
||||||
|
output_steps = [
|
||||||
|
UnnormalizerProcessorStep(
|
||||||
|
features=config.output_features,
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device="cpu"),
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=input_steps,
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=output_steps,
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Custom XVLA processor steps
|
||||||
|
@dataclass
|
||||||
|
class LiberoProcessorStep(ObservationProcessorStep):
|
||||||
|
"""
|
||||||
|
Processes LIBERO observations into the LeRobot format.
|
||||||
|
|
||||||
|
This step handles the specific observation structure from LIBERO environments,
|
||||||
|
which includes nested robot_state dictionaries and image observations.
|
||||||
|
|
||||||
|
**State Processing:**
|
||||||
|
- Processes the `robot_state` dictionary which contains nested end-effector,
|
||||||
|
gripper, and joint information.
|
||||||
|
- Extracts and concatenates:
|
||||||
|
- End-effector position (3D)
|
||||||
|
- End-effector quaternion converted to axis-angle (3D)
|
||||||
|
- Gripper joint positions (2D)
|
||||||
|
- Maps the concatenated state to `"observation.state"`.
|
||||||
|
|
||||||
|
**Image Processing:**
|
||||||
|
- Rotates images by 180 degrees by flipping both height and width dimensions.
|
||||||
|
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _process_observation(self, observation):
|
||||||
|
"""
|
||||||
|
Processes both image and robot_state observations from LIBERO.
|
||||||
|
"""
|
||||||
|
processed_obs = observation.copy()
|
||||||
|
for key in list(processed_obs.keys()):
|
||||||
|
if key.startswith(f"{OBS_IMAGES}."):
|
||||||
|
img = processed_obs[key]
|
||||||
|
|
||||||
|
if key == f"{OBS_IMAGES}.image":
|
||||||
|
# Flip both H and W
|
||||||
|
img = torch.flip(img, dims=[2, 3])
|
||||||
|
|
||||||
|
processed_obs[key] = img
|
||||||
|
# Process robot_state into a flat state vector
|
||||||
|
if "observation.robot_state" in processed_obs:
|
||||||
|
robot_state = processed_obs.pop("observation.robot_state")
|
||||||
|
|
||||||
|
# Extract components
|
||||||
|
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||||
|
eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
|
||||||
|
eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
|
||||||
|
|
||||||
|
extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
|
||||||
|
|
||||||
|
proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
|
||||||
|
state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
|
||||||
|
# ensure float32
|
||||||
|
state = state.float()
|
||||||
|
if state.dim() == 1:
|
||||||
|
state = state.unsqueeze(0)
|
||||||
|
|
||||||
|
processed_obs[OBS_STATE] = state
|
||||||
|
return processed_obs
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""
|
||||||
|
Transforms feature keys from the LIBERO format to the LeRobot standard.
|
||||||
|
"""
|
||||||
|
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
|
||||||
|
|
||||||
|
# copy over non-STATE features
|
||||||
|
for ft, feats in features.items():
|
||||||
|
if ft != PipelineFeatureType.STATE:
|
||||||
|
new_features[ft] = feats.copy()
|
||||||
|
|
||||||
|
# rebuild STATE features
|
||||||
|
state_feats = {}
|
||||||
|
|
||||||
|
# add our new flattened state
|
||||||
|
state_feats["observation.state"] = PolicyFeature(
|
||||||
|
key="observation.state",
|
||||||
|
shape=(20,),
|
||||||
|
dtype="float32",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_features[PipelineFeatureType.STATE] = state_feats
|
||||||
|
|
||||||
|
return new_features
|
||||||
|
|
||||||
|
def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: 6D rotation representation, shape (B, 6)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if input is not a torch tensor
|
||||||
|
ValueError: if shape is not (B, 3, 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(rot_mats, torch.Tensor):
|
||||||
|
raise TypeError(f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}")
|
||||||
|
|
||||||
|
if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
|
||||||
|
raise ValueError(f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}")
|
||||||
|
|
||||||
|
rot_mats = rot_mats.to(torch.float32)
|
||||||
|
|
||||||
|
col1 = rot_mats[:, :3, 0] # (B, 3)
|
||||||
|
col2 = rot_mats[:, :3, 1] # (B, 3)
|
||||||
|
|
||||||
|
rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
|
||||||
|
|
||||||
|
return rot6d
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
return self._process_observation(observation)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_image_scale")
|
||||||
|
class XVLAImageScaleProcessorStep(ProcessorStep):
|
||||||
|
"""Scale image observations by 255 to convert from [0, 1] to [0, 255] range.
|
||||||
|
|
||||||
|
This processor step multiplies all image observations by 255, which is required
|
||||||
|
for XVLA models that expect images in uint8-like range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_keys: List of observation keys that contain images to scale.
|
||||||
|
If None, will automatically detect keys starting with "observation.images."
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_keys: list[str] | None = None
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Scale image observations by 255."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
if obs is None:
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
# Make a copy of observations to avoid modifying the original
|
||||||
|
obs = obs.copy()
|
||||||
|
|
||||||
|
# Determine which keys to scale
|
||||||
|
keys_to_scale = self.image_keys
|
||||||
|
if keys_to_scale is None:
|
||||||
|
# Auto-detect image keys
|
||||||
|
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
|
||||||
|
|
||||||
|
# Scale each image
|
||||||
|
for key in keys_to_scale:
|
||||||
|
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||||
|
obs[key] = obs[key] * 255
|
||||||
|
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = obs
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""Image scaling doesn't change feature structure."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"image_keys": self.image_keys,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_image_to_float")
|
||||||
|
class XVLAImageToFloatProcessorStep(ProcessorStep):
|
||||||
|
"""Convert image observations from [0, 255] to [0, 1] range.
|
||||||
|
|
||||||
|
This processor step divides image observations by 255 to convert from uint8-like
|
||||||
|
range [0, 255] to float range [0, 1]. This is typically used when loading images
|
||||||
|
that are stored as uint8 values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_keys: List of observation keys that contain images to convert.
|
||||||
|
If None, will automatically detect keys starting with "observation.images."
|
||||||
|
validate_range: If True, validates that input values are in [0, 255] range (default: True)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If validate_range is True and image values are not in [0, 255] range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_keys: list[str] | None = None
|
||||||
|
validate_range: bool = True
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Convert image observations from [0, 255] to [0, 1]."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
if obs is None:
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
# Make a copy of observations to avoid modifying the original
|
||||||
|
obs = obs.copy()
|
||||||
|
|
||||||
|
# Determine which keys to convert
|
||||||
|
keys_to_convert = self.image_keys
|
||||||
|
if keys_to_convert is None:
|
||||||
|
# Auto-detect image keys
|
||||||
|
keys_to_convert = [k for k in obs if k.startswith("observation.images.")]
|
||||||
|
|
||||||
|
# Convert each image
|
||||||
|
for key in keys_to_convert:
|
||||||
|
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||||
|
tensor = obs[key]
|
||||||
|
|
||||||
|
# Validate that values are in [0, 255] range if requested
|
||||||
|
if self.validate_range:
|
||||||
|
min_val = tensor.min().item()
|
||||||
|
max_val = tensor.max().item()
|
||||||
|
if min_val < 0.0 or max_val > 255.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image '{key}' has values outside [0, 255] range: "
|
||||||
|
f"min={min_val:.4f}, max={max_val:.4f}. "
|
||||||
|
f"Cannot convert to [0, 1] range."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to float and divide by 255
|
||||||
|
obs[key] = tensor.float() / 255.0
|
||||||
|
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = obs
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""Image conversion doesn't change feature structure."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"image_keys": self.image_keys,
|
||||||
|
"validate_range": self.validate_range,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_imagenet_normalize")
|
||||||
|
class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
|
||||||
|
"""Normalize image observations using ImageNet statistics.
|
||||||
|
|
||||||
|
This processor step applies ImageNet normalization (mean and std) to image observations.
|
||||||
|
It validates that input values are in the [0, 1] range before normalizing.
|
||||||
|
|
||||||
|
The normalization formula is: (image - mean) / std
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_keys: List of observation keys that contain images to normalize.
|
||||||
|
If None, will automatically detect keys starting with "observation.images."
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If image values are not in the [0, 1] range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_keys: list[str] | None = None
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Normalize image observations using ImageNet statistics."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
if obs is None:
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
# Make a copy of observations to avoid modifying the original
|
||||||
|
obs = obs.copy()
|
||||||
|
|
||||||
|
# Determine which keys to normalize
|
||||||
|
keys_to_normalize = self.image_keys
|
||||||
|
if keys_to_normalize is None:
|
||||||
|
# Auto-detect image keys
|
||||||
|
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
|
||||||
|
|
||||||
|
# Normalize each image
|
||||||
|
for key in keys_to_normalize:
|
||||||
|
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||||
|
tensor = obs[key]
|
||||||
|
|
||||||
|
# Validate that values are in [0, 1] range
|
||||||
|
min_val = tensor.min().item()
|
||||||
|
max_val = tensor.max().item()
|
||||||
|
if min_val < 0.0 or max_val > 1.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image '{key}' has values outside [0, 1] range: "
|
||||||
|
f"min={min_val:.4f}, max={max_val:.4f}. "
|
||||||
|
f"ImageNet normalization requires input values in [0, 1]."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply ImageNet normalization
|
||||||
|
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
|
||||||
|
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
|
||||||
|
|
||||||
|
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
|
||||||
|
while mean.dim() < tensor.dim():
|
||||||
|
mean = mean.unsqueeze(0)
|
||||||
|
std = std.unsqueeze(0)
|
||||||
|
|
||||||
|
# Normalize: (image - mean) / std
|
||||||
|
obs[key] = (tensor - mean) / std
|
||||||
|
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = obs
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""ImageNet normalization doesn't change feature structure."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"image_keys": self.image_keys,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_add_domain_id")
|
||||||
|
class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
||||||
|
"""Add domain_id to complementary data.
|
||||||
|
|
||||||
|
This processor step adds a domain_id tensor to the complementary data,
|
||||||
|
which is used by XVLA to identify different robot embodiments or task domains.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain_id: The domain ID to add (default: 3)
|
||||||
|
"""
|
||||||
|
|
||||||
|
domain_id: int = 0
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Add domain_id to complementary data."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
|
comp = {} if comp is None else comp.copy()
|
||||||
|
|
||||||
|
# Infer batch size from observation tensors
|
||||||
|
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
batch_size = 1
|
||||||
|
if obs:
|
||||||
|
for v in obs.values():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
batch_size = v.shape[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add domain_id tensor
|
||||||
|
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
|
||||||
|
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""Domain ID addition doesn't change feature structure."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"domain_id": self.domain_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="xvla_rotation_6d_to_axis_angle")
|
||||||
|
class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
|
||||||
|
"""Convert 6D rotation representation to axis-angle and reorganize action dimensions.
|
||||||
|
|
||||||
|
This processor step takes actions with 6D rotation representation and converts them to
|
||||||
|
axis-angle representation, reorganizing the action dimensions as:
|
||||||
|
- action[:, :3] -> target_eef (end-effector position)
|
||||||
|
- action[:, 3:9] -> 6D rotation (converted to axis-angle, 3D)
|
||||||
|
- action[:, 9:10] -> gripper action
|
||||||
|
|
||||||
|
Final output: [target_eef (3), axis_angle (3), gripper (1)] = 7D action
|
||||||
|
|
||||||
|
Args:
|
||||||
|
expected_action_dim: Expected input action dimension (default: 10, supports 6D rotation + extras)
|
||||||
|
"""
|
||||||
|
|
||||||
|
expected_action_dim: int = 10
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
"""Convert 6D rotation to axis-angle in action."""
|
||||||
|
new_transition = transition.copy()
|
||||||
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
|
|
||||||
|
if action is None or not isinstance(action, torch.Tensor):
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
# Convert to numpy for processing
|
||||||
|
device = action.device
|
||||||
|
dtype = action.dtype
|
||||||
|
action_np = action.cpu().numpy()
|
||||||
|
|
||||||
|
# Extract components
|
||||||
|
# action shape: (B, D) where D >= 10
|
||||||
|
target_eef = action_np[:, :3] # (B, 3)
|
||||||
|
rotation_6d = action_np[:, 3:9] # (B, 6)
|
||||||
|
target_act = action_np[:, 9:10] # (B, 1)
|
||||||
|
|
||||||
|
# Convert 6D rotation to axis-angle
|
||||||
|
target_axis = rotate6d_to_axis_angle(rotation_6d) # (B, 3)
|
||||||
|
|
||||||
|
# Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D
|
||||||
|
action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||||
|
|
||||||
|
# Convert gripper action to -1 or 1
|
||||||
|
action_np[:, -1] = np.where(action_np[:, -1] > 0.5, 1.0, -1.0)
|
||||||
|
|
||||||
|
# Convert back to tensor
|
||||||
|
action = torch.from_numpy(action_np).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
new_transition[TransitionKey.ACTION] = action
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
"""Rotation conversion changes action dimension from 10 to 7."""
|
||||||
|
# Note: This is a simplified version. In practice, you might want to
|
||||||
|
# update the action feature shape in the features dict.
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
"""Return serializable configuration."""
|
||||||
|
return {
|
||||||
|
"expected_action_dim": self.expected_action_dim,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def make_xvla_libero_pre_post_processors() -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Build the LeRobot processor pipelines for XVLA with LIBERO environment.
|
||||||
|
"""
|
||||||
|
pre_processor_steps: list[ProcessorStep] = []
|
||||||
|
post_processor_steps: list[ProcessorStep] = []
|
||||||
|
pre_processor_steps.extend(
|
||||||
|
[LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()]
|
||||||
|
)
|
||||||
|
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=pre_processor_steps,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=post_processor_steps,
|
||||||
|
),
|
||||||
|
)
|
||||||
415
src/lerobot/policies/xvla/soft_transformer.py
Normal file
415
src/lerobot/policies/xvla/soft_transformer.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
# Copyright 2025 2toINF (https://github.com/2toINF)
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from functools import partial
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as functional
|
||||||
|
|
||||||
|
# ------------------------------- Small utils ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _to_2tuple(x) -> tuple:
|
||||||
|
"""Minimal replacement for timm.layers.to_2tuple."""
|
||||||
|
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
||||||
|
t = tuple(x)
|
||||||
|
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
|
||||||
|
return (x, x)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_sdp_attention() -> bool:
|
||||||
|
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
|
||||||
|
return hasattr(functional, "scaled_dot_product_attention")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------- MLP --------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
"""
|
||||||
|
MLP used in ViT-style blocks.
|
||||||
|
|
||||||
|
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: int | None = None,
|
||||||
|
out_features: int | None = None,
|
||||||
|
norm_layer: type[nn.Module] | None = None,
|
||||||
|
bias: bool | tuple[bool, bool] = True,
|
||||||
|
drop: float | tuple[float, float] = 0.0,
|
||||||
|
use_conv: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
bias = _to_2tuple(bias)
|
||||||
|
drop_probs = _to_2tuple(drop)
|
||||||
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||||
|
|
||||||
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||||
|
self.act = nn.GELU(approximate="tanh")
|
||||||
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||||||
|
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||||
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||||
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop1(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------- Attention ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-Head Self-Attention with optional fused SDPA fallback.
|
||||||
|
|
||||||
|
If PyTorch provides `scaled_dot_product_attention`, it will be used
|
||||||
|
(usually faster and more stable); otherwise we use a manual implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fused_attn: Final[bool]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
qk_norm: bool = False,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim**-0.5
|
||||||
|
self.fused_attn = _has_sdp_attention()
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor, shape [batch_size, seq_len, channels]
|
||||||
|
Input sequence.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor, shape [batch_size, seq_len, channels]
|
||||||
|
Output sequence after MHSA + projection.
|
||||||
|
"""
|
||||||
|
batch_size, seq_len, channels = x.shape
|
||||||
|
qkv = (
|
||||||
|
self.qkv(x)
|
||||||
|
.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||||
|
.permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
)
|
||||||
|
q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
|
if self.fused_attn:
|
||||||
|
x = functional.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||||
|
) # [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len]
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v # [batch_size, num_heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels]
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------- Utilities -----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def basic_init(module: nn.Module) -> None:
|
||||||
|
"""
|
||||||
|
Apply a basic initialization scheme to Linear layers.
|
||||||
|
|
||||||
|
- Weight: Xavier uniform initialization.
|
||||||
|
- Bias: Set to zero.
|
||||||
|
"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.constant_(module.bias, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
t : torch.Tensor
|
||||||
|
Shape [B]. Each element is a timestep index, may be fractional.
|
||||||
|
dim : int
|
||||||
|
Dimensionality of the output embedding.
|
||||||
|
max_period : int, default=100
|
||||||
|
Controls the minimum frequency of the sinusoids.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Shape [B, dim]. Sinusoidal embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
|
||||||
|
)
|
||||||
|
args = t[:, None] * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2 == 1:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------- Core Layers ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class DomainAwareLinear(nn.Module):
|
||||||
|
"""
|
||||||
|
Linear layer with domain-conditioned parameters (per-sample).
|
||||||
|
|
||||||
|
Each domain has its own weight and bias vectors, stored in embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.output_size = output_size
|
||||||
|
self.fc = nn.Embedding(num_domains, output_size * input_size)
|
||||||
|
self.bias = nn.Embedding(num_domains, output_size)
|
||||||
|
nn.init.xavier_uniform_(self.fc.weight)
|
||||||
|
nn.init.zeros_(self.bias.weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
[B, I] or [B, T, I]
|
||||||
|
domain_id : LongTensor
|
||||||
|
[B], domain indices.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
[batch_size, output_size] or [batch_size, seq_len, output_size]
|
||||||
|
"""
|
||||||
|
batch_size = domain_id.shape[0]
|
||||||
|
squeeze_seq = False
|
||||||
|
if x.dim() == 2:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
squeeze_seq = True
|
||||||
|
weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size)
|
||||||
|
bias = self.bias(domain_id).view(batch_size, self.output_size)
|
||||||
|
y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size)
|
||||||
|
if squeeze_seq:
|
||||||
|
y = y.squeeze(1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_size)
|
||||||
|
self.norm2 = nn.LayerNorm(hidden_size)
|
||||||
|
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=hidden_size,
|
||||||
|
hidden_features=int(hidden_size * mlp_ratio),
|
||||||
|
drop=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor, [B, T, H]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor, [B, T, H]
|
||||||
|
"""
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------- Main Model ---------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SoftPromptedTransformer(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-modal, domain-aware Transformer with optional soft prompts.
|
||||||
|
|
||||||
|
See parameter and forward I/O descriptions inside the docstrings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 768,
|
||||||
|
multi_modal_input_size: int = 768,
|
||||||
|
depth: int = 24,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
num_domains: int = 20,
|
||||||
|
dim_action: int = 20,
|
||||||
|
dim_propio: int = 20,
|
||||||
|
dim_time: int = 32,
|
||||||
|
len_soft_prompts: int = 32,
|
||||||
|
max_len_seq: int = 512,
|
||||||
|
use_hetero_proj: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.dim_action = dim_action
|
||||||
|
self.dim_time = dim_time
|
||||||
|
self.len_soft_prompts = len_soft_prompts
|
||||||
|
self.use_hetero_proj = use_hetero_proj
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_hetero_proj:
|
||||||
|
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
||||||
|
self.aux_visual_proj = DomainAwareLinear(
|
||||||
|
multi_modal_input_size, hidden_size, num_domains=num_domains
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||||
|
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||||
|
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
|
||||||
|
nn.init.normal_(self.pos_emb, std=0.02)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(hidden_size)
|
||||||
|
self.action_encoder = DomainAwareLinear(
|
||||||
|
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
|
||||||
|
)
|
||||||
|
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
|
||||||
|
|
||||||
|
if len_soft_prompts > 0:
|
||||||
|
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
|
||||||
|
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
|
||||||
|
|
||||||
|
self.apply(basic_init)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
domain_id: torch.LongTensor,
|
||||||
|
vlm_features: torch.Tensor,
|
||||||
|
aux_visual_inputs: torch.Tensor,
|
||||||
|
action_with_noise: torch.Tensor,
|
||||||
|
proprio: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass.
|
||||||
|
|
||||||
|
Inputs
|
||||||
|
------
|
||||||
|
domain_id : [B]
|
||||||
|
vlm_features : [B, T_vlm, D]
|
||||||
|
aux_visual_inputs : [B, T_aux, D]
|
||||||
|
action_with_noise : [B, T_action, dim_action]
|
||||||
|
proprio : [B, dim_propio]
|
||||||
|
t : [B]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Predicted actions, [batch_size, num_actions, dim_action]
|
||||||
|
"""
|
||||||
|
batch_size, num_actions = action_with_noise.shape[:2]
|
||||||
|
|
||||||
|
# Encode (action + proprio + time) → tokens
|
||||||
|
time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time]
|
||||||
|
time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time)
|
||||||
|
proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1])
|
||||||
|
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
|
||||||
|
x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size]
|
||||||
|
|
||||||
|
# Project visual streams and concatenate
|
||||||
|
if self.use_hetero_proj:
|
||||||
|
x = torch.cat(
|
||||||
|
[
|
||||||
|
x,
|
||||||
|
self.vlm_proj(vlm_features, domain_id),
|
||||||
|
self.aux_visual_proj(aux_visual_inputs, domain_id),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
|
||||||
|
|
||||||
|
# Add positional embeddings (truncate if needed)
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
if seq_len > self.pos_emb.shape[1]:
|
||||||
|
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
|
||||||
|
x = x + self.pos_emb[:, :seq_len, :]
|
||||||
|
|
||||||
|
# Append soft prompts
|
||||||
|
if self.len_soft_prompts > 0:
|
||||||
|
soft_prompts = self.soft_prompt_hub(domain_id).view(
|
||||||
|
batch_size, self.len_soft_prompts, self.hidden_size
|
||||||
|
)
|
||||||
|
x = torch.cat([x, soft_prompts], dim=1)
|
||||||
|
|
||||||
|
# Transformer backbone
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
# Decode only the action segment
|
||||||
|
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
|
||||||
138
src/lerobot/policies/xvla/utils.py
Normal file
138
src/lerobot/policies/xvla/utils.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def mat2quat(rmat):
|
||||||
|
"""
|
||||||
|
Converts given rotation matrix to quaternion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rmat (np.array): 3x3 rotation matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: (x,y,z,w) float quaternion angles
|
||||||
|
"""
|
||||||
|
mat = np.asarray(rmat).astype(np.float32)[:3, :3]
|
||||||
|
|
||||||
|
m00 = mat[0, 0]
|
||||||
|
m01 = mat[0, 1]
|
||||||
|
m02 = mat[0, 2]
|
||||||
|
m10 = mat[1, 0]
|
||||||
|
m11 = mat[1, 1]
|
||||||
|
m12 = mat[1, 2]
|
||||||
|
m20 = mat[2, 0]
|
||||||
|
m21 = mat[2, 1]
|
||||||
|
m22 = mat[2, 2]
|
||||||
|
# symmetric matrix k
|
||||||
|
k = np.array(
|
||||||
|
[
|
||||||
|
[m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
|
||||||
|
[m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
|
||||||
|
[m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
|
||||||
|
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
k /= 3.0
|
||||||
|
# quaternion is Eigen vector of k that corresponds to largest eigenvalue
|
||||||
|
w, v = np.linalg.eigh(k)
|
||||||
|
inds = np.array([3, 0, 1, 2])
|
||||||
|
q1 = v[inds, np.argmax(w)]
|
||||||
|
if q1[0] < 0.0:
|
||||||
|
np.negative(q1, q1)
|
||||||
|
inds = np.array([1, 2, 3, 0])
|
||||||
|
return q1[inds]
|
||||||
|
|
||||||
|
|
||||||
|
def quat2axisangle(quat):
|
||||||
|
"""
|
||||||
|
Converts quaternion to axis-angle format.
|
||||||
|
Returns a unit vector direction scaled by its angle in radians.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quat (np.array): (x,y,z,w) vec4 float angles
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||||
|
"""
|
||||||
|
# clip quaternion
|
||||||
|
if quat[3] > 1.0:
|
||||||
|
quat[3] = 1.0
|
||||||
|
elif quat[3] < -1.0:
|
||||||
|
quat[3] = -1.0
|
||||||
|
|
||||||
|
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||||
|
if math.isclose(den, 0.0):
|
||||||
|
# This is (close to) a zero degree rotation, immediately return
|
||||||
|
return np.zeros(3)
|
||||||
|
|
||||||
|
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||||
|
|
||||||
|
|
||||||
|
def rotate6d_to_axis_angle(r6d):
|
||||||
|
"""
|
||||||
|
r6d: np.ndarray, shape (N, 6)
|
||||||
|
return: np.ndarray, shape (N, 3), axis-angle vectors
|
||||||
|
"""
|
||||||
|
flag = 0
|
||||||
|
if len(r6d.shape) == 1:
|
||||||
|
r6d = r6d[None, ...]
|
||||||
|
flag = 1
|
||||||
|
|
||||||
|
a1 = r6d[:, 0:3]
|
||||||
|
a2 = r6d[:, 3:6]
|
||||||
|
|
||||||
|
# b1
|
||||||
|
b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-6)
|
||||||
|
|
||||||
|
# b2
|
||||||
|
dot_prod = np.sum(b1 * a2, axis=-1, keepdims=True)
|
||||||
|
b2_orth = a2 - dot_prod * b1
|
||||||
|
b2 = b2_orth / (np.linalg.norm(b2_orth, axis=-1, keepdims=True) + 1e-6)
|
||||||
|
|
||||||
|
# b3
|
||||||
|
b3 = np.cross(b1, b2, axis=-1)
|
||||||
|
|
||||||
|
rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
|
||||||
|
|
||||||
|
axis_angle_list = []
|
||||||
|
for i in range(rotation_matrix.shape[0]):
|
||||||
|
quat = mat2quat(rotation_matrix[i])
|
||||||
|
axis_angle = quat2axisangle(quat)
|
||||||
|
axis_angle_list.append(axis_angle)
|
||||||
|
|
||||||
|
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
|
||||||
|
|
||||||
|
if flag == 1:
|
||||||
|
axis_angle_array = axis_angle_array[0]
|
||||||
|
|
||||||
|
return axis_angle_array
|
||||||
|
|
||||||
|
|
||||||
|
def mat_to_rotate6d(abs_action):
|
||||||
|
if len(abs_action.shape) == 2:
|
||||||
|
return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
|
||||||
|
elif len(abs_action.shape) == 3:
|
||||||
|
return np.concatenate([abs_action[:, :3, 0], abs_action[:, :3, 1]], axis=-1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||||
|
'survival rate' as the argument.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if drop_prob == 0.0 or not training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||||
|
if keep_prob > 0.0 and scale_by_keep:
|
||||||
|
random_tensor.div_(keep_prob)
|
||||||
|
return x * random_tensor
|
||||||
@@ -533,7 +533,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy_all(
|
info = eval_policy_all(
|
||||||
|
|||||||
@@ -260,7 +260,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
if cfg.env is not None:
|
if cfg.env is not None:
|
||||||
logging.info(f"{cfg.env.task=}")
|
logging.info(f"{cfg.env.task=}")
|
||||||
logging.info("Creating environment processors")
|
logging.info("Creating environment processors")
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
|
||||||
|
env_cfg=cfg.env, policy_cfg=cfg.policy
|
||||||
|
)
|
||||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||||
logging.info(f"{dataset.num_episodes=}")
|
logging.info(f"{dataset.num_episodes=}")
|
||||||
|
|||||||
361
tests/policies/xvla/test_xvla_original_vs_lerobot.py
Normal file
361
tests/policies/xvla/test_xvla_original_vs_lerobot.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||||
|
# ruff: noqa: E402
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import random
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pytest.importorskip("transformers")
|
||||||
|
|
||||||
|
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||||
|
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||||
|
from lerobot.policies.xvla.processor_xvla import make_xvla_pre_post_processors
|
||||||
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
|
||||||
|
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||||
|
from tests.utils import require_cuda # noqa: E402
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
DUMMY_ACTION_DIM = 7 # Standard robot arm action dimension
|
||||||
|
DUMMY_STATE_DIM = 20 # Proprioceptive state dimension
|
||||||
|
IMAGE_HEIGHT = 224
|
||||||
|
IMAGE_WIDTH = 224
|
||||||
|
NUM_VIEWS = 2 # Number of camera views
|
||||||
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
MODEL_PATH_LEROBOT = "lerobot/xvla-widowx"
|
||||||
|
LIBERO_DOMAIN_ID = 0 # Domain ID for examples purposes
|
||||||
|
|
||||||
|
# Expected values from original XVLA implementation (reference values)
|
||||||
|
EXPECTED_ACTIONS_SHAPE = (30, 20)
|
||||||
|
EXPECTED_ACTIONS_MEAN = 0.117606
|
||||||
|
EXPECTED_ACTIONS_STD = 0.245411
|
||||||
|
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.2742, 0.4977, 0.0500, 0.7040, -0.2653])
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_memory():
|
||||||
|
"""Clean up GPU/MPS memory to prevent OOM errors between tests."""
|
||||||
|
print("\nCleaning up memory...")
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
print("Memory cleanup complete.")
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed_all(seed: int):
|
||||||
|
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
# Set deterministic behavior
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
def instantiate_lerobot_xvla(
|
||||||
|
from_pretrained: bool = False,
|
||||||
|
model_path: str = MODEL_PATH_LEROBOT,
|
||||||
|
) -> tuple[
|
||||||
|
Any, # Policy
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""Instantiate LeRobot XVLA policy with preprocessor and postprocessor."""
|
||||||
|
if from_pretrained:
|
||||||
|
policy = XVLAPolicy.from_pretrained(
|
||||||
|
pretrained_name_or_path=model_path,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = XVLAConfig(
|
||||||
|
base_model_path=model_path,
|
||||||
|
n_action_steps=DUMMY_ACTION_DIM,
|
||||||
|
chunk_size=DUMMY_ACTION_DIM,
|
||||||
|
device=DEVICE,
|
||||||
|
num_image_views=NUM_VIEWS,
|
||||||
|
) # add resize_imgs_with_padding=IMAGE_SIZE, IMAGE_SIZE?
|
||||||
|
policy = XVLAPolicy(config)
|
||||||
|
|
||||||
|
policy.to(DEVICE)
|
||||||
|
policy.config.device = DEVICE
|
||||||
|
preprocessor, postprocessor = make_xvla_pre_post_processors(
|
||||||
|
config=policy.config,
|
||||||
|
dataset_stats=None, # Pass None for dataset_stats to disable normalization (original XVLA doesn't normalize)
|
||||||
|
)
|
||||||
|
|
||||||
|
return policy, preprocessor, postprocessor
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_data(device=DEVICE):
|
||||||
|
"""Create dummy data for testing both implementations."""
|
||||||
|
batch_size = 1
|
||||||
|
prompt = "Pick up the red block and place it in the bin"
|
||||||
|
|
||||||
|
# Create random RGB images in [0, 255] uint8 range (as PIL images would be)
|
||||||
|
# Then convert to [0, 1] float32 range for LeRobot
|
||||||
|
def fake_rgb(h, w):
|
||||||
|
arr = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||||
|
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
|
||||||
|
return t
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
f"{OBS_IMAGES}.image": torch.stack(
|
||||||
|
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
|
||||||
|
).to(device),
|
||||||
|
f"{OBS_IMAGES}.image2": torch.stack(
|
||||||
|
[fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
|
||||||
|
).to(device),
|
||||||
|
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
|
||||||
|
"task": [prompt for _ in range(batch_size)],
|
||||||
|
}
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
# Pytest fixtures
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def xvla_components():
|
||||||
|
"""Fixture to instantiate and provide all XVLA components for tests."""
|
||||||
|
print(f"\nTesting with DEVICE='{DEVICE}'")
|
||||||
|
print("\n[Setup] Instantiating LeRobot XVLA policy...")
|
||||||
|
policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_xvla(from_pretrained=True)
|
||||||
|
print("✔️ Model loaded successfully")
|
||||||
|
yield policy_obj, preprocessor_obj, postprocessor_obj
|
||||||
|
cleanup_memory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def policy(xvla_components):
|
||||||
|
"""Fixture to provide the XVLA policy for tests."""
|
||||||
|
return xvla_components[0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def preprocessor(xvla_components):
|
||||||
|
"""Fixture to provide the XVLA preprocessor for tests."""
|
||||||
|
return xvla_components[1]
|
||||||
|
|
||||||
|
|
||||||
|
@require_cuda
|
||||||
|
def test_xvla_preprocessor_alignment(policy, preprocessor):
|
||||||
|
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Test: XVLA Preprocessor Outputs")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
set_seed_all(42)
|
||||||
|
|
||||||
|
print("\nCreating dummy data...")
|
||||||
|
batch = create_dummy_data()
|
||||||
|
|
||||||
|
print("\n[LeRobot] Preprocessing...")
|
||||||
|
lerobot_observation = preprocessor(deepcopy(batch))
|
||||||
|
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||||
|
|
||||||
|
print("\nVerifying preprocessor outputs:")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# Expected shapes from tester.txt
|
||||||
|
expected_shapes = {
|
||||||
|
"domain_id": (1,),
|
||||||
|
"input_ids": (1, 50),
|
||||||
|
"proprio": (1, 20),
|
||||||
|
"image_mask": (1, 2),
|
||||||
|
"image_input": (1, 2, 3, 224, 224),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, expected_shape in expected_shapes.items():
|
||||||
|
if key in lerobot_inputs:
|
||||||
|
actual_shape = tuple(lerobot_inputs[key].shape)
|
||||||
|
print(f"\nKey: {key}")
|
||||||
|
print(f"Expected shape: {expected_shape}")
|
||||||
|
print(f"Actual shape: {actual_shape}")
|
||||||
|
|
||||||
|
if actual_shape == expected_shape:
|
||||||
|
print("Shape matches!")
|
||||||
|
else:
|
||||||
|
print("Shape mismatch!")
|
||||||
|
|
||||||
|
assert actual_shape == expected_shape, f"Shape mismatch for {key}"
|
||||||
|
else:
|
||||||
|
print(f"\nKey '{key}' not found in inputs!")
|
||||||
|
|
||||||
|
print("\nAll preprocessor outputs have correct shapes!")
|
||||||
|
|
||||||
|
|
||||||
|
@require_cuda
|
||||||
|
def test_xvla_action_generation(policy, preprocessor):
|
||||||
|
"""Test XVLA LeRobot implementation generates expected actions."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Test: XVLA Action Generation Against Expected Values")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
set_seed_all(42)
|
||||||
|
|
||||||
|
print("\nCreating dummy data...")
|
||||||
|
batch = create_dummy_data()
|
||||||
|
|
||||||
|
print("\n[LeRobot] Running inference...")
|
||||||
|
lerobot_observation = preprocessor(deepcopy(batch))
|
||||||
|
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||||
|
|
||||||
|
# Reset seed for inference
|
||||||
|
torch.manual_seed(42)
|
||||||
|
with torch.no_grad():
|
||||||
|
lerobot_actions = policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||||
|
lerobot_actions = lerobot_actions.squeeze(0).float().cpu()
|
||||||
|
|
||||||
|
print(f"LeRobot actions shape: {lerobot_actions.shape}")
|
||||||
|
print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}")
|
||||||
|
print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}")
|
||||||
|
print(f"LeRobot actions first 5: {lerobot_actions[0, :5]}")
|
||||||
|
|
||||||
|
print("\nExpected values (from original XVLA):")
|
||||||
|
print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}")
|
||||||
|
print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}")
|
||||||
|
print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}")
|
||||||
|
print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}")
|
||||||
|
|
||||||
|
print("\nAction Comparison:")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# Compare shapes
|
||||||
|
actual_shape = tuple(lerobot_actions.shape)
|
||||||
|
assert actual_shape == EXPECTED_ACTIONS_SHAPE, (
|
||||||
|
f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}"
|
||||||
|
)
|
||||||
|
print(f"✔️ Shape matches: {actual_shape}")
|
||||||
|
|
||||||
|
# Compare statistics
|
||||||
|
actual_mean = lerobot_actions.mean().item()
|
||||||
|
actual_std = lerobot_actions.std().item()
|
||||||
|
|
||||||
|
mean_diff = abs(actual_mean - EXPECTED_ACTIONS_MEAN)
|
||||||
|
std_diff = abs(actual_std - EXPECTED_ACTIONS_STD)
|
||||||
|
|
||||||
|
print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f}, diff: {mean_diff:.6e})")
|
||||||
|
print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f}, diff: {std_diff:.6e})")
|
||||||
|
|
||||||
|
# Compare first 5 actions
|
||||||
|
actual_first_5 = lerobot_actions[0, :5]
|
||||||
|
first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5)
|
||||||
|
|
||||||
|
print("\nFirst 5 actions comparison:")
|
||||||
|
print(f" Actual: {actual_first_5}")
|
||||||
|
print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}")
|
||||||
|
print(f" Max diff: {first_5_diff.max().item():.6e}")
|
||||||
|
print(f" Mean diff: {first_5_diff.mean().item():.6e}")
|
||||||
|
|
||||||
|
# Check with different tolerances
|
||||||
|
tolerances = [1e-5, 1e-4, 1e-3, 1e-2]
|
||||||
|
for tol in tolerances:
|
||||||
|
is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol)
|
||||||
|
status = "Success" if is_close else "Failure"
|
||||||
|
print(f"{status}: First 5 actions close (atol={tol}): {is_close}")
|
||||||
|
|
||||||
|
# Assert with reasonable tolerance
|
||||||
|
tolerance = 1e-3
|
||||||
|
assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), (
|
||||||
|
f"First 5 actions differ by more than tolerance ({tolerance})"
|
||||||
|
)
|
||||||
|
print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!")
|
||||||
|
|
||||||
|
|
||||||
|
@require_cuda
|
||||||
|
def test_xvla_inference_reproducibility(policy, preprocessor):
|
||||||
|
"""Test that XVLA inference is reproducible with the same seed."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("Test: XVLA Inference Reproducibility")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
print("\nCreating dummy data...")
|
||||||
|
batch = create_dummy_data()
|
||||||
|
|
||||||
|
# First inference
|
||||||
|
print("\n[Run 1] Running inference...")
|
||||||
|
set_seed_all(42)
|
||||||
|
lerobot_observation = preprocessor(deepcopy(batch))
|
||||||
|
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||||
|
with torch.no_grad():
|
||||||
|
actions_1 = policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||||
|
actions_1 = actions_1.squeeze(0).float().cpu()
|
||||||
|
|
||||||
|
# Second inference with same seed
|
||||||
|
print("\n[Run 2] Running inference with same seed...")
|
||||||
|
set_seed_all(42)
|
||||||
|
lerobot_observation = preprocessor(deepcopy(batch))
|
||||||
|
lerobot_inputs = policy._build_model_inputs(lerobot_observation)
|
||||||
|
with torch.no_grad():
|
||||||
|
actions_2 = policy.model.generate_actions(**lerobot_inputs, steps=10)
|
||||||
|
actions_2 = actions_2.squeeze(0).float().cpu()
|
||||||
|
|
||||||
|
print("\nComparing two runs:")
|
||||||
|
print("-" * 80)
|
||||||
|
if torch.allclose(actions_1, actions_2, atol=1e-8):
|
||||||
|
print("Inference is perfectly reproducible!")
|
||||||
|
else:
|
||||||
|
diff = torch.abs(actions_1 - actions_2)
|
||||||
|
print("Small differences detected:")
|
||||||
|
print(f" Max diff: {diff.max().item():.6e}")
|
||||||
|
print(f" Mean diff: {diff.mean().item():.6e}")
|
||||||
|
|
||||||
|
assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!"
|
||||||
|
|
||||||
|
print("\nInference is reproducible!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("XVLA LeRobot Validation Test Suite")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize model once for all tests
|
||||||
|
print("\n[Setup] Instantiating LeRobot XVLA policy...")
|
||||||
|
policy, preprocessor, postprocessor = instantiate_lerobot_xvla(from_pretrained=True)
|
||||||
|
print("✔️ Model loaded successfully")
|
||||||
|
|
||||||
|
# Run all tests with the same model instance
|
||||||
|
test_xvla_preprocessor_alignment(policy, preprocessor)
|
||||||
|
test_xvla_action_generation(policy, preprocessor)
|
||||||
|
test_xvla_inference_reproducibility(policy, preprocessor)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("All tests passed!")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
cleanup_memory()
|
||||||
|
except Exception as e:
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print(f"Test failed with error: {e}")
|
||||||
|
print("=" * 80)
|
||||||
|
cleanup_memory()
|
||||||
|
raise
|
||||||
Reference in New Issue
Block a user