Reward models refactor (#3142)

* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes

* refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/

* refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/

* refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py

* refactor(rewards): update imports and delete old reward model locations

* test(rewards): add reward model tests and update existing test imports

* fix(rewards): restore full Classifier and SARM implementations

* test(rewards): restore missing CUDA and mixed precision classifier processor tests

* refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train

* refactor(lerobot_train.py): add missing sampling weight script

* linter + missing files

* add testing for sampl weighter

* revert some useless changes, improve typing

* update docs

* add automatic detection of the progress path

* remove type exp

* improve comment

* fix: move rabc.py to rewards/sarm/ and update import paths

* refactor(imports): update reward model imports to new module structure

* refactor(imports): update reward model imports to reflect new module structure

* refactor(imports): conditionally import pandas based on availability

* feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig

* refactor(policies): remove reward model branches from policy factory and __init__

* refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash

* feat(train): route reward model training through rewards/factory instead of policies/factory

* refactor(train): streamline reward model training logic

* fix(rewards): ensure FileNotFoundError is raised for missing config_file

* refactor(train): update __get_path_fields__ to include reward_model for config loading

* refactor(classifier): remove redundant input normalization in predict_reward method

* fix(train): raise ValueError for non-trainable reward models in train function

* refactor(pretrained_rm): add model card template

* refactor(tests): reward models

* refactor(sarm): update reset method and remove unused action prediction methods

* refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function

* fix(train): raise ValueError for PEFT usage in reward model training

* refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Khalil Meftah
2026-04-28 17:56:24 +02:00
committed by GitHub
parent 03ee50e08f
commit 8a3d64033f
37 changed files with 2091 additions and 381 deletions

View File

@@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.rewards import make_reward_pre_post_processors
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
@@ -70,8 +71,8 @@ def update_policy(
accelerator: "Accelerator",
lr_scheduler=None,
lock=None,
rabc_weights_provider=None,
) -> tuple[MetricsTracker, dict]:
sample_weighter=None,
) -> tuple[MetricsTracker, dict | None]:
"""
Performs a single training step to update the policy's weights.
@@ -87,7 +88,7 @@ def update_policy(
accelerator: The Accelerator instance for distributed training and mixed precision.
lr_scheduler: An optional learning rate scheduler.
lock: An optional lock for thread-safe optimizer updates.
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
Returns:
A tuple containing:
@@ -97,27 +98,31 @@ def update_policy(
start_time = time.perf_counter()
policy.train()
# Get RA-BC weights if enabled
rabc_batch_weights = None
rabc_batch_stats = None
if rabc_weights_provider is not None:
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
# Compute sample weights if a weighter is provided
sample_weights = None
weight_stats = None
if sample_weighter is not None:
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
# Let accelerator handle mixed precision
with accelerator.autocast():
# Use per-sample loss when RA-BC is enabled for proper weighting
if rabc_batch_weights is not None:
# Get per-sample losses
if sample_weights is not None:
# Use per-sample loss for weighted training
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
# rabc_batch_weights is already normalized to sum to batch_size
# Weighted loss: each sample's contribution is scaled by its weight.
# We divide by weight sum (not batch size) so that if some weights are zero,
# the remaining samples contribute proportionally more, preserving gradient scale.
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
epsilon = 1e-6
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
# Log raw mean weight (before normalization) - this is the meaningful metric
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
# Log weighting statistics
if output_dict is None:
output_dict = {}
for key, value in weight_stats.items():
output_dict[f"sample_weight_{key}"] = value
else:
loss, output_dict = policy.forward(batch)
@@ -188,8 +193,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when policy.device is set to CPU.
force_cpu = cfg.policy.device == "cpu"
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
force_cpu = cfg.trainable_config.device == "cpu"
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
kwargs_handlers=[ddp_kwargs],
@@ -245,26 +250,44 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
if is_main_process:
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
if cfg.is_reward_model_training:
if is_main_process:
logging.info("Creating reward model")
from lerobot.rewards import make_reward_model
policy = make_reward_model(
cfg=cfg.reward_model,
dataset_stats=dataset.meta.stats,
dataset_meta=dataset.meta,
)
if not policy.is_trainable:
raise ValueError(
f"Reward model '{policy.name}' is zero-shot and cannot be trained via lerobot-train. "
"Use it directly for inference via compute_reward() (e.g. offline precompute)."
)
else:
if is_main_process:
logging.info("Creating policy")
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
if cfg.peft is not None:
if cfg.is_reward_model_training:
raise ValueError("PEFT is only supported for policy training. ")
logging.info("Using PEFT! Wrapping model.")
# Convert CLI peft config to dict for overrides
peft_cli_overrides = dataclasses.asdict(cfg.peft)
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
# Wait for all processes to finish policy creation before continuing
# Wait for all processes to finish model creation before continuing
accelerator.wait_for_everyone()
processor_pretrained_path = cfg.policy.pretrained_path
active_cfg = cfg.trainable_config
processor_pretrained_path = active_cfg.pretrained_path
if (
getattr(cfg.policy, "use_relative_actions", False)
getattr(active_cfg, "use_relative_actions", False)
and processor_pretrained_path is not None
and not cfg.resume
):
@@ -274,18 +297,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
)
processor_pretrained_path = None
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
# For SARM, always provide dataset_meta for progress normalization
if cfg.policy.type == "sarm":
if cfg.is_reward_model_training:
processor_kwargs["dataset_meta"] = dataset.meta
if processor_pretrained_path is not None:
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
"normalizer_processor": {
@@ -305,38 +325,36 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
if cfg.is_reward_model_training:
preprocessor, postprocessor = make_reward_pre_post_processors(
cfg.reward_model,
**processor_kwargs,
)
else:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
if is_main_process:
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
# Load precomputed SARM progress for RA-BC if enabled
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
rabc_weights = None
if cfg.use_rabc:
from lerobot.utils.rabc import RABCWeights
# Create sample weighter if configured (e.g., for RA-BC training)
sample_weighter = None
if cfg.sample_weighting is not None:
from lerobot.utils.sample_weighting import make_sample_weighter
# Get chunk_size from policy config
chunk_size = getattr(policy.config, "chunk_size", None)
if chunk_size is None:
raise ValueError("Chunk size is not found in policy config")
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
rabc_weights = RABCWeights(
progress_path=cfg.rabc_progress_path,
chunk_size=chunk_size,
head_mode=head_mode,
kappa=getattr(cfg, "rabc_kappa", 0.01),
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
device=device,
if is_main_process:
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
sample_weighter = make_sample_weighter(
cfg.sample_weighting,
policy,
device,
dataset_root=cfg.dataset.root,
dataset_repo_id=cfg.dataset.repo_id,
)
step = 0 # number of policy updates (forward + backward + optim)
@@ -365,13 +383,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# create dataloader for offline training
if hasattr(cfg.policy, "drop_n_last_frames"):
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
)
else:
@@ -448,7 +466,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
cfg.optimizer.grad_clip_norm,
accelerator=accelerator,
lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights,
sample_weighter=sample_weighter,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@@ -467,16 +485,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weights is not None:
rabc_stats = rabc_weights.get_stats()
wandb_log_dict.update(
{
"rabc_delta_mean": rabc_stats["delta_mean"],
"rabc_delta_std": rabc_stats["delta_std"],
"rabc_num_frames": rabc_stats["num_frames"],
}
)
# Log sample weighting statistics if enabled
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
@@ -558,14 +570,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process:
logging.info("End of training")
if cfg.policy.push_to_hub:
unwrapped_policy = accelerator.unwrap_model(policy)
if cfg.policy.use_peft:
unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy)
if getattr(active_cfg, "push_to_hub", False):
unwrapped_model = accelerator.unwrap_model(policy)
# PEFT only applies when training a policy — reward models use the plain path.
if not cfg.is_reward_model_training and cfg.policy.use_peft:
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
else:
unwrapped_policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id)
postprocessor.push_to_hub(cfg.policy.repo_id)
unwrapped_model.push_model_to_hub(cfg)
preprocessor.push_to_hub(active_cfg.repo_id)
postprocessor.push_to_hub(active_cfg.repo_id)
# Properly clean up the distributed process group
accelerator.wait_for_everyone()