diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx index 65e49792b..3885e1fc5 100644 --- a/docs/source/sarm.mdx +++ b/docs/source/sarm.mdx @@ -465,15 +465,16 @@ This script: ### Step 5b: Train Policy with RA-BC -Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: +Once you have the progress file, train your policy with RA-BC weighting. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: ```bash python src/lerobot/scripts/lerobot_train.py \ --dataset.repo_id=your-username/your-dataset \ --policy.type=pi0 \ - --use_rabc=true \ - --rabc_head_mode=sparse \ - --rabc_kappa=0.01 \ + --sample_weighting.type=rabc \ + --sample_weighting.progress_path=path/to/sarm_progress.parquet \ + --sample_weighting.head_mode=sparse \ + --sample_weighting.kappa=0.01 \ --output_dir=outputs/train/policy_rabc \ --batch_size=32 \ --steps=40000 @@ -488,12 +489,13 @@ The training script automatically: **RA-BC Arguments:** -| Argument | Description | Default | -| ---------------------- | ---------------------------------------------------------- | ---------------------------------- | -| `--use_rabc` | Enable RA-BC sample weighting | `false` | -| `--rabc_progress_path` | Path to progress parquet file (auto-detected from dataset) | `sarm_progress.parquet` in dataset | -| `--rabc_head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` | -| `--rabc_kappa` | Threshold κ for high-quality samples | `0.01` | +| Argument | Description | Default | +| ----------------------------------- | ------------------------------------------------------ | --------- | +| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` | +| `--sample_weighting.progress_path` | Path to progress parquet file (required for RABC) | (required)| +| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` | +| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` | +| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` | ### Tuning RA-BC Kappa @@ -511,30 +513,30 @@ The `kappa` parameter is the threshold that determines which samples get full we Monitor these WandB metrics during training: -| Metric | Healthy Range | Problem Indicator | -| ------------------ | ------------- | ------------------------- | -| `rabc_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low | -| `rabc_delta_mean` | > 0 | Should be positive | -| `rabc_delta_std` | > 0 | Variance in data quality | +| Metric | Healthy Range | Problem Indicator | +| ------------------------------- | ------------- | ------------------------- | +| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low | +| `sample_weighting/delta_mean` | > 0 | Should be positive | +| `sample_weighting/delta_std` | > 0 | Variance in data quality | -**If `rabc_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC. +**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC. **Setting kappa based on your data:** -The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `rabc_delta_mean` and `rabc_delta_std`: +The default `kappa=0.01` was tuned for the paper's T-shirt folding task (~90s episodes at 30fps). For your dataset, check the logged `sample_weighting/delta_mean` and `sample_weighting/delta_std`: ``` # If delta_mean ≈ 0.03 and delta_std ≈ 0.02: # Most deltas fall in range [0.01, 0.05] # Option 1: Set kappa = delta_mean (medium selectivity) ---rabc_kappa=0.03 +--sample_weighting.kappa=0.03 # Option 2: Set kappa = delta_mean + delta_std (high selectivity) ---rabc_kappa=0.05 +--sample_weighting.kappa=0.05 # Option 3: Set kappa = delta_mean + 2*delta_std (very selective) ---rabc_kappa=0.07 +--sample_weighting.kappa=0.07 ``` **When RA-BC may not help:** @@ -550,8 +552,9 @@ accelerate launch \ src/lerobot/scripts/lerobot_train.py \ --dataset.repo_id=your-username/your-dataset \ --policy.type=pi0 \ - --use_rabc=true \ - --rabc_kappa=0.01 \ + --sample_weighting.type=rabc \ + --sample_weighting.progress_path=path/to/sarm_progress.parquet \ + --sample_weighting.kappa=0.01 \ --output_dir=outputs/train/policy_rabc \ --batch_size=32 \ --steps=40000 @@ -576,7 +579,7 @@ accelerate launch \ ### RA-BC 1. **Train SARM first**: RA-BC quality depends entirely on SARM quality -2. **Monitor `rabc_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa)) +2. **Monitor `sample_weight_mean_weight`**: If it's ≈ 1.0, increase kappa (see [Tuning RA-BC Kappa](#tuning-ra-bc-kappa)) ---