From 94efcea8678ae1f998e947a01fb1ae574cac068f Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 14 Jan 2026 17:08:23 +0100 Subject: [PATCH] add automatic detection of the progress path --- docs/source/sarm.mdx | 28 ++++++------ src/lerobot/scripts/lerobot_train.py | 8 +++- src/lerobot/utils/sample_weighting.py | 64 ++++++++++++++------------- tests/utils/test_sample_weighting.py | 57 +++++++++++++++++++++++- 4 files changed, 108 insertions(+), 49 deletions(-) diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx index 3885e1fc5..9455b50e5 100644 --- a/docs/source/sarm.mdx +++ b/docs/source/sarm.mdx @@ -465,14 +465,13 @@ This script: ### Step 5b: Train Policy with RA-BC -Once you have the progress file, train your policy with RA-BC weighting. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: +Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: ```bash python src/lerobot/scripts/lerobot_train.py \ --dataset.repo_id=your-username/your-dataset \ --policy.type=pi0 \ --sample_weighting.type=rabc \ - --sample_weighting.progress_path=path/to/sarm_progress.parquet \ --sample_weighting.head_mode=sparse \ --sample_weighting.kappa=0.01 \ --output_dir=outputs/train/policy_rabc \ @@ -489,13 +488,13 @@ The training script automatically: **RA-BC Arguments:** -| Argument | Description | Default | -| ----------------------------------- | ------------------------------------------------------ | --------- | -| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` | -| `--sample_weighting.progress_path` | Path to progress parquet file (required for RABC) | (required)| -| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` | -| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` | -| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` | +| Argument | Description | Default | +| ---------------------------------- | ------------------------------------------------------ | ----------------------- | +| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` | +| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` | +| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` | +| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` | +| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` | ### Tuning RA-BC Kappa @@ -513,11 +512,11 @@ The `kappa` parameter is the threshold that determines which samples get full we Monitor these WandB metrics during training: -| Metric | Healthy Range | Problem Indicator | -| ------------------------------- | ------------- | ------------------------- | -| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low | -| `sample_weighting/delta_mean` | > 0 | Should be positive | -| `sample_weighting/delta_std` | > 0 | Variance in data quality | +| Metric | Healthy Range | Problem Indicator | +| ----------------------------- | ------------- | ------------------------- | +| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low | +| `sample_weighting/delta_mean` | > 0 | Should be positive | +| `sample_weighting/delta_std` | > 0 | Variance in data quality | **If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC. @@ -553,7 +552,6 @@ accelerate launch \ --dataset.repo_id=your-username/your-dataset \ --policy.type=pi0 \ --sample_weighting.type=rabc \ - --sample_weighting.progress_path=path/to/sarm_progress.parquet \ --sample_weighting.kappa=0.01 \ --output_dir=outputs/train/policy_rabc \ --batch_size=32 \ diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 7f08c3c55..6d6cea41e 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -382,7 +382,13 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if is_main_process: logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}") - sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device) + sample_weighter = make_sample_weighter( + cfg.sample_weighting, + policy, + device, + dataset_root=cfg.dataset.root, + dataset_repo_id=cfg.dataset.repo_id, + ) step = 0 # number of policy updates (forward + backward + optim) diff --git a/src/lerobot/utils/sample_weighting.py b/src/lerobot/utils/sample_weighting.py index 7d0f8989d..1e9329a6a 100644 --- a/src/lerobot/utils/sample_weighting.py +++ b/src/lerobot/utils/sample_weighting.py @@ -28,7 +28,7 @@ Example usage: kappa: 0.01 # In training script - sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device) + sample_weighter = make_sample_weighter(cfg.sample_weighting, policy, device, dataset_root=cfg.dataset.root, dataset_repo_id=cfg.dataset.repo_id) ... weights, stats = sample_weighter.compute_batch_weights(batch) """ @@ -37,6 +37,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass, field +from pathlib import Path from typing import TYPE_CHECKING import torch @@ -63,21 +64,12 @@ class SampleWeighter(ABC): Args: batch: Training batch dictionary containing at minimum an "index" key with global frame indices. - - Returns: - Tuple of: - - weights: Tensor of shape (batch_size,) with per-sample weights, - normalized to sum to batch_size for stable gradients. - - stats: Dictionary with logging-friendly statistics about the weights. """ @abstractmethod def get_stats(self) -> dict: """ Get global statistics about the weighting strategy. - - Returns: - Dictionary with statistics for logging (e.g., mean delta, coverage). """ @@ -112,6 +104,8 @@ def make_sample_weighter( config: SampleWeightingConfig | None, policy: PreTrainedPolicy, device: torch.device, + dataset_root: str | None = None, + dataset_repo_id: str | None = None, ) -> SampleWeighter | None: """ Factory function to create a SampleWeighter from config. @@ -122,18 +116,14 @@ def make_sample_weighter( config: Sample weighting configuration, or None to disable weighting. policy: The policy being trained (used to extract chunk_size, etc.) device: Device to place weight tensors on. - - Returns: - SampleWeighter instance, or None if config is None. - - Raises: - ValueError: If the weighting type is unknown or required params are missing. + dataset_root: Local path to dataset root (for auto-detecting progress_path). + dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path). """ if config is None: return None if config.type == "rabc": - return _make_rabc_weighter(config, policy, device) + return _make_rabc_weighter(config, policy, device, dataset_root, dataset_repo_id) if config.type == "uniform": # No-op weighter that returns uniform weights @@ -146,8 +136,18 @@ def _make_rabc_weighter( config: SampleWeightingConfig, policy: PreTrainedPolicy, device: torch.device, + dataset_root: str | None = None, + dataset_repo_id: str | None = None, ) -> SampleWeighter: - """Create RABC weighter with policy-specific initialization.""" + """Create RABC weighter with policy-specific initialization. + + Args: + config: Sample weighting configuration. + policy: The policy being trained (used to extract chunk_size). + device: Device to place weight tensors on. + dataset_root: Local path to dataset root (for auto-detecting progress_path). + dataset_repo_id: HuggingFace repo ID (for auto-detecting progress_path). + """ # Import here to avoid circular imports and keep RABC code in SARM module from lerobot.policies.sarm.rabc import RABCWeights @@ -159,15 +159,23 @@ def _make_rabc_weighter( "This is typically set for action-chunking policies like ACT, Diffusion, PI0, etc." ) - if config.progress_path is None: - raise ValueError( - "RABC sample weighting requires 'progress_path' to be set. " - "Generate progress values using: " - "python -m lerobot.policies.sarm.compute_rabc_weights --help" - ) + # Determine progress_path: use explicit config or auto-detect from dataset + progress_path = config.progress_path + if progress_path is None: + if dataset_root: + progress_path = str(Path(dataset_root) / "sarm_progress.parquet") + elif dataset_repo_id: + progress_path = f"hf://datasets/{dataset_repo_id}/sarm_progress.parquet" + else: + raise ValueError( + "RABC sample weighting requires 'progress_path' to be set, " + "or dataset_root/dataset_repo_id for auto-detection. " + "Generate progress values using: " + "python -m lerobot.policies.sarm.compute_rabc_weights --help" + ) return RABCWeights( - progress_path=config.progress_path, + progress_path=progress_path, chunk_size=chunk_size, head_mode=config.head_mode, kappa=config.kappa, @@ -209,12 +217,6 @@ class UniformWeighter(SampleWeighter): Args: batch: Training batch dictionary. - - Returns: - Batch size, or 1 if it cannot be determined. - - Raises: - ValueError: If batch is empty. """ if not batch: raise ValueError("Cannot determine batch size from empty batch") diff --git a/tests/utils/test_sample_weighting.py b/tests/utils/test_sample_weighting.py index 9e7bc61cc..e0055d317 100644 --- a/tests/utils/test_sample_weighting.py +++ b/tests/utils/test_sample_weighting.py @@ -231,8 +231,8 @@ def test_factory_rabc_requires_chunk_size(): make_sample_weighter(config, policy, device) -def test_factory_rabc_requires_progress_path(): - """Test that RABC weighter requires progress_path.""" +def test_factory_rabc_requires_progress_path_or_dataset_info(): + """Test that RABC weighter requires progress_path or dataset info for auto-detection.""" config = SampleWeightingConfig( type="rabc", progress_path=None, # No progress path @@ -242,10 +242,63 @@ def test_factory_rabc_requires_progress_path(): policy.config.chunk_size = 50 device = torch.device("cpu") + # Should fail when no progress_path AND no dataset info with pytest.raises(ValueError, match="progress_path"): make_sample_weighter(config, policy, device) +def test_factory_rabc_auto_detects_from_dataset_root(sample_progress_parquet): + """Test that RABC weighter auto-detects progress_path from dataset_root.""" + config = SampleWeightingConfig( + type="rabc", + progress_path=None, # Not provided, should auto-detect + ) + policy = Mock() + policy.config = Mock() + policy.config.chunk_size = 5 + device = torch.device("cpu") + + # The parquet file is at sample_progress_parquet, get its parent directory + dataset_root = sample_progress_parquet.parent + weighter = make_sample_weighter( + config, + policy, + device, + dataset_root=str(dataset_root), + ) + + assert weighter is not None + from lerobot.policies.sarm.rabc import RABCWeights + + assert isinstance(weighter, RABCWeights) + + +def test_factory_rabc_auto_detects_from_repo_id(): + """Test that RABC weighter constructs HF path from repo_id.""" + config = SampleWeightingConfig( + type="rabc", + progress_path=None, # Not provided, should auto-detect + ) + policy = Mock() + policy.config = Mock() + policy.config.chunk_size = 50 + device = torch.device("cpu") + + # This will construct the path but fail when trying to load (file doesn't exist) + # We just verify it doesn't raise the "progress_path required" error + with pytest.raises(Exception) as exc_info: + make_sample_weighter( + config, + policy, + device, + dataset_repo_id="test-user/test-dataset", + ) + # Should NOT be the "progress_path required" error - it should try to load the file + assert ( + "progress_path" not in str(exc_info.value).lower() or "auto-detection" in str(exc_info.value).lower() + ) + + # ============================================================================= # Integration Tests with RABCWeights # =============================================================================