mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
3 Commits
feat/trim-
...
security-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aba8beddda | ||
|
|
85de893fa7 | ||
|
|
a4c66e530b |
1
.github/workflows/fast_tests.yml
vendored
1
.github/workflows/fast_tests.yml
vendored
@@ -91,6 +91,7 @@ jobs:
|
||||
run: uv sync --extra "test"
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
|
||||
5
.github/workflows/full_tests.yml
vendored
5
.github/workflows/full_tests.yml
vendored
@@ -89,6 +89,7 @@ jobs:
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
@@ -181,6 +182,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
@@ -200,7 +202,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get Docker Hub Token and Delete Image
|
||||
# zizmor: ignore[template-injection]
|
||||
env:
|
||||
DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}
|
||||
DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}
|
||||
@@ -232,4 +233,4 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# TODO(Steven): Check dockerimages pull in ubuntu
|
||||
# TODO(Steven): Check dockerimages pull in ubuntu
|
||||
3
.github/workflows/nightly.yml
vendored
3
.github/workflows/nightly.yml
vendored
@@ -132,6 +132,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
@@ -164,6 +165,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
@@ -197,6 +199,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
|
||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -83,14 +83,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Remove Tags with Git dependencies
|
||||
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
run: |
|
||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
|
||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||
|
||||
- name: Install build dependencies
|
||||
run: python -m pip install build
|
||||
|
||||
|
||||
2
.github/workflows/unbound_deps_tests.yml
vendored
2
.github/workflows/unbound_deps_tests.yml
vendored
@@ -81,6 +81,7 @@ jobs:
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --extra all # TODO(Steven): Make flash-attn optional
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
uv run hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
uv run hf auth whoami
|
||||
@@ -154,6 +155,7 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Login to Hugging Face
|
||||
if: env.HF_USER_TOKEN != ''
|
||||
run: |
|
||||
hf auth login --token "$HF_USER_TOKEN" --add-to-git-credential
|
||||
hf auth whoami
|
||||
|
||||
@@ -90,9 +90,6 @@ _Replace `[...]` with your desired features._
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
|
||||
@@ -34,11 +34,6 @@ As described by Physical Intelligence, while AI has achieved remarkable success
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training Data and Capabilities
|
||||
|
||||
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
|
||||
|
||||
@@ -36,11 +36,6 @@ This diverse training mixture creates a "curriculum" that enables generalization
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Usage
|
||||
|
||||
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
@@ -43,11 +43,6 @@ This approach can transform **any existing VLM** into a VLA by training it to pr
|
||||
pip install -e ".[pi]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
## Training a Custom FAST Tokenizer
|
||||
|
||||
You have two options for the FAST tokenizer:
|
||||
|
||||
@@ -25,7 +25,6 @@ This module provides utilities for:
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
@@ -46,8 +45,6 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
flatten_dict,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
update_chunk_file_indices,
|
||||
@@ -144,315 +141,6 @@ def delete_episodes(
|
||||
return new_dataset
|
||||
|
||||
|
||||
def trim_episode_start(
|
||||
dataset: LeRobotDataset,
|
||||
seconds: float,
|
||||
episode_indices: list[int] | None = None,
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim the first N seconds from selected episodes and create a new dataset.
|
||||
|
||||
The operation rewrites data parquet files and updates episode metadata so that:
|
||||
- frame_index starts at 0 for each trimmed episode
|
||||
- timestamp starts at 0 for each trimmed episode
|
||||
- global index remains contiguous across the full dataset
|
||||
- dataset_from_index / dataset_to_index reflect new frame ranges
|
||||
|
||||
Video files are copied as-is and per-episode video timestamps are shifted forward
|
||||
for trimmed episodes.
|
||||
|
||||
Episodes selected for trimming that are too short (length <= trim_frames) are skipped
|
||||
from the output dataset.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
seconds: Number of seconds to remove from episode starts.
|
||||
episode_indices: Optional list of episode indices to trim. If None, trims all episodes.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
|
||||
"""
|
||||
if seconds <= 0:
|
||||
raise ValueError(f"seconds must be strictly positive, got {seconds}")
|
||||
|
||||
if dataset.meta.episodes is None:
|
||||
dataset.meta.episodes = load_episodes(dataset.meta.root)
|
||||
|
||||
trim_frames = int(seconds * dataset.meta.fps)
|
||||
if trim_frames <= 0:
|
||||
raise ValueError(
|
||||
f"seconds={seconds} corresponds to 0 frames at fps={dataset.meta.fps}. "
|
||||
"Increase seconds so at least one frame is trimmed."
|
||||
)
|
||||
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if len(episode_indices) == 0:
|
||||
raise ValueError("No episodes specified to trim")
|
||||
|
||||
episode_indices = sorted(set(episode_indices))
|
||||
valid_indices = set(range(dataset.meta.total_episodes))
|
||||
invalid = set(episode_indices) - valid_indices
|
||||
if invalid:
|
||||
raise ValueError(f"Invalid episode indices: {invalid}")
|
||||
|
||||
too_short = sorted(
|
||||
ep_idx for ep_idx in episode_indices if int(dataset.meta.episodes[ep_idx]["length"]) <= trim_frames
|
||||
)
|
||||
trim_set = set(episode_indices)
|
||||
skipped_set = set(too_short)
|
||||
trim_set -= skipped_set
|
||||
|
||||
if too_short:
|
||||
logging.warning(
|
||||
f"Skipping {len(too_short)} episode(s) that are too short to trim "
|
||||
f"({trim_frames} frames): {too_short}"
|
||||
)
|
||||
|
||||
episodes_to_keep = [ep_idx for ep_idx in range(dataset.meta.total_episodes) if ep_idx not in skipped_set]
|
||||
if not episodes_to_keep:
|
||||
raise ValueError(
|
||||
"All episodes selected for trimming are too short and would be skipped. "
|
||||
"Try a smaller trim duration."
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Trimming {len(trim_set)} episode(s) by {seconds}s and keeping {len(episodes_to_keep)} "
|
||||
f"episode(s) in output"
|
||||
)
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_trimmed"
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=dataset.meta.features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
new_meta.tasks = dataset.meta.tasks.copy()
|
||||
|
||||
subtasks_path = dataset.root / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
dst_subtasks_path = new_meta.root / DEFAULT_SUBTASKS_PATH
|
||||
dst_subtasks_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(subtasks_path, dst_subtasks_path)
|
||||
|
||||
episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)}
|
||||
trim_duration_s = trim_frames / dataset.meta.fps
|
||||
|
||||
episode_lengths: dict[int, int] = {}
|
||||
episode_ranges: dict[int, tuple[int, int]] = {}
|
||||
total_frames = 0
|
||||
for old_ep_idx in episodes_to_keep:
|
||||
new_ep_idx = episode_mapping[old_ep_idx]
|
||||
src_length = int(dataset.meta.episodes[old_ep_idx]["length"])
|
||||
new_length = src_length - trim_frames if old_ep_idx in trim_set else src_length
|
||||
episode_lengths[new_ep_idx] = new_length
|
||||
episode_ranges[new_ep_idx] = (total_frames, total_frames + new_length)
|
||||
total_frames += new_length
|
||||
|
||||
numeric_features = {
|
||||
k: v
|
||||
for k, v in dataset.meta.features.items()
|
||||
if v["dtype"] not in ["image", "video", "string"]
|
||||
}
|
||||
episode_stats_parts: dict[int, list[dict[str, dict]]] = defaultdict(list)
|
||||
episode_file_metadata: dict[int, dict[str, int]] = {}
|
||||
|
||||
data_dir = dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Trimming data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
if skipped_set:
|
||||
keep_mask = ~df["episode_index"].isin(skipped_set)
|
||||
if not keep_mask.all():
|
||||
df = df.loc[keep_mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
if trim_set:
|
||||
trim_mask = df["episode_index"].isin(trim_set) & (df["frame_index"] < trim_frames)
|
||||
if trim_mask.any():
|
||||
df = df.loc[~trim_mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
relative_path = src_path.relative_to(dataset.root)
|
||||
chunk_idx = int(relative_path.parts[1].split("-")[1])
|
||||
file_idx = int(relative_path.parts[2].split("-")[1].split(".")[0])
|
||||
|
||||
for old_ep_idx in sorted(df["episode_index"].unique().tolist()):
|
||||
ep_mask = df["episode_index"] == old_ep_idx
|
||||
new_ep_idx = episode_mapping[old_ep_idx]
|
||||
|
||||
if old_ep_idx in trim_set:
|
||||
df.loc[ep_mask, "frame_index"] = df.loc[ep_mask, "frame_index"] - trim_frames
|
||||
shifted_timestamps = df.loc[ep_mask, "timestamp"].to_numpy(dtype=np.float64) - trim_duration_s
|
||||
df.loc[ep_mask, "timestamp"] = np.clip(shifted_timestamps, a_min=0.0, a_max=None)
|
||||
|
||||
df.loc[ep_mask, "episode_index"] = new_ep_idx
|
||||
|
||||
ep_start, _ = episode_ranges[new_ep_idx]
|
||||
new_indices = ep_start + df.loc[ep_mask, "frame_index"].to_numpy(dtype=np.int64)
|
||||
df.loc[ep_mask, "index"] = new_indices
|
||||
|
||||
if new_ep_idx in episode_file_metadata:
|
||||
existing = episode_file_metadata[new_ep_idx]
|
||||
if (
|
||||
existing["data/chunk_index"] != chunk_idx
|
||||
or existing["data/file_index"] != file_idx
|
||||
):
|
||||
raise ValueError(
|
||||
f"Episode {old_ep_idx} spans multiple data files. "
|
||||
"trim_episode_start currently expects one data file per episode."
|
||||
)
|
||||
else:
|
||||
episode_file_metadata[new_ep_idx] = {
|
||||
"data/chunk_index": chunk_idx,
|
||||
"data/file_index": file_idx,
|
||||
}
|
||||
|
||||
if numeric_features:
|
||||
ep_df = df.loc[ep_mask]
|
||||
episode_data: dict[str, np.ndarray] = {}
|
||||
episode_feature_spec: dict[str, dict] = {}
|
||||
|
||||
for key, feature in numeric_features.items():
|
||||
if key not in ep_df.columns:
|
||||
continue
|
||||
|
||||
values = ep_df[key].to_numpy()
|
||||
if len(values) == 0:
|
||||
continue
|
||||
|
||||
first_value = values[0]
|
||||
if isinstance(first_value, np.ndarray):
|
||||
episode_data[key] = np.stack(values)
|
||||
elif isinstance(first_value, (list, tuple)):
|
||||
episode_data[key] = np.stack(values)
|
||||
else:
|
||||
episode_data[key] = np.asarray(values)
|
||||
|
||||
episode_feature_spec[key] = feature
|
||||
|
||||
if episode_data:
|
||||
episode_stats_parts[new_ep_idx].append(
|
||||
compute_episode_stats(episode_data, episode_feature_spec)
|
||||
)
|
||||
|
||||
df["index"] = df["index"].astype(np.int64)
|
||||
if "frame_index" in df.columns:
|
||||
df["frame_index"] = df["frame_index"].astype(np.int64)
|
||||
|
||||
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_write_parquet(df, dst_path, new_meta)
|
||||
|
||||
all_episode_stats = []
|
||||
for old_ep_idx in tqdm(episodes_to_keep, desc="Writing episode metadata"):
|
||||
new_ep_idx = episode_mapping[old_ep_idx]
|
||||
|
||||
if new_ep_idx not in episode_file_metadata:
|
||||
raise ValueError(f"Missing data file metadata for episode {old_ep_idx}")
|
||||
|
||||
from_idx, to_idx = episode_ranges[new_ep_idx]
|
||||
src_episode = dataset.meta.episodes[old_ep_idx]
|
||||
ep_data_meta = episode_file_metadata[new_ep_idx]
|
||||
|
||||
stats_parts = episode_stats_parts.get(new_ep_idx, [])
|
||||
ep_stats = aggregate_stats(stats_parts) if len(stats_parts) > 1 else (stats_parts[0] if stats_parts else {})
|
||||
if ep_stats:
|
||||
all_episode_stats.append(ep_stats)
|
||||
|
||||
episode_meta = {
|
||||
"data/chunk_index": ep_data_meta["data/chunk_index"],
|
||||
"data/file_index": ep_data_meta["data/file_index"],
|
||||
"dataset_from_index": from_idx,
|
||||
"dataset_to_index": to_idx,
|
||||
}
|
||||
|
||||
for video_key in dataset.meta.video_keys:
|
||||
from_ts = src_episode[f"videos/{video_key}/from_timestamp"]
|
||||
if old_ep_idx in trim_set:
|
||||
from_ts += trim_duration_s
|
||||
episode_meta.update(
|
||||
{
|
||||
f"videos/{video_key}/chunk_index": src_episode[f"videos/{video_key}/chunk_index"],
|
||||
f"videos/{video_key}/file_index": src_episode[f"videos/{video_key}/file_index"],
|
||||
f"videos/{video_key}/from_timestamp": from_ts,
|
||||
f"videos/{video_key}/to_timestamp": src_episode[f"videos/{video_key}/to_timestamp"],
|
||||
}
|
||||
)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": new_ep_idx,
|
||||
"tasks": src_episode["tasks"],
|
||||
"length": episode_lengths[new_ep_idx],
|
||||
}
|
||||
episode_dict.update(episode_meta)
|
||||
if ep_stats:
|
||||
episode_dict.update(flatten_dict({"stats": ep_stats}))
|
||||
|
||||
new_meta._save_episode_metadata(episode_dict)
|
||||
|
||||
new_meta._close_writer()
|
||||
|
||||
if new_meta.video_keys:
|
||||
_copy_videos(dataset, new_meta)
|
||||
|
||||
new_meta.info.update(
|
||||
{
|
||||
"total_episodes": len(episodes_to_keep),
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": len(new_meta.tasks) if new_meta.tasks is not None else 0,
|
||||
"splits": {"train": f"0:{len(episodes_to_keep)}"},
|
||||
}
|
||||
)
|
||||
|
||||
if new_meta.video_keys and dataset.meta.video_keys:
|
||||
for key in new_meta.video_keys:
|
||||
if key in dataset.meta.features:
|
||||
new_meta.info["features"][key]["info"] = dataset.meta.info["features"][key].get("info", {})
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
merged_stats = aggregate_stats(all_episode_stats) if all_episode_stats else {}
|
||||
if dataset.meta.stats:
|
||||
for key, value in dataset.meta.stats.items():
|
||||
if key not in merged_stats:
|
||||
merged_stats[key] = value
|
||||
if merged_stats:
|
||||
write_stats(merged_stats, new_meta.root)
|
||||
|
||||
return LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=output_dir,
|
||||
image_transforms=dataset.image_transforms,
|
||||
delta_timestamps=dataset.delta_timestamps,
|
||||
tolerance_s=dataset.tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
splits: dict[str, float | list[int]],
|
||||
|
||||
@@ -117,13 +117,6 @@ Modify tasks - set default task with overrides for specific episodes (WARNING: m
|
||||
--operation.new_task "Default task" \
|
||||
--operation.episode_tasks '{"5": "Special task for episode 5"}'
|
||||
|
||||
Trim first 3 seconds from all episodes:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_trim3s \
|
||||
--operation.type trim_episode_start \
|
||||
--operation.seconds 3.0
|
||||
|
||||
Convert image dataset to video format and save locally:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -177,7 +170,6 @@ from lerobot.datasets.dataset_tools import (
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
trim_episode_start,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
@@ -223,13 +215,6 @@ class ModifyTasksConfig(OperationConfig):
|
||||
episode_tasks: dict[str, str] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("trim_episode_start")
|
||||
@dataclass
|
||||
class TrimEpisodeStartConfig(OperationConfig):
|
||||
seconds: float | None = None
|
||||
episode_indices: list[int] | None = None
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("convert_image_to_video")
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
@@ -479,41 +464,6 @@ def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
modified_dataset.push_to_hub()
|
||||
|
||||
|
||||
def handle_trim_episode_start(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, TrimEpisodeStartConfig):
|
||||
raise ValueError("Operation config must be TrimEpisodeStartConfig")
|
||||
|
||||
if cfg.operation.seconds is None:
|
||||
raise ValueError("seconds must be specified for trim_episode_start operation")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
|
||||
logging.info(
|
||||
f"Trimming first {cfg.operation.seconds}s from episodes "
|
||||
f"{cfg.operation.episode_indices if cfg.operation.episode_indices else 'ALL'} in {cfg.repo_id}"
|
||||
)
|
||||
new_dataset = trim_episode_start(
|
||||
dataset=dataset,
|
||||
seconds=cfg.operation.seconds,
|
||||
episode_indices=cfg.operation.episode_indices,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
@@ -644,8 +594,6 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "modify_tasks":
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "trim_episode_start":
|
||||
handle_trim_episode_start(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "info":
|
||||
|
||||
@@ -29,7 +29,6 @@ from lerobot.datasets.dataset_tools import (
|
||||
modify_tasks,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
trim_episode_start,
|
||||
)
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
||||
|
||||
@@ -143,104 +142,6 @@ def test_delete_empty_list(sample_dataset, tmp_path):
|
||||
)
|
||||
|
||||
|
||||
def test_trim_episode_start_updates_indices(sample_dataset, tmp_path):
|
||||
"""Test trimming episode starts updates frame/timestamp/index metadata consistently."""
|
||||
output_dir = tmp_path / "trimmed"
|
||||
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||
trim_frames = int(trim_seconds * sample_dataset.meta.fps)
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
new_dataset = trim_episode_start(
|
||||
sample_dataset,
|
||||
seconds=trim_seconds,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
expected_length = 10 - trim_frames
|
||||
assert new_dataset.meta.total_episodes == sample_dataset.meta.total_episodes
|
||||
assert new_dataset.meta.total_frames == sample_dataset.meta.total_episodes * expected_length
|
||||
|
||||
indices = [int(i.item()) for i in new_dataset.hf_dataset["index"]]
|
||||
assert indices == list(range(new_dataset.meta.total_frames))
|
||||
|
||||
episode_indices = [int(i.item()) for i in new_dataset.hf_dataset["episode_index"]]
|
||||
frame_indices = [int(i.item()) for i in new_dataset.hf_dataset["frame_index"]]
|
||||
timestamps = [float(i.item()) for i in new_dataset.hf_dataset["timestamp"]]
|
||||
|
||||
for ep_idx in range(sample_dataset.meta.total_episodes):
|
||||
ep_frame_indices = [f for e, f in zip(episode_indices, frame_indices, strict=False) if e == ep_idx]
|
||||
ep_timestamps = [t for e, t in zip(episode_indices, timestamps, strict=False) if e == ep_idx]
|
||||
|
||||
assert len(ep_frame_indices) == expected_length
|
||||
assert ep_frame_indices == list(range(expected_length))
|
||||
assert ep_timestamps[0] == pytest.approx(0.0)
|
||||
assert ep_timestamps[-1] == pytest.approx((expected_length - 1) / sample_dataset.meta.fps)
|
||||
|
||||
ep_meta = new_dataset.meta.episodes[ep_idx]
|
||||
assert int(ep_meta["length"]) == expected_length
|
||||
assert int(ep_meta["dataset_from_index"]) == ep_idx * expected_length
|
||||
assert int(ep_meta["dataset_to_index"]) == (ep_idx + 1) * expected_length
|
||||
|
||||
|
||||
def test_trim_episode_start_skips_too_short_episodes(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test too-short episodes are skipped and remaining episodes are reindexed."""
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
"observation.images.top": {"dtype": "image", "shape": (32, 32, 3), "names": None},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "source", features=features)
|
||||
|
||||
for ep_len in [10, 2, 10]:
|
||||
for _ in range(ep_len):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"observation.state": np.random.randn(2).astype(np.float32),
|
||||
"observation.images.top": np.random.randint(0, 255, size=(32, 32, 3), dtype=np.uint8),
|
||||
"task": "task",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
trim_seconds = 0.1 # 3 frames at 30 FPS
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "trimmed")
|
||||
|
||||
new_dataset = trim_episode_start(
|
||||
dataset,
|
||||
seconds=trim_seconds,
|
||||
output_dir=tmp_path / "trimmed",
|
||||
)
|
||||
|
||||
# Episode 1 is too short and gets skipped. Remaining episodes are trimmed and reindexed.
|
||||
assert new_dataset.meta.total_episodes == 2
|
||||
assert new_dataset.meta.total_frames == 14
|
||||
assert sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) == [0, 1]
|
||||
assert [int(ep["length"]) for ep in new_dataset.meta.episodes] == [7, 7]
|
||||
|
||||
|
||||
def test_trim_episode_start_rejects_when_all_selected_are_too_short(sample_dataset, tmp_path):
|
||||
"""Test trimming fails when all selected episodes are too short and would be skipped."""
|
||||
with pytest.raises(ValueError, match="All episodes selected for trimming are too short"):
|
||||
trim_episode_start(
|
||||
sample_dataset,
|
||||
seconds=1.0, # 30 frames > 10-frame episodes
|
||||
output_dir=tmp_path / "trimmed",
|
||||
)
|
||||
|
||||
|
||||
def test_split_by_episodes(sample_dataset, tmp_path):
|
||||
"""Test splitting dataset by specific episode indices."""
|
||||
splits = {
|
||||
|
||||
@@ -40,7 +40,7 @@ from lerobot.utils.constants import (
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
) # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||
|
||||
# Constants
|
||||
DUMMY_ACTION_DIM = 7
|
||||
@@ -65,6 +65,7 @@ EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 0.3536, 0.0707, 0.0000, 0.0000]
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def set_seed_all(seed: int):
|
||||
"""Set random seed for all RNG sources to ensure reproducibility."""
|
||||
random.seed(seed)
|
||||
@@ -82,6 +83,7 @@ def set_seed_all(seed: int):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def instantiate_lerobot_pi0_fast(
|
||||
from_pretrained: bool = False,
|
||||
model_path: str = MODEL_PATH_LEROBOT,
|
||||
@@ -125,6 +127,7 @@ def instantiate_lerobot_pi0_fast(
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def create_dummy_data(device=DEVICE):
|
||||
"""Create dummy data for testing both implementations."""
|
||||
batch_size = 1
|
||||
@@ -157,6 +160,7 @@ def create_dummy_data(device=DEVICE):
|
||||
# Pytest fixtures
|
||||
@pytest.fixture(scope="module")
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def pi0_fast_components():
|
||||
"""Fixture to instantiate and provide all PI0Fast components for tests."""
|
||||
print(f"\nTesting with DEVICE='{DEVICE}'")
|
||||
@@ -168,6 +172,7 @@ def pi0_fast_components():
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def policy(pi0_fast_components):
|
||||
"""Fixture to provide the PI0Fast policy for tests."""
|
||||
return pi0_fast_components[0]
|
||||
@@ -175,12 +180,14 @@ def policy(pi0_fast_components):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def preprocessor(pi0_fast_components):
|
||||
"""Fixture to provide the PI0Fast preprocessor for tests."""
|
||||
return pi0_fast_components[1]
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_preprocessor_alignment(policy, preprocessor):
|
||||
"""Test that LeRobot PI0Fast preprocessor produces expected outputs."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -228,6 +235,7 @@ def test_pi0_fast_preprocessor_alignment(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_action_generation(policy, preprocessor):
|
||||
"""Test PI0Fast LeRobot implementation generates expected actions."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -306,6 +314,7 @@ def test_pi0_fast_action_generation(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_inference_reproducibility(policy, preprocessor):
|
||||
"""Test that PI0Fast inference is reproducible with the same seed."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -347,6 +356,7 @@ def test_pi0_fast_inference_reproducibility(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_forward_pass_logits(policy, preprocessor):
|
||||
"""Test PI0Fast forward pass and compare logits against expected values."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -396,6 +406,7 @@ def test_pi0_fast_forward_pass_logits(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_action_token_sampling(policy, preprocessor):
|
||||
"""Test PI0Fast action token sampling (autoregressive decoding)."""
|
||||
print("\n" + "=" * 80)
|
||||
@@ -452,6 +463,7 @@ def test_pi0_fast_action_token_sampling(policy, preprocessor):
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_pi0_fast_detokenization(policy, preprocessor):
|
||||
"""Test PI0Fast action detokenization (FAST decoding)."""
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
@@ -14,10 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
|
||||
"""Test script to verify PI0 policy integration with LeRobot"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi0 import ( # noqa: E402
|
||||
PI0Config,
|
||||
@@ -25,10 +28,11 @@ from lerobot.policies.pi0 import ( # noqa: E402
|
||||
make_pi0_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
@@ -105,6 +109,7 @@ def test_policy_instantiation():
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
|
||||
@@ -14,10 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
|
||||
"""Test script to verify PI0.5 (pi05) support in PI0 policy"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi05 import ( # noqa: E402
|
||||
PI05Config,
|
||||
@@ -25,10 +28,11 @@ from lerobot.policies.pi05 import ( # noqa: E402
|
||||
make_pi05_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
@@ -141,6 +145,7 @@ def test_policy_instantiation():
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation"""
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation"""
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify Wall-X policy integration with LeRobot, only meant to be run locally!"""
|
||||
"""Test script to verify Wall-X policy integration with LeRobot"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -29,10 +29,11 @@ from lerobot.policies.wall_x import WallXConfig # noqa: E402
|
||||
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy # noqa: E402
|
||||
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
from tests.utils import require_cuda, require_hf_token # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
set_seed(42)
|
||||
@@ -118,6 +119,7 @@ def test_policy_instantiation():
|
||||
|
||||
|
||||
@require_cuda
|
||||
@require_hf_token
|
||||
def test_config_creation():
|
||||
"""Test policy config creation through factory."""
|
||||
try:
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
|
||||
"""Test script to verify XVLA policy integration with LeRobot vs the original implementation"""
|
||||
# ruff: noqa: E402
|
||||
|
||||
import random
|
||||
|
||||
@@ -28,7 +28,6 @@ from lerobot.scripts.lerobot_edit_dataset import (
|
||||
RemoveFeatureConfig,
|
||||
SplitConfig,
|
||||
_validate_config,
|
||||
TrimEpisodeStartConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,7 +47,6 @@ class TestOperationTypeParsing:
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("trim_episode_start", TrimEpisodeStartConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
@@ -79,7 +77,6 @@ class TestOperationTypeParsing:
|
||||
("merge", MergeConfig),
|
||||
("remove_feature", RemoveFeatureConfig),
|
||||
("modify_tasks", ModifyTasksConfig),
|
||||
("trim_episode_start", TrimEpisodeStartConfig),
|
||||
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||
("info", InfoConfig),
|
||||
],
|
||||
|
||||
@@ -108,6 +108,22 @@ def require_cuda(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_hf_token(func):
|
||||
"""
|
||||
Decorator that skips the test if no Hugging Face Hub token is available.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
from huggingface_hub import get_token
|
||||
|
||||
if get_token() is None:
|
||||
pytest.skip("requires HF token for gated model access")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_env(func):
|
||||
"""
|
||||
Decorator that skips the test if the required environment package is not installed.
|
||||
|
||||
Reference in New Issue
Block a user