add automatic detection of the progress path

This commit is contained in:
Michel Aractingi
2026-01-14 17:08:23 +01:00
parent faa276b8cf
commit 94efcea867
4 changed files with 108 additions and 49 deletions

View File

@@ -465,14 +465,13 @@ This script:
### Step 5b: Train Policy with RA-BC
Once you have the progress file, train your policy with RA-BC weighting. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`) if not explicitly provided. Currently PI0, PI0.5 and SmolVLA are supported with RA-BC:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--sample_weighting.type=rabc \
--sample_weighting.progress_path=path/to/sarm_progress.parquet \
--sample_weighting.head_mode=sparse \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
@@ -489,13 +488,13 @@ The training script automatically:
**RA-BC Arguments:**
| Argument | Description | Default |
| ----------------------------------- | ------------------------------------------------------ | --------- |
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
| `--sample_weighting.progress_path` | Path to progress parquet file (required for RABC) | (required)|
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
| Argument | Description | Default |
| ---------------------------------- | ------------------------------------------------------ | ----------------------- |
| `--sample_weighting.type` | Weighting strategy type (`rabc` or `uniform`) | `rabc` |
| `--sample_weighting.progress_path` | Path to progress parquet file | `sarm_progress.parquet` |
| `--sample_weighting.head_mode` | Which SARM head's progress to use: `sparse` or `dense` | `sparse` |
| `--sample_weighting.kappa` | Threshold κ for high-quality samples | `0.01` |
| `--sample_weighting.epsilon` | Small constant for numerical stability | `1e-6` |
### Tuning RA-BC Kappa
@@ -513,11 +512,11 @@ The `kappa` parameter is the threshold that determines which samples get full we
Monitor these WandB metrics during training:
| Metric | Healthy Range | Problem Indicator |
| ------------------------------- | ------------- | ------------------------- |
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `sample_weighting/delta_mean` | > 0 | Should be positive |
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
| Metric | Healthy Range | Problem Indicator |
| ----------------------------- | ------------- | ------------------------- |
| `sample_weight_mean_weight` | 0.3 - 0.8 | ≈ 1.0 means kappa too low |
| `sample_weighting/delta_mean` | > 0 | Should be positive |
| `sample_weighting/delta_std` | > 0 | Variance in data quality |
**If `sample_weight_mean_weight ≈ 1.0`:** Your kappa is too low. Most samples have `delta > kappa` and bypass the soft-weighting entirely. RA-BC becomes equivalent to vanilla BC.
@@ -553,7 +552,6 @@ accelerate launch \
--dataset.repo_id=your-username/your-dataset \
--policy.type=pi0 \
--sample_weighting.type=rabc \
--sample_weighting.progress_path=path/to/sarm_progress.parquet \
--sample_weighting.kappa=0.01 \
--output_dir=outputs/train/policy_rabc \
--batch_size=32 \