From 2cdd9f43f7e104702ed74c9b6e809d079527314e Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 13 Jan 2026 01:42:53 +0100 Subject: [PATCH] fix: train tokenizer CLI entry point (#2784) --- .../scripts/lerobot_train_tokenizer.py | 177 +++++++++++------- 1 file changed, 105 insertions(+), 72 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 03bfcaaf8..296447bad 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -14,11 +14,14 @@ """Train FAST tokenizer for action encoding. This script: -1. Loads action chunks from LeRobotDataset (with sampling) -2. Applies delta transforms and per-timestamp normalization -3. Trains FAST tokenizer on specified action dimensions -4. Saves tokenizer to assets directory -5. Reports compression statistics +1. Loads action chunks from LeRobotDataset (with episode sampling) +2. Optionally applies delta transforms (relative vs absolute actions) +3. Extracts specified action dimensions for encoding +4. Applies normalization (MEAN_STD, MIN_MAX, QUANTILES, or other modes) +5. Trains FAST tokenizer (BPE on DCT coefficients) on the action chunks +6. Saves tokenizer to output directory +7. Optionally pushes tokenizer to Hugging Face Hub +8. Reports compression statistics Example: @@ -42,18 +45,64 @@ lerobot-train-tokenizer \ """ import json +from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import torch -import tyro from huggingface_hub import HfApi -from transformers import AutoProcessor +from lerobot.utils.import_utils import _transformers_available + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoProcessor +else: + AutoProcessor = None + +from lerobot.configs import parser from lerobot.configs.types import NormalizationMode from lerobot.datasets.lerobot_dataset import LeRobotDataset +@dataclass +class TokenizerTrainingConfig: + """Configuration for training FAST tokenizer.""" + + # LeRobot dataset repository ID + repo_id: str + # Root directory for dataset (default: ~/.cache/huggingface/lerobot) + root: str | None = None + # Number of future actions in each chunk + action_horizon: int = 10 + # Max episodes to use (None = all episodes in dataset) + max_episodes: int | None = None + # Fraction of chunks to sample per episode + sample_fraction: float = 0.1 + # Comma-separated dimension ranges to encode (e.g., "0:6,7:23") + encoded_dims: str = "0:6,7:23" + # Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5") + delta_dims: str | None = None + # Whether to apply delta transform (relative actions vs absolute actions) + use_delta_transform: bool = False + # Dataset key for state observations (default: "observation.state") + state_key: str = "observation.state" + # Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY) + normalization_mode: str = "QUANTILES" + # FAST vocabulary size (BPE vocab size) + vocab_size: int = 1024 + # DCT scaling factor (default: 10.0) + scale: float = 10.0 + # Directory to save tokenizer (default: ./fast_tokenizer_{repo_id}) + output_dir: str | None = None + # Whether to push the tokenizer to Hugging Face Hub + push_to_hub: bool = False + # Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name + hub_repo_id: str | None = None + # Whether to create a private repository on the Hub + hub_private: bool = False + + def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray: """Apply delta transform to specified dimensions. @@ -327,88 +376,57 @@ def compute_compression_stats(tokenizer, action_chunks: np.ndarray): return stats -def main( - repo_id: str, - root: str | None = None, - action_horizon: int = 10, - max_episodes: int | None = None, - sample_fraction: float = 0.1, - encoded_dims: str = "0:6,7:23", - delta_dims: str | None = None, - use_delta_transform: bool = False, - state_key: str = "observation.state", - normalization_mode: str = "QUANTILES", - vocab_size: int = 1024, - scale: float = 10.0, - output_dir: str | None = None, - push_to_hub: bool = False, - hub_repo_id: str | None = None, - hub_private: bool = False, -): +@parser.wrap() +def train_tokenizer(cfg: TokenizerTrainingConfig): """ Train FAST tokenizer for action encoding. Args: - repo_id: LeRobot dataset repository ID - root: Root directory for dataset (default: ~/.cache/huggingface/lerobot) - action_horizon: Number of future actions in each chunk - max_episodes: Max episodes to use (None = all episodes in dataset) - sample_fraction: Fraction of chunks to sample per episode - encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23") - delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5") - use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions) - state_key: Dataset key for state observations (default: "observation.state") - normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY) - vocab_size: FAST vocabulary size (BPE vocab size) - scale: DCT scaling factor (default: 10.0) - output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id}) - push_to_hub: Whether to push the tokenizer to Hugging Face Hub - hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name - hub_private: Whether to create a private repository on the Hub + cfg: TokenizerTrainingConfig dataclass with all configuration parameters """ # load dataset - print(f"Loading dataset: {repo_id}") - dataset = LeRobotDataset(repo_id=repo_id, root=root) + print(f"Loading dataset: {cfg.repo_id}") + dataset = LeRobotDataset(repo_id=cfg.repo_id, root=cfg.root) print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames") # parse normalization mode try: - norm_mode = NormalizationMode(normalization_mode) + norm_mode = NormalizationMode(cfg.normalization_mode) except ValueError as err: raise ValueError( - f"Invalid normalization_mode: {normalization_mode}. " + f"Invalid normalization_mode: {cfg.normalization_mode}. " f"Must be one of: {', '.join([m.value for m in NormalizationMode])}" ) from err print(f"Normalization mode: {norm_mode.value}") # parse encoded dimensions encoded_dim_ranges = [] - for range_str in encoded_dims.split(","): + for range_str in cfg.encoded_dims.split(","): start, end = map(int, range_str.strip().split(":")) encoded_dim_ranges.append((start, end)) total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges) - print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}") + print(f"Encoding {total_encoded_dims} dimensions: {cfg.encoded_dims}") # parse delta dimensions delta_dim_list = None - if delta_dims is not None and delta_dims.strip(): - delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")] + if cfg.delta_dims is not None and cfg.delta_dims.strip(): + delta_dim_list = [int(d.strip()) for d in cfg.delta_dims.split(",")] print(f"Delta dimensions: {delta_dim_list}") else: print("No delta dimensions specified") - print(f"Use delta transform: {use_delta_transform}") - if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0): + print(f"Use delta transform: {cfg.use_delta_transform}") + if cfg.use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0): print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.") - print(f"Action horizon: {action_horizon}") - print(f"State key: {state_key}") + print(f"Action horizon: {cfg.action_horizon}") + print(f"State key: {cfg.state_key}") # determine episodes to process num_episodes = dataset.num_episodes - if max_episodes is not None: - num_episodes = min(max_episodes, num_episodes) + if cfg.max_episodes is not None: + num_episodes = min(cfg.max_episodes, num_episodes) print(f"Processing {num_episodes} episodes...") @@ -419,7 +437,15 @@ def main( print(f" Processing episode {ep_idx}/{num_episodes}...") chunks = process_episode( - (dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform) + ( + dataset, + ep_idx, + cfg.action_horizon, + delta_dim_list, + cfg.sample_fraction, + cfg.state_key, + cfg.use_delta_transform, + ) ) if chunks is not None: all_chunks.append(chunks) @@ -495,16 +521,17 @@ def main( # train FAST tokenizer tokenizer = train_fast_tokenizer( encoded_chunks, - vocab_size=vocab_size, - scale=scale, + vocab_size=cfg.vocab_size, + scale=cfg.scale, ) # compute compression statistics compression_stats = compute_compression_stats(tokenizer, encoded_chunks) # save tokenizer + output_dir = cfg.output_dir if output_dir is None: - output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}" + output_dir = f"fast_tokenizer_{cfg.repo_id.replace('/', '_')}" output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) @@ -512,18 +539,18 @@ def main( # save metadata metadata = { - "repo_id": repo_id, - "vocab_size": vocab_size, - "scale": scale, - "encoded_dims": encoded_dims, + "repo_id": cfg.repo_id, + "vocab_size": cfg.vocab_size, + "scale": cfg.scale, + "encoded_dims": cfg.encoded_dims, "encoded_dim_ranges": encoded_dim_ranges, "total_encoded_dims": total_encoded_dims, - "delta_dims": delta_dims, + "delta_dims": cfg.delta_dims, "delta_dim_list": delta_dim_list, - "use_delta_transform": use_delta_transform, - "state_key": state_key, + "use_delta_transform": cfg.use_delta_transform, + "state_key": cfg.state_key, "normalization_mode": norm_mode.value, - "action_horizon": action_horizon, + "action_horizon": cfg.action_horizon, "num_training_chunks": len(encoded_chunks), "compression_stats": compression_stats, } @@ -535,21 +562,22 @@ def main( print(f"Metadata: {json.dumps(metadata, indent=2)}") # push to Hugging Face Hub if requested - if push_to_hub: + if cfg.push_to_hub: # determine the hub repository ID + hub_repo_id = cfg.hub_repo_id if hub_repo_id is None: hub_repo_id = output_path.name print(f"\nNo hub_repo_id provided, using: {hub_repo_id}") print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}") - print(f" Private: {hub_private}") + print(f" Private: {cfg.hub_private}") try: # use the tokenizer's push_to_hub method tokenizer.push_to_hub( repo_id=hub_repo_id, - private=hub_private, - commit_message=f"Upload FAST tokenizer trained on {repo_id}", + private=cfg.hub_private, + commit_message=f"Upload FAST tokenizer trained on {cfg.repo_id}", ) # also upload the metadata.json file separately @@ -568,5 +596,10 @@ def main( print(" Make sure you're logged in with `huggingface-cli login`") +def main(): + """CLI entry point that parses arguments and runs the tokenizer training.""" + train_tokenizer() + + if __name__ == "__main__": - tyro.cli(main) + main()