From cd6b43ea7ac3ada097a162811acd065df417c7a8 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Wed, 29 Apr 2026 16:17:00 +0200 Subject: [PATCH] fix(train): migrate legacy RA-BC fields in train config loading (#3480) --- src/lerobot/configs/train.py | 47 ++++++++++++++++ tests/rewards/test_reward_model_base.py | 74 +++++++++++++++++++++++++ 2 files changed, 121 insertions(+) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 3f78cc07b..b2b3cd7a0 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -13,7 +13,9 @@ # limitations under the License. import builtins import datetime as dt +import json import os +import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -35,6 +37,42 @@ from .rewards import RewardModelConfig TRAIN_CONFIG_NAME = "train_config.json" +def _migrate_legacy_rabc_fields(config: dict[str, Any]) -> dict[str, Any] | None: + """Return migrated payload for legacy RA-BC fields, or None when no migration is needed.""" + legacy_fields = ( + "use_rabc", + "rabc_progress_path", + "rabc_kappa", + "rabc_epsilon", + "rabc_head_mode", + ) + if not any(key in config for key in legacy_fields): + return None + + migrated_config = dict(config) + use_rabc = bool(migrated_config.pop("use_rabc", False)) + rabc_progress_path = migrated_config.pop("rabc_progress_path", None) + rabc_kappa = migrated_config.pop("rabc_kappa", None) + rabc_epsilon = migrated_config.pop("rabc_epsilon", None) + rabc_head_mode = migrated_config.pop("rabc_head_mode", None) + + # New configs may already define sample_weighting explicitly. In that case, + # legacy fields are ignored after being stripped from the payload. + if migrated_config.get("sample_weighting") is None and use_rabc: + sample_weighting: dict[str, Any] = {"type": "rabc"} + if rabc_progress_path is not None: + sample_weighting["progress_path"] = rabc_progress_path + if rabc_kappa is not None: + sample_weighting["kappa"] = rabc_kappa + if rabc_epsilon is not None: + sample_weighting["epsilon"] = rabc_epsilon + if rabc_head_mode is not None: + sample_weighting["head_mode"] = rabc_head_mode + migrated_config["sample_weighting"] = sample_weighting + + return migrated_config + + @dataclass class TrainPipelineConfig(HubMixin): dataset: DatasetConfig @@ -218,6 +256,15 @@ class TrainPipelineConfig(HubMixin): ) from e cli_args = kwargs.pop("cli_args", []) + if config_file is not None: + with open(config_file) as f: + config = json.load(f) + migrated_config = _migrate_legacy_rabc_fields(config) + if migrated_config is not None: + with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f: + json.dump(migrated_config, f) + config_file = f.name + with draccus.config_type("json"): return draccus.parse(cls, config_file, args=cli_args) diff --git a/tests/rewards/test_reward_model_base.py b/tests/rewards/test_reward_model_base.py index c8755a0fa..1c4dad642 100644 --- a/tests/rewards/test_reward_model_base.py +++ b/tests/rewards/test_reward_model_base.py @@ -14,6 +14,7 @@ """Tests for the reward model base classes and registry.""" +import json from dataclasses import dataclass from pathlib import Path from types import SimpleNamespace @@ -251,6 +252,79 @@ def test_train_pipeline_config_trainable_config_returns_policy_when_set(): assert cfg.trainable_config.device == "cpu" +def test_train_pipeline_config_from_pretrained_migrates_legacy_rabc_fields(tmp_path): + """Legacy top-level RA-BC fields should be migrated into ``sample_weighting``.""" + from lerobot.configs.default import DatasetConfig + from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig + from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig + + cfg = TrainPipelineConfig( + dataset=DatasetConfig(repo_id="user/repo"), + policy=DiffusionConfig(device="cpu"), + ) + cfg._save_pretrained(tmp_path) + + config_path = tmp_path / TRAIN_CONFIG_NAME + with open(config_path) as f: + payload = json.load(f) + + payload.pop("sample_weighting", None) + payload.update( + { + "use_rabc": True, + "rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet", + "rabc_kappa": 0.05, + "rabc_epsilon": 1e-5, + "rabc_head_mode": "dense", + } + ) + with open(config_path, "w") as f: + json.dump(payload, f) + + loaded = TrainPipelineConfig.from_pretrained(tmp_path) + + assert loaded.sample_weighting is not None + assert loaded.sample_weighting.type == "rabc" + assert loaded.sample_weighting.progress_path == "hf://datasets/user/repo/sarm_progress.parquet" + assert loaded.sample_weighting.kappa == 0.05 + assert loaded.sample_weighting.epsilon == 1e-5 + assert loaded.sample_weighting.head_mode == "dense" + + +def test_train_pipeline_config_from_pretrained_strips_legacy_rabc_when_disabled(tmp_path): + """Legacy RA-BC fields should be ignored when ``use_rabc`` was false.""" + from lerobot.configs.default import DatasetConfig + from lerobot.configs.train import TRAIN_CONFIG_NAME, TrainPipelineConfig + from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig + + cfg = TrainPipelineConfig( + dataset=DatasetConfig(repo_id="user/repo"), + policy=DiffusionConfig(device="cpu"), + ) + cfg._save_pretrained(tmp_path) + + config_path = tmp_path / TRAIN_CONFIG_NAME + with open(config_path) as f: + payload = json.load(f) + + payload.pop("sample_weighting", None) + payload.update( + { + "use_rabc": False, + "rabc_progress_path": "hf://datasets/user/repo/sarm_progress.parquet", + "rabc_kappa": 0.05, + "rabc_epsilon": 1e-5, + "rabc_head_mode": "dense", + } + ) + with open(config_path, "w") as f: + json.dump(payload, f) + + loaded = TrainPipelineConfig.from_pretrained(tmp_path) + + assert loaded.sample_weighting is None + + # --------------------------------------------------------------------------- # PreTrainedRewardModel hub upload: push_model_to_hub + generate_model_card. # We test the generation side (offline) fully, and the upload side with HfApi