Compare commits

..

1 Commits

34 changed files with 434 additions and 790 deletions

View File

@@ -61,7 +61,6 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -90,10 +89,5 @@ jobs:
- name: Install lerobot with test extras
run: uv sync --extra "test"
- name: Login to Hugging Face
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest
run: uv run pytest tests -vv --maxfail=10

View File

@@ -60,7 +60,6 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -88,11 +87,6 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest (all extras)
run: uv run pytest tests -vv --maxfail=10
@@ -168,7 +162,6 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -180,10 +173,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Fix ptxas permissions
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
- name: Run pytest on GPU

View File

@@ -119,7 +119,6 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --shm-size "16gb"
@@ -131,10 +130,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
echo "$HF_USER_TOKEN" | hf auth login --token --add-to-git-credential
hf auth whoami
- name: Run pytest on CPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -151,7 +146,6 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -163,10 +157,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
echo "$HF_USER_TOKEN" | hf auth login --token --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
@@ -184,7 +174,6 @@ jobs:
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
CUDA_VISIBLE_DEVICES: "0,1,2,3"
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -196,10 +185,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
echo "$HF_USER_TOKEN" | hf auth login --token --add-to-git-credential
hf auth whoami
- name: Verify GPU availability
run: |
nvidia-smi
@@ -207,4 +192,5 @@ jobs:
- name: Run multi-GPU training tests
# TODO(Steven): Investigate why motors tests are failing in multi-GPU setup
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/
timeout-minutes: 10

View File

@@ -48,7 +48,6 @@ jobs:
MUJOCO_GL: egl
HF_HOME: /mnt/cache/.cache/huggingface
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
steps:
- uses: actions/checkout@v6
with:
@@ -80,10 +79,7 @@ jobs:
- name: Install lerobot with all extras
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
- name: Login to Hugging Face
run: |
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
uv run hf auth whoami
- name: Run pytest (all extras)
run: uv run pytest tests -vv
@@ -141,7 +137,6 @@ jobs:
HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot
TORCH_HOME: /home/user_lerobot/.cache/torch
TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
container:
image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images]
options: --gpus all --shm-size "16gb"
@@ -153,10 +148,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Login to Hugging Face
run: |
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
hf auth whoami
- name: Run pytest on GPU
run: pytest tests -vv
- name: Run end-to-end tests

View File

@@ -52,7 +52,7 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
You have two options for the FAST tokenizer:
1. **Use the pre-trained tokenizer**: The `lerobot/fast-action-tokenizer` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer.
2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data.
@@ -114,15 +114,15 @@ lerobot-train \
### Key Training Parameters
| Parameter | Description | Default |
| -------------------------------------- | -------------------------------------------------- | ------------------------------- |
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `lerobot/fast-action-tokenizer` |
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
| Parameter | Description | Default |
| -------------------------------------- | -------------------------------------------------- | ---------------------------- |
| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` |
| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` |
| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` |
| `--policy.n_action_steps` | Number of action steps to execute | `50` |
| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` |
| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` |
| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` |
## Inference

View File

@@ -61,7 +61,7 @@ dependencies = [
# Hugging Face dependencies
"datasets>=4.0.0,<5.0.0",
"diffusers>=0.27.2,<0.36.0",
"huggingface-hub[cli]>=1.0.0,<2.0.0",
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
"accelerate>=1.10.0,<2.0.0",
# Core dependencies
@@ -96,12 +96,9 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1,<2.7.0"]
placo-dep = ["placo>=0.9.6,<0.10.0"]
transformers-dep = ["transformers>=5.3.0,<6.0.0"]
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
can-dep = ["python-can>=4.2.0,<5.0.0"]
peft-dep = ["peft>=0.18.0,<1.0.0"]
scipy-dep = ["scipy>=1.14.0,<2.0.0"]
qwen-vl-utils-dep = ["qwen-vl-utils>=0.0.11,<0.1.0"]
# Motors
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
@@ -132,17 +129,17 @@ phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0", "fastapi<1.0"]
# Policies
wallx = [
"lerobot[transformers-dep]",
"lerobot[peft]",
"lerobot[scipy-dep]",
"torchdiffeq>=0.2.4,<0.3.0",
"lerobot[qwen-vl-utils-dep]",
"transformers==4.49.0",
"peft==0.17.1",
"scipy==1.15.3",
"torchdiffeq==0.2.5",
"qwen_vl_utils==0.0.11"
]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
groot = [
"lerobot[transformers-dep]",
"lerobot[peft]",
"peft>=0.13.0,<1.0.0",
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"safetensors>=0.4.3,<1.0.0",
@@ -151,13 +148,13 @@ groot = [
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "lerobot[qwen-vl-utils-dep]"]
sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"]
xvla = ["lerobot[transformers-dep]"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
# Features
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"]
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
# Development
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
@@ -179,8 +176,8 @@ all = [
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[wallx]",
"lerobot[pi]",
# "lerobot[wallx]",
# "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[xvla]",
@@ -400,3 +397,85 @@ ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.scripts.*"
# ignore_errors = false
[tool.uv]
# wallx requires transformers==4.49.0 which conflicts with other extras that need >=4.53.0
conflicts = [
[
{ extra = "wallx" },
{ extra = "transformers-dep" },
],
[
{ extra = "wallx" },
{ extra = "pi" },
],
[
{ extra = "wallx" },
{ extra = "smolvla" },
],
[
{ extra = "wallx" },
{ extra = "groot" },
],
[
{ extra = "wallx" },
{ extra = "xvla" },
],
[
{ extra = "wallx" },
{ extra = "sarm" },
],
[
{ extra = "wallx" },
{ extra = "hilserl" },
],
[
{ extra = "wallx" },
{ extra = "libero" },
],
[
{ extra = "wallx" },
{ extra = "peft" },
],
[
{ extra = "wallx" },
{ extra = "all" },
],
# pi uses custom branch which conflicts with transformers-dep
[
{ extra = "pi" },
{ extra = "transformers-dep" },
],
[
{ extra = "pi" },
{ extra = "smolvla" },
],
[
{ extra = "pi" },
{ extra = "groot" },
],
[
{ extra = "pi" },
{ extra = "xvla" },
],
[
{ extra = "pi" },
{ extra = "sarm" },
],
[
{ extra = "pi" },
{ extra = "hilserl" },
],
[
{ extra = "pi" },
{ extra = "libero" },
],
[
{ extra = "pi" },
{ extra = "peft" },
],
[
{ extra = "pi" },
{ extra = "all" },
],
]

View File

@@ -21,6 +21,8 @@ from pathlib import Path
import datasets
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import tqdm
from lerobot.datasets.compute_stats import aggregate_stats
@@ -35,7 +37,6 @@ from lerobot.datasets.utils import (
get_file_size_in_mb,
get_hf_features_from_features,
get_parquet_file_size_in_mb,
to_parquet_with_hf_images,
update_chunk_file_indices,
write_info,
write_stats,
@@ -80,28 +81,41 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
return fps, robot_type, features
def update_data_df(df, src_meta, dst_meta):
"""Updates a data DataFrame with new indices and task mappings for aggregation.
def update_data_table(table: pa.Table, src_meta, dst_meta) -> pa.Table:
"""Updates a pyarrow Table with new indices and task mappings for aggregation.
Adjusts episode indices, frame indices, and task indices to account for
previously aggregated data in the destination dataset.
Args:
df: DataFrame containing the data to be updated.
table: pyarrow Table containing the data to be updated.
src_meta: Source dataset metadata.
dst_meta: Destination dataset metadata.
Returns:
pd.DataFrame: Updated DataFrame with adjusted indices.
pa.Table: Updated Table with adjusted indices.
"""
ep_offset = dst_meta.info["total_episodes"]
idx_offset = dst_meta.info["total_frames"]
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
df["index"] = df["index"] + dst_meta.info["total_frames"]
ep_col = table.column("episode_index")
new_ep = pa.array([v + ep_offset for v in ep_col.to_pylist()], type=ep_col.type)
table = table.set_column(table.column_names.index("episode_index"), "episode_index", new_ep)
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
idx_col = table.column("index")
new_idx = pa.array([v + idx_offset for v in idx_col.to_pylist()], type=idx_col.type)
table = table.set_column(table.column_names.index("index"), "index", new_idx)
return df
old_task_indices = table.column("task_index").to_pylist()
src_task_names = src_meta.tasks.index.take(old_task_indices)
new_task_indices = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy().tolist()
table = table.set_column(
table.column_names.index("task_index"),
"task_index",
pa.array(new_task_indices, type=table.schema.field("task_index").type),
)
return table
def update_meta_data(
@@ -468,18 +482,13 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
src_path = src_meta.root / DEFAULT_DATA_PATH.format(
chunk_index=src_chunk_idx, file_index=src_file_idx
)
if contains_images:
# Use HuggingFace datasets to read source data to preserve image format
src_ds = datasets.Dataset.from_parquet(str(src_path))
df = src_ds.to_pandas()
else:
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
table = pq.read_table(src_path)
table = update_data_table(table, src_meta, dst_meta)
# Write data and get the actual destination file it was written to
# This avoids duplicating the rotation logic here
data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file(
df,
table,
src_path,
data_idx,
data_files_size_in_mb,
@@ -554,8 +563,16 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
return meta_idx
def _write_table_with_hf_images(
table: pa.Table, path: Path, features: datasets.Features | None = None
) -> None:
"""Write a pyarrow Table to parquet with proper HF image encoding."""
ds = datasets.Dataset.from_dict(table.to_pydict(), features=features)
ds.to_parquet(path)
def append_or_create_parquet_file(
df: pd.DataFrame,
data: pd.DataFrame | pa.Table,
src_path: Path,
idx: dict[str, int],
max_mb: float,
@@ -571,7 +588,7 @@ def append_or_create_parquet_file(
from becoming too large. Handles both regular parquet files and those containing images.
Args:
df: DataFrame to write to the parquet file.
data: Data to write, as a pandas DataFrame or pyarrow Table.
src_path: Path to the source file (used for size estimation).
idx: Dictionary containing current 'chunk' and 'file' indices.
max_mb: Maximum allowed file size in MB before rotation.
@@ -585,15 +602,17 @@ def append_or_create_parquet_file(
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
and (dst_chunk, dst_file) is the actual destination file the data was written to.
"""
table = data if isinstance(data, pa.Table) else pa.Table.from_pandas(data)
dst_chunk, dst_file = idx["chunk"], idx["file"]
dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
if contains_images:
to_parquet_with_hf_images(df, dst_path, features=hf_features)
_write_table_with_hf_images(table, dst_path, features=hf_features)
else:
df.to_parquet(dst_path)
pq.write_table(table, dst_path)
return idx, (dst_chunk, dst_file)
src_size = get_parquet_file_size_in_mb(src_path)
@@ -604,22 +623,17 @@ def append_or_create_parquet_file(
dst_chunk, dst_file = idx["chunk"], idx["file"]
new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file)
new_path.parent.mkdir(parents=True, exist_ok=True)
final_df = df
final_table = table
target_path = new_path
else:
if contains_images:
# Use HuggingFace datasets to read existing data to preserve image format
existing_ds = datasets.Dataset.from_parquet(str(dst_path))
existing_df = existing_ds.to_pandas()
else:
existing_df = pd.read_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
existing_table = pq.read_table(dst_path)
final_table = pa.concat_tables([existing_table, table], promote_options="permissive")
target_path = dst_path
if contains_images:
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
_write_table_with_hf_images(final_table, target_path, features=hf_features)
else:
final_df.to_parquet(target_path)
pq.write_table(final_table, target_path)
return idx, (dst_chunk, dst_file)

View File

@@ -32,6 +32,9 @@ from pathlib import Path
import datasets
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as pa_ds
import pyarrow.parquet as pq
import torch
from tqdm import tqdm
@@ -496,13 +499,16 @@ def _copy_and_reindex_data(
global_index = 0
episode_data_metadata: dict[int, dict] = {}
episode_keys = list(episode_mapping.keys())
ep_filter = pa_ds.field("episode_index").isin(episode_keys)
if dst_meta.tasks is None:
all_task_indices = set()
all_task_indices: set[int] = set()
for src_path in file_to_episodes:
df = pd.read_parquet(src_dataset.root / src_path)
mask = df["episode_index"].isin(list(episode_mapping.keys()))
task_series: pd.Series = df[mask]["task_index"]
all_task_indices.update(task_series.unique().tolist())
table = pq.read_table(
src_dataset.root / src_path, columns=["episode_index", "task_index"], filters=ep_filter
)
all_task_indices.update(pc.unique(table.column("task_index")).to_pylist())
tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices]
dst_meta.save_episode_tasks(list(set(tasks)))
@@ -514,52 +520,41 @@ def _copy_and_reindex_data(
task_mapping[old_task_idx] = new_task_idx
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
df = pd.read_parquet(src_dataset.root / src_path)
all_episodes_in_file = set(df["episode_index"].unique())
table = pq.read_table(src_dataset.root / src_path, filters=ep_filter)
episodes_to_keep = file_to_episodes[src_path]
if all_episodes_in_file == episodes_to_keep:
df["episode_index"] = df["episode_index"].replace(episode_mapping)
df["index"] = range(global_index, global_index + len(df))
df["task_index"] = df["task_index"].replace(task_mapping)
if table.num_rows == 0:
continue
first_ep_old_idx = min(episodes_to_keep)
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
chunk_idx = src_ep["data/chunk_index"]
file_idx = src_ep["data/file_index"]
else:
mask = df["episode_index"].isin(list(episode_mapping.keys()))
df = df[mask].copy().reset_index(drop=True)
table = _replace_column_values(table, "episode_index", episode_mapping)
col_pos = table.column_names.index("index")
new_indices = pa.array(range(global_index, global_index + table.num_rows), type=pa.int64())
table = table.set_column(col_pos, "index", new_indices)
table = _replace_column_values(table, "task_index", task_mapping)
if len(df) == 0:
continue
df["episode_index"] = df["episode_index"].replace(episode_mapping)
df["index"] = range(global_index, global_index + len(df))
df["task_index"] = df["task_index"].replace(task_mapping)
first_ep_old_idx = min(episodes_to_keep)
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
chunk_idx = src_ep["data/chunk_index"]
file_idx = src_ep["data/file_index"]
first_ep_old_idx = min(episodes_to_keep)
src_ep = src_dataset.meta.episodes[first_ep_old_idx]
chunk_idx = src_ep["data/chunk_index"]
file_idx = src_ep["data/file_index"]
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
_write_parquet(df, dst_path, dst_meta)
_write_parquet(table, dst_path, dst_meta)
ep_col = table.column("episode_index").to_pylist()
idx_col = table.column("index").to_pylist()
for ep_old_idx in episodes_to_keep:
ep_new_idx = episode_mapping[ep_old_idx]
ep_df = df[df["episode_index"] == ep_new_idx]
ep_indices = [idx_col[i] for i, e in enumerate(ep_col) if e == ep_new_idx]
episode_data_metadata[ep_new_idx] = {
"data/chunk_index": chunk_idx,
"data/file_index": file_idx,
"dataset_from_index": int(ep_df["index"].min()),
"dataset_to_index": int(ep_df["index"].max() + 1),
"dataset_from_index": min(ep_indices),
"dataset_to_index": max(ep_indices) + 1,
}
global_index += len(df)
global_index += table.num_rows
return episode_data_metadata
@@ -910,15 +905,39 @@ def _copy_and_reindex_episodes_metadata(
write_stats(filtered_stats, dst_meta.root)
def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None:
"""Write DataFrame to parquet
def _replace_column_values(table: pa.Table, column: str, mapping: dict) -> pa.Table:
"""Replace values in a pyarrow Table column using a mapping dict."""
old_values = table.column(column).to_pylist()
new_values = [mapping.get(v, v) for v in old_values]
col_pos = table.column_names.index(column)
return table.set_column(col_pos, column, pa.array(new_values, type=table.schema.field(column).type))
def _write_parquet(
data: pd.DataFrame | pa.Table | dict, path: Path, meta: LeRobotDatasetMetadata
) -> None:
"""Write data to parquet.
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
Args:
data: Input data as a pandas DataFrame, pyarrow Table, or dict of lists.
path: Destination parquet file path.
meta: Dataset metadata for feature schema.
"""
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
if isinstance(data, pd.DataFrame):
data_dict = data.to_dict(orient="list")
elif isinstance(data, pa.Table):
data_dict = data.to_pydict()
elif isinstance(data, dict):
data_dict = data
else:
raise TypeError(f"Unsupported data type: {type(data)}")
hf_features = get_hf_features_from_features(meta.features)
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
ep_dataset = datasets.Dataset.from_dict(data_dict, features=hf_features, split="train")
if len(meta.image_keys) > 0:
ep_dataset = embed_images(ep_dataset)

View File

@@ -14,7 +14,7 @@ from transformers.image_processing_utils import (
)
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
ImagesKwargs,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
@@ -77,7 +77,7 @@ def crop(img: torch.Tensor, left: int, top: int, right: int, bottom: int) -> tor
return img[:, top:bottom, left:right]
class Eagle25VLFastImageProcessorKwargs(ImagesKwargs):
class Eagle25VLFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
max_dynamic_tiles: int | None
min_dynamic_tiles: int | None
use_thumbnail: bool | None

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -33,21 +32,13 @@ from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
else:
CONFIG_MAPPING = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
@@ -200,7 +191,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(0.0, 1.0)
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -211,7 +202,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else 0.0
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -230,14 +221,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
models = [paligemma.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
@@ -263,10 +254,10 @@ def compute_layer_complete(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -274,7 +265,7 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
@@ -286,15 +277,15 @@ def compute_layer_complete(
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
@@ -367,7 +358,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
@@ -375,7 +366,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
vlm_config_hf.vision_config.torch_dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
@@ -386,13 +377,13 @@ class PaliGemmaWithExpertModel(
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
torch_dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
@@ -407,11 +398,10 @@ class PaliGemmaWithExpertModel(
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Align with PI05.
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -423,8 +413,8 @@ class PaliGemmaWithExpertModel(
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
@@ -434,23 +424,15 @@ class PaliGemmaWithExpertModel(
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
self.paligemma.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.language_model.embed_tokens(tokens)
def forward(
self,
@@ -464,7 +446,7 @@ class PaliGemmaWithExpertModel(
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
prefix_output = self.paligemma.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -488,7 +470,7 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
models = [self.paligemma.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
@@ -528,7 +510,7 @@ class PaliGemmaWithExpertModel(
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -594,19 +576,29 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI0Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
@@ -768,7 +760,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
@@ -842,7 +834,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
@@ -916,7 +908,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -1006,12 +997,14 @@ class PI0Policy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Load state dict (expects keys with "model." prefix)
# Now manually load and remap the state dict
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -1019,7 +1012,7 @@ class PI0Policy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -1032,7 +1025,7 @@ class PI0Policy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -1077,7 +1070,7 @@ class PI0Policy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
print(f"Warning: Could not remap state dict keys: {e}")
return model
@@ -1127,14 +1120,6 @@ class PI0Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if (
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
or key == "paligemma_with_expert.paligemma.lm_head.weight"
):
fixed_state_dict[
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import builtins
import copy
import logging
import math
from collections import deque
@@ -33,20 +32,14 @@ from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.models.auto import CONFIG_MAPPING
from transformers.models.gemma import modeling_gemma
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaForCausalLM,
_gated_residual,
layernorm_forward,
)
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
else:
CONFIG_MAPPING = None
modeling_gemma = None
PiGemmaForCausalLM = None
_gated_residual = None
layernorm_forward = None
PaliGemmaForConditionalGenerationWithPiGemma = None
GemmaForCausalLM = None
PaliGemmaForConditionalGeneration = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -99,11 +92,10 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
alpha_t = torch.tensor(alpha, dtype=torch.float32)
beta_t = torch.tensor(beta, dtype=torch.float32)
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
dist = torch.distributions.Beta(alpha_t, beta_t)
return dist.sample((bsize,)).to(device)
return dist.sample((bsize,))
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
@@ -197,7 +189,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(0.0, 1.0)
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -208,7 +200,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else 0.0
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -227,14 +219,14 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
def compute_layer_complete(
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert
):
models = [paligemma.model.language_model, gemma_expert.model]
models = [paligemma.language_model, gemma_expert.model]
query_states = []
key_states = []
value_states = []
gates = []
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
hidden_states, gate = layernorm_forward(layer.input_layernorm, hidden_states, adarms_cond[i])
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901
gates.append(gate)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
@@ -260,10 +252,10 @@ def compute_layer_complete(
query_states, key_states, cos, sin, unsqueeze_dim=1
)
batch_size = query_states.shape[0]
scaling = paligemma.model.language_model.layers[layer_idx].self_attn.scaling
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
# Attention computation
att_output, _ = modeling_gemma.eager_attention_forward(
paligemma.model.language_model.layers[layer_idx].self_attn,
paligemma.language_model.layers[layer_idx].self_attn,
query_states,
key_states,
value_states,
@@ -271,7 +263,7 @@ def compute_layer_complete(
scaling,
)
# Get head_dim from the current layer, not from the model
head_dim = paligemma.model.language_model.layers[layer_idx].self_attn.head_dim
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim)
# Process layer outputs
outputs_embeds = []
@@ -283,15 +275,15 @@ def compute_layer_complete(
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos])
# first residual
out_emb = _gated_residual(hidden_states, out_emb, gates[i])
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
after_first_residual = out_emb.clone()
out_emb, gate = layernorm_forward(layer.post_attention_layernorm, out_emb, adarms_cond[i])
out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i])
# Convert to bfloat16 if the next layer (mlp) uses bfloat16
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
out_emb = out_emb.to(dtype=torch.bfloat16)
out_emb = layer.mlp(out_emb)
# second residual
out_emb = _gated_residual(after_first_residual, out_emb, gate)
out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001
outputs_embeds.append(out_emb)
start_pos = end_pos
return outputs_embeds
@@ -364,7 +356,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
@@ -372,7 +364,7 @@ class PaliGemmaWithExpertModel(
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
vlm_config_hf.vision_config.torch_dtype = "float32"
action_expert_config_hf = CONFIG_MAPPING["gemma"](
head_dim=action_expert_config.head_dim,
@@ -383,13 +375,13 @@ class PaliGemmaWithExpertModel(
num_key_value_heads=action_expert_config.num_kv_heads,
vocab_size=257152,
hidden_activation="gelu_pytorch_tanh",
dtype="float32",
torch_dtype="float32",
use_adarms=use_adarms[1],
adarms_cond_dim=action_expert_config.width if use_adarms[1] else None,
)
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
self.gemma_expert = PiGemmaForCausalLM(config=action_expert_config_hf)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf)
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_for_selected_params(precision)
@@ -404,11 +396,10 @@ class PaliGemmaWithExpertModel(
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Saves memory vs full float32; more memory than only 3 params.
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -420,8 +411,8 @@ class PaliGemmaWithExpertModel(
def _set_requires_grad(self):
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
for param in self.paligemma.model.vision_tower.parameters():
self.paligemma.vision_tower.eval()
for param in self.paligemma.vision_tower.parameters():
param.requires_grad = False
if self.train_expert_only:
self.paligemma.eval()
@@ -431,23 +422,15 @@ class PaliGemmaWithExpertModel(
def train(self, mode: bool = True):
super().train(mode)
if self.freeze_vision_encoder:
self.paligemma.model.vision_tower.eval()
self.paligemma.vision_tower.eval()
if self.train_expert_only:
self.paligemma.eval()
def embed_image(self, image: torch.Tensor):
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32).
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.language_model.embed_tokens(tokens)
def forward(
self,
@@ -461,7 +444,7 @@ class PaliGemmaWithExpertModel(
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
prefix_output = self.paligemma.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -485,7 +468,7 @@ class PaliGemmaWithExpertModel(
prefix_output = None
prefix_past_key_values = None
else:
models = [self.paligemma.model.language_model, self.gemma_expert.model]
models = [self.paligemma.language_model, self.gemma_expert.model]
num_layers = self.paligemma.config.text_config.num_hidden_layers
# Check if gradient checkpointing is enabled for any of the models
@@ -525,7 +508,7 @@ class PaliGemmaWithExpertModel(
def compute_final_norms(inputs_embeds, adarms_cond):
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
outputs_embeds.append(out_emb)
return outputs_embeds
@@ -590,19 +573,29 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
# Also compile the main forward pass used during training
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True
logging.info("Enabled gradient checkpointing for PI05Pytorch model")
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
@@ -744,7 +737,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
suffix_embs = suffix_embs.to(dtype=torch.bfloat16)
@@ -815,7 +808,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks)
self.paligemma_with_expert.paligemma.model.language_model.config._attn_implementation = "eager" # noqa: SLF001
self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks_4d,
@@ -887,7 +880,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks)
self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001
past_key_values = copy.deepcopy(past_key_values)
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks_4d,
position_ids=position_ids,
@@ -977,12 +969,14 @@ class PI05Policy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Load state dict (expects keys with "model." prefix)
# Now manually load and remap the state dict
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -990,7 +984,7 @@ class PI05Policy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -1003,7 +997,7 @@ class PI05Policy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
@@ -1015,6 +1009,8 @@ class PI05Policy(PreTrainedPolicy):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
@@ -1048,7 +1044,7 @@ class PI05Policy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
print(f"Warning: Could not remap state dict keys: {e}")
return model
@@ -1102,14 +1098,6 @@ class PI05Policy(PreTrainedPolicy):
# Some checkpoints might have this, but current model expects different structure
logging.warning(f"Vision embedding key might need handling: {key}")
if (
key == "model.paligemma_with_expert.paligemma.lm_head.weight"
or key == "paligemma_with_expert.paligemma.lm_head.weight"
):
fixed_state_dict[
"model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
] = value.clone()
fixed_state_dict[new_key] = value
return fixed_state_dict

View File

@@ -23,6 +23,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
@@ -67,6 +68,9 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()

View File

@@ -54,7 +54,7 @@ class PI0FastConfig(PreTrainedConfig):
tokenizer_max_length: int = 200 # see openpi `__post_init__`
text_tokenizer_name: str = "google/paligemma-3b-pt-224"
action_tokenizer_name: str = "lerobot/fast-action-tokenizer"
action_tokenizer_name: str = "physical-intelligence/fast"
temperature: float = 0.0
max_decoding_steps: int = 256
fast_skip_tokens: int = 128

View File

@@ -38,16 +38,11 @@ else:
if TYPE_CHECKING or _transformers_available:
from transformers import AutoTokenizer
from transformers.models.auto import CONFIG_MAPPING
from lerobot.policies.pi_gemma import (
PaliGemmaForConditionalGenerationWithPiGemma,
PiGemmaModel,
)
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
else:
CONFIG_MAPPING = None
PaliGemmaForConditionalGeneration = None
AutoTokenizer = None
PiGemmaModel = None
PaliGemmaForConditionalGenerationWithPiGemma = None
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
@@ -126,7 +121,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
if images.dtype == torch.uint8:
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8)
elif images.dtype == torch.float32:
resized_images = resized_images.clamp(0.0, 1.0)
resized_images = resized_images.clamp(-1.0, 1.0)
else:
raise ValueError(f"Unsupported image dtype: {images.dtype}")
@@ -137,7 +132,7 @@ def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy)
pad_w1 = pad_w0 + remainder_w
# Pad
constant_value = 0 if images.dtype == torch.uint8 else 0.0
constant_value = 0 if images.dtype == torch.uint8 else -1.0
padded_images = F.pad(
resized_images,
(pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom
@@ -211,22 +206,16 @@ class PI0FastPaliGemma(nn.Module):
vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth
vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads
vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh"
vlm_config_hf.text_config.dtype = "float32"
vlm_config_hf.text_config.torch_dtype = "float32"
vlm_config_hf.text_config.vocab_size = 257152
vlm_config_hf.text_config.use_adarms = use_adarms[0]
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
vlm_config_hf.vision_config.intermediate_size = 4304
vlm_config_hf.vision_config.projection_dim = 2048
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
vlm_config_hf.vision_config.dtype = "float32"
vlm_config_hf.vision_config.torch_dtype = "float32"
self.paligemma = PaliGemmaForConditionalGenerationWithPiGemma(config=vlm_config_hf)
# Use PI Gemma (AdaRMS) as language model when use_adarms[0] is True so that
# forward(..., adarms_cond=...) is supported (same as pi0/pi05).
if use_adarms[0]:
text_config = self.paligemma.config.text_config
self.paligemma.model.language_model = PiGemmaModel(text_config)
self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf)
self.to_bfloat16_for_selected_params(precision)
@@ -239,11 +228,10 @@ class PI0FastPaliGemma(nn.Module):
else:
raise ValueError(f"Invalid precision: {precision}")
# Keep full vision path in float32 so we never toggle (toggle causes optimizer
# "same dtype" error). Align with PI05.
params_to_keep_float32 = [
"vision_tower",
"multi_modal_projector",
"vision_tower.vision_model.embeddings.patch_embedding.weight",
"vision_tower.vision_model.embeddings.patch_embedding.bias",
"vision_tower.vision_model.embeddings.position_embedding.weight",
"input_layernorm",
"post_attention_layernorm",
"model.norm",
@@ -254,18 +242,10 @@ class PI0FastPaliGemma(nn.Module):
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
# Vision tower and multi_modal_projector are kept in float32 (params_to_keep_float32). Align with PI05.
out_dtype = image.dtype
if image.dtype != torch.float32:
image = image.to(torch.float32)
image_outputs = self.paligemma.model.get_image_features(image)
features = image_outputs.pooler_output * self.paligemma.config.text_config.hidden_size**0.5
if features.dtype != out_dtype:
features = features.to(out_dtype)
return features
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.model.language_model.embed_tokens(tokens)
return self.paligemma.language_model.embed_tokens(tokens)
def forward(
self,
@@ -279,7 +259,7 @@ class PI0FastPaliGemma(nn.Module):
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
prefix_output = self.paligemma.model.language_model.forward(
prefix_output = self.paligemma.language_model.forward(
inputs_embeds=inputs_embeds[0],
attention_mask=attention_mask,
position_ids=position_ids,
@@ -326,14 +306,24 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode)
self.forward = torch.compile(self.forward, mode=config.compile_mode)
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
try:
from transformers.models.siglip import check
if not check.check_whether_transformers_replace_is_installed_correctly():
raise ValueError(msg)
except ImportError:
raise ValueError(msg) from None
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory optimization."""
self.gradient_checkpointing_enabled = True
# Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_enable(
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_enable(
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
logging.info("Enabled gradient checkpointing for PI0FastPytorch model")
@@ -342,8 +332,8 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
"""Disable gradient checkpointing."""
self.gradient_checkpointing_enabled = False
# Call the proper gradient_checkpointing_disable() method
self.paligemma_with_expert.paligemma.model.language_model.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.model.vision_tower.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable()
self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable()
logging.info("Disabled gradient checkpointing for PI0FastPytorch model")
def _apply_checkpoint(self, func, *args, **kwargs):
@@ -533,7 +523,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# Convert embeddings to bfloat16 if needed
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -626,7 +616,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -724,7 +714,7 @@ class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch`
# Ensure correct precision (bfloat16/float32)
if (
self.paligemma_with_expert.paligemma.model.language_model.layers[0].self_attn.q_proj.weight.dtype
self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype
== torch.bfloat16
):
prefix_embs = prefix_embs.to(dtype=torch.bfloat16)
@@ -907,12 +897,14 @@ class PI0FastPolicy(PreTrainedPolicy):
# Check if dataset_stats were provided in kwargs
model = cls(config, **kwargs)
# Load state dict (expects keys with "model." prefix)
# Now manually load and remap the state dict
try:
# Try to load the pytorch_model.bin or model.safetensors file
print(f"Loading model from: {pretrained_name_or_path}")
try:
from transformers.utils import cached_file
# Try safetensors first
resolved_file = cached_file(
pretrained_name_or_path,
"model.safetensors",
@@ -920,7 +912,7 @@ class PI0FastPolicy(PreTrainedPolicy):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
@@ -933,9 +925,8 @@ class PI0FastPolicy(PreTrainedPolicy):
print("Returning model without loading pretrained weights")
return model
# First, fix any key differences (see openpi model.py, _fix_pytorch_state_dict_keys)
# First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys`
fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config)
# Then add "model." prefix for all keys that don't already have it
remapped_state_dict = {}
remap_count = 0
@@ -945,6 +936,8 @@ class PI0FastPolicy(PreTrainedPolicy):
new_key = f"model.{key}"
remapped_state_dict[new_key] = value
remap_count += 1
if remap_count <= 10: # Only print first 10 to avoid spam
print(f"Remapped: {key} -> {new_key}")
else:
remapped_state_dict[key] = value
@@ -978,7 +971,7 @@ class PI0FastPolicy(PreTrainedPolicy):
print("All keys loaded successfully!")
except Exception as e:
print(f"Warning: Could not load state dict: {e}")
print(f"Warning: Could not remap state dict keys: {e}")
return model

View File

@@ -23,6 +23,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
from lerobot.processor import (
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
@@ -68,6 +69,9 @@ class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep):
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()

View File

@@ -1,363 +0,0 @@
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from torch import nn
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers.cache_utils import DynamicCache
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaConfig,
GemmaForCausalLM,
GemmaMLP,
GemmaModel,
)
from transformers.models.paligemma.modeling_paligemma import (
PaliGemmaForConditionalGeneration,
PaliGemmaModel,
)
else:
GemmaAttention = None
GemmaConfig = None
GemmaForCausalLM = None
GemmaMLP = None
GemmaModel = None
PaliGemmaModel = None
PaliGemmaForConditionalGeneration = None
DynamicCache = None
GradientCheckpointingLayer = None
BaseModelOutputWithPast = None
create_causal_mask = None
def _gated_residual(
x: torch.Tensor | None,
y: torch.Tensor | None,
gate: torch.Tensor | None,
) -> torch.Tensor | None:
"""Gated residual: x + y when gate is None, else x + y * gate."""
if x is None and y is None:
return None
if x is None or y is None:
return x if x is not None else y
if gate is None:
return x + y
return x + y * gate
def layernorm_forward(
layernorm: nn.Module,
x: torch.Tensor,
cond: torch.Tensor | None = None,
):
"""
call layernorm and return hidden states and gate
if cond is not None, use conditional norm
otherwise, use normal gemma norm
"""
if cond is not None:
return layernorm(x, cond=cond)
else:
return layernorm(x)
class PiGemmaRMSNorm(nn.Module):
"""
Adaptive RMSNorm for PI Gemma (AdaRMS).
When cond_dim is set, uses cond to modulate scale/shift/gate; otherwise behaves like standard GemmaRMSNorm.
forward(x, cond=None) returns (output, gate) for use with _gated_residual.
"""
def __init__(self, dim: int, eps: float = 1e-6, cond_dim: int | None = None):
super().__init__()
self.eps = eps
self.dim = dim
self.cond_dim = cond_dim
if cond_dim is not None:
self.dense = nn.Linear(cond_dim, dim * 3, bias=True)
nn.init.zeros_(self.dense.weight)
else:
self.weight = nn.Parameter(torch.zeros(dim))
self.dense = None
def _norm(self, x):
# Compute variance in float32 (like the source implementation)
var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True)
# Compute normalization in float32
normed_inputs = x * torch.rsqrt(var + self.eps)
return normed_inputs
def forward(
self,
x: torch.Tensor,
cond: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
dtype = x.dtype
normed = self._norm(x)
if cond is None or self.dense is None:
normed = normed * (1.0 + self.weight.float())
return normed.type_as(x), None
if cond.shape[-1] != self.cond_dim:
raise ValueError(f"Expected cond dim {self.cond_dim}, got {cond.shape[-1]}")
modulation = self.dense(cond)
if len(x.shape) == 3:
modulation = modulation.unsqueeze(1)
scale, shift, gate = modulation.chunk(3, dim=-1)
normed = normed * (1 + scale.float()) + shift.float()
return normed.to(dtype), gate.to(dtype)
def extra_repr(self) -> str:
if self.dense is not None:
return f"dim={self.dim}, eps={self.eps}, adaptive=True, cond_dim={self.cond_dim}"
return f"dim={self.dim}, eps={self.eps}"
def _get_pi_gemma_decoder_layer_base():
"""base for PiGemmaDecoderLayer"""
class _PiGemmaDecoderLayerBase(GradientCheckpointingLayer):
"""Decoder layer that uses PiGemmaRMSNorm and _gated_residual, compatible with v5 Gemma."""
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
self.mlp = GemmaMLP(config)
cond_dim = (
getattr(config, "adarms_cond_dim", None) if getattr(config, "use_adarms", False) else None
)
self.input_layernorm = PiGemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
self.post_attention_layernorm = PiGemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
hidden_states, gate = self.input_layernorm(hidden_states, cond=adarms_cond)
hidden_states, _ = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = _gated_residual(residual, hidden_states, gate)
residual = hidden_states
hidden_states, gate = self.post_attention_layernorm(hidden_states, cond=adarms_cond)
hidden_states = self.mlp(hidden_states)
hidden_states = _gated_residual(residual, hidden_states, gate)
return hidden_states
return _PiGemmaDecoderLayerBase
class PiGemmaModel(GemmaModel): # type: ignore[misc]
"""
GemmaModel extended with AdaRMS (adaptive RMSNorm) and gated residuals when config.use_adarms is True.
"""
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
# if not getattr(config, "use_adarms", False):
# return
cond_dim = getattr(config, "adarms_cond_dim", None)
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
self.layers = nn.ModuleList(
[pi_gemma_decoder_layer_base(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = PiGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim)
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: DynamicCache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
cache_position: torch.LongTensor | None = None,
adarms_cond: torch.Tensor | None = None,
**kwargs,
) -> BaseModelOutputWithPast:
"""
adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*):
Condition for ADARMS.
"""
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
import logging
logging.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
# embed positions
hidden_states = inputs_embeds
# Convert to bfloat16 if the first layer uses bfloat16
if len(self.layers) > 0 and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.bfloat16)
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# normalized
# Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
adarms_cond=adarms_cond,
**kwargs,
)
hidden_states = layer_outputs
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states, _ = self.norm(hidden_states, adarms_cond)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class PiGemmaForCausalLM(GemmaForCausalLM): # type: ignore[misc]
"""
Causal LM wrapper using PiGemmaModel as the backbone, for consistency with GemmaForCausalLM
and the language model used in pi0_fast. Use this for the action expert in pi0/pi05.
"""
def __init__(self, config: GemmaConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = PiGemmaModel(config)
class PaliGemmaModelWithPiGemma(PaliGemmaModel):
"""PaliGemmaModel whose language_model is PiGemmaModel (custom decoder with PiGemmaRMSNorm and gated residuals)."""
def __init__(self, config):
super().__init__(config)
self.language_model = PiGemmaModel(config.text_config)
class PaliGemmaForConditionalGenerationWithPiGemma(PaliGemmaForConditionalGeneration):
"""PaliGemmaForConditionalGeneration using PiGemma decoder for the language model."""
def __init__(self, config):
super().__init__(config)
self.model = PaliGemmaModelWithPiGemma(config)
# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model
__all__ = [
"PiGemmaModel",
"PiGemmaForCausalLM",
"PiGemmaRMSNorm",
"_gated_residual",
"layernorm_forward",
"PaliGemmaModelWithPiGemma",
"PaliGemmaForConditionalGenerationWithPiGemma",
]

View File

@@ -33,7 +33,7 @@ class RewardClassifierConfig(PreTrainedConfig):
latent_dim: int = 256
image_embedding_pooling_dim: int = 8
dropout_rate: float = 0.1
model_name: str = "helper2424/resnet10" # TODO: This needs to be updated. The model on the Hub doesn't call self.post_init() in its __init__, which is required by transformers v5 to set all_tied_weights_keys. The from_pretrained call fails when it tries to access this attribute during _finalize_model_loading.
model_name: str = "helper2424/resnet10"
device: str = "cpu"
model_type: str = "cnn" # "transformer" or "cnn"
num_cameras: int = 2

View File

@@ -55,7 +55,7 @@ class WallXConfig(PreTrainedConfig):
pretrained_name_or_path: str = "x-square-robot/wall-oss-flow"
# Tokenizer settings
action_tokenizer_path: str | None = "lerobot/fast-action-tokenizer"
action_tokenizer_path: str | None = "physical-intelligence/fast"
# Action prediction mode: "diffusion" or "fast"
prediction_mode: str = "diffusion"

View File

@@ -261,15 +261,10 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
and optional LoRA fine-tuning support.
"""
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tied_weights_keys = ["lm_head.weight"]
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer_with_MoE", "Qwen2_5_VLVisionBlock"]
def init_weights(self):
if getattr(self.model, "language_model", None) is not None:
return
super().init_weights()
@classmethod
def from_pretrained(
cls,
@@ -317,11 +312,6 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
processor.action_processor = action_tokenizer
else:
action_tokenizer = None
# add pad_token_id to config
config.pad_token_id = processor.tokenizer.pad_token_id
config.text_config.pad_token_id = processor.tokenizer.pad_token_id
# Initialize model with configuration and processor
model = cls(config, processor=processor, action_tokenizer=action_tokenizer, **kwargs)
@@ -341,7 +331,7 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download"),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
use_auth_token=kwargs.get("use_auth_token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)

View File

@@ -21,7 +21,6 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)
@@ -39,7 +38,6 @@ class Qwen2_5_VLVisionConfig(PretrainedConfig):
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
self.initializer_range = initializer_range
class Qwen2_5_VLConfig(PretrainedConfig):

View File

@@ -11,6 +11,7 @@ from transformers.activations import ACT2FN
from transformers.cache_utils import (
Cache,
DynamicCache,
SlidingWindowCache,
StaticCache,
)
from transformers.generation import GenerationMixin
@@ -30,15 +31,6 @@ from transformers.utils import (
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
# TODO(Steven): SlidingWindowCache was removed in transformers v5. Define a placeholder so isinstance checks
# always return False (which is the correct behavior when no sliding window cache is in use).
class _SlidingWindowCachePlaceholder:
pass
SlidingWindowCache = _SlidingWindowCachePlaceholder
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
@@ -602,40 +594,19 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
return hidden_states
def _compute_default_rope_parameters_qwen2_5_vl(config, device=None):
"""
compute default rope parameters for Qwen2_5_VL
"""
base = config.text_config.rope_parameters["rope_theta"]
dim = config.hidden_size // config.num_attention_heads
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, 1.0
class Qwen2_5_VLRotaryEmbedding(nn.Module):
def __init__(self, config: Qwen2_5_VLConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
elif hasattr(config, "rope_parameters") and config.rope_parameters is not None:
self.rope_type = config.rope_parameters.get("rope_type", "default")
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
if self.rope_type == "default":
self.rope_init_fn = _compute_default_rope_parameters_qwen2_5_vl
self.rope_kwargs = {}
else:
rope_type_key = "linear" if self.rope_type == "linear" else self.rope_type
self.rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type_key]
self.rope_kwargs = {}
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@@ -1596,7 +1567,7 @@ QWEN2_5_VL_INPUTS_DOCSTRING = r"""
class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tied_weights_keys = ["lm_head.weight"]
config_class = Qwen2_5_VLConfig
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]

View File

@@ -144,7 +144,7 @@ def preprocesser_call(
"""
# Process image inputs
if images is not None and len(images) > 0:
image_inputs = processor.image_processor(images=images, return_tensors=return_tensors)
image_inputs = processor.image_processor(images=images, videos=None, return_tensors=return_tensors)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
@@ -152,7 +152,7 @@ def preprocesser_call(
# Process video inputs
if videos is not None:
videos_inputs = processor.image_processor(videos=videos, return_tensors=return_tensors)
videos_inputs = processor.image_processor(images=None, videos=videos, return_tensors=return_tensors)
video_grid_thw = videos_inputs["video_grid_thw"]
else:
videos_inputs = {}

View File

@@ -276,8 +276,6 @@ class Florence2LanguageConfig(PretrainedConfig):
)
# ensure backward compatibility for BART CNN models
if not hasattr(self, "forced_bos_token_id"):
self.forced_bos_token_id = None
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(

View File

@@ -1951,10 +1951,7 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
_tied_weights_keys = {
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: Florence2LanguageConfig):
super().__init__(config)
@@ -2079,10 +2076,7 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = {
"model.encoder.embed_tokens.weight": "model.shared.weight",
"model.decoder.embed_tokens.weight": "model.shared.weight",
}
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
def __init__(self, config: Florence2LanguageConfig):
@@ -2442,10 +2436,11 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
FLORENCE2_START_DOCSTRING,
)
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
_tied_weights_keys = {
"language_model.model.encoder.embed_tokens.weight": "language_model.model.shared.weight",
"language_model.model.decoder.embed_tokens.weight": "language_model.model.shared.weight",
}
_tied_weights_keys = [
"language_model.encoder.embed_tokens.weight",
"language_model.decoder.embed_tokens.weight",
"language_model.lm_head.weight",
]
def __init__(self, config: Florence2Config):
super().__init__(config)

View File

@@ -336,7 +336,7 @@ class ActionTokenizerProcessorStep(ActionProcessorStep):
Requires the `transformers` library to be installed.
Attributes:
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "lerobot/fast-action-tokenizer").
tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast").
tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored.
trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers).
action_tokenizer: The internal tokenizer/processor instance, loaded during initialization.

View File

@@ -306,7 +306,7 @@ def train_fast_tokenizer(
# download the tokenizer source code (not pretrained weights)
# we'll train a new tokenizer on our own data
base_tokenizer = AutoProcessor.from_pretrained("lerobot/fast-action-tokenizer", trust_remote_code=True)
base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
# convert action_chunks array to list of arrays (expected by .fit())
action_data_list = [action_chunks[i] for i in range(len(action_chunks))]

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@@ -38,9 +37,6 @@ def test_classifier_output():
@require_package("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_binary_classifier_with_default_params():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -82,9 +78,6 @@ def test_binary_classifier_with_default_params():
@require_package("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_multiclass_classifier():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -124,9 +117,6 @@ def test_multiclass_classifier():
@require_package("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_default_device():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
@@ -139,9 +129,6 @@ def test_default_device():
@require_package("transformers")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_explicit_device_setup():
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier

View File

@@ -17,6 +17,7 @@
"""Test script to verify PI0Fast policy integration with LeRobot vs the original implementation"""
# ruff: noqa: E402
import os
import random
from copy import deepcopy
from typing import Any
@@ -27,6 +28,10 @@ import torch
pytest.importorskip("transformers")
pytest.importorskip("scipy")
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires accepting the model license",
)
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy
@@ -48,23 +53,22 @@ DUMMY_STATE_DIM = 20
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
NUM_VIEWS = 2 # Number of camera views
DEVICE = "cuda"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH_LEROBOT = "lerobot/pi0fast-base"
# Expected action token shape: (batch_size, max_decoding_steps)
EXPECTED_ACTION_TOKENS_SHAPE = (1, 2)
# Expected first 5 action tokens (for reproducibility check)
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255020, 255589])
EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255362])
# Expected actions after detokenization
EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim)
EXPECTED_ACTIONS_MEAN = 0.046403881162405014
EXPECTED_ACTIONS_STD = 0.2607129216194153
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 0.3536, 0.0707, 0.0000, 0.0000])
EXPECTED_ACTIONS_MEAN = 0.04419417306780815
EXPECTED_ACTIONS_STD = 0.26231569051742554
EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 1.4849, 0.0000, 0.0000, 0.0000])
@require_cuda
def set_seed_all(seed: int):
"""Set random seed for all RNG sources to ensure reproducibility."""
random.seed(seed)
@@ -81,7 +85,6 @@ def set_seed_all(seed: int):
torch.use_deterministic_algorithms(True, warn_only=True)
@require_cuda
def instantiate_lerobot_pi0_fast(
from_pretrained: bool = False,
model_path: str = MODEL_PATH_LEROBOT,
@@ -124,7 +127,6 @@ def instantiate_lerobot_pi0_fast(
return policy, preprocessor, postprocessor
@require_cuda
def create_dummy_data(device=DEVICE):
"""Create dummy data for testing both implementations."""
batch_size = 1
@@ -156,25 +158,22 @@ def create_dummy_data(device=DEVICE):
# Pytest fixtures
@pytest.fixture(scope="module")
@require_cuda
def pi0_fast_components():
"""Fixture to instantiate and provide all PI0Fast components for tests."""
print(f"\nTesting with DEVICE='{DEVICE}'")
print("\n[Setup] Instantiating LeRobot PI0Fast policy...")
policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_pi0_fast(from_pretrained=True)
print("Model loaded successfully")
return policy_obj, preprocessor_obj, postprocessor_obj
yield policy_obj, preprocessor_obj, postprocessor_obj
@pytest.fixture(scope="module")
@require_cuda
def policy(pi0_fast_components):
"""Fixture to provide the PI0Fast policy for tests."""
return pi0_fast_components[0]
@pytest.fixture(scope="module")
@require_cuda
def preprocessor(pi0_fast_components):
"""Fixture to provide the PI0Fast preprocessor for tests."""
return pi0_fast_components[1]

View File

@@ -16,8 +16,17 @@
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi0 import ( # noqa: E402
PI0Config,

View File

@@ -16,15 +16,25 @@
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
import os
import pytest
import torch
from lerobot.utils.random_utils import set_seed
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi05 import ( # noqa: E402
PI05Config,
PI05Policy,
make_pi05_pre_post_processors, # noqa: E402
)
from lerobot.utils.random_utils import set_seed
from tests.utils import require_cuda # noqa: E402

View File

@@ -24,10 +24,9 @@ import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="TODO: This test seems to hang the CI",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402

View File

@@ -24,10 +24,9 @@ import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="TODO: This test seems to hang the CI",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
@@ -89,7 +88,6 @@ def test_pi0_rtc_initialization_without_rtc_config():
print("✓ PI0 RTC initialization without RTC config: Test passed")
@require_cuda
def test_pi0_rtc_inference_with_prev_chunk():
"""Test PI0 policy inference with RTC and previous chunk."""
set_seed(42)

View File

@@ -305,9 +305,6 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
@pytest.mark.skip(
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
)
def test_sac_policy_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):

View File

@@ -16,6 +16,8 @@
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
import os
import pytest
import torch
@@ -24,15 +26,19 @@ pytest.importorskip("peft")
pytest.importorskip("transformers")
pytest.importorskip("torchdiffeq")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local Wall-X installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.wall_x import WallXConfig # noqa: E402
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors # noqa: E402
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_policy_instantiation():
# Create config
set_seed(42)
@@ -117,7 +123,6 @@ def test_policy_instantiation():
raise
@require_cuda
def test_config_creation():
"""Test policy config creation through factory."""
try:
@@ -129,3 +134,8 @@ def test_config_creation():
except Exception as e:
print(f"Config creation failed: {e}")
raise
if __name__ == "__main__":
test_policy_instantiation()
test_config_creation()