mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
Reward models refactor (#3142)
* feat(rewards): add RewardModelConfig and PreTrainedRewardModel base classes * refactor(rewards): migrate Classifier from policies/sac/reward_model/ to rewards/classifier/ * refactor(rewards): migrate SARM from policies/sarm/ to rewards/sarm/ * refactor(rewards): add rewards/factory.py and remove reward model code from policies/factory.py * refactor(rewards): update imports and delete old reward model locations * test(rewards): add reward model tests and update existing test imports * fix(rewards): restore full Classifier and SARM implementations * test(rewards): restore missing CUDA and mixed precision classifier processor tests * refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train * refactor(lerobot_train.py): add missing sampling weight script * linter + missing files * add testing for sampl weighter * revert some useless changes, improve typing * update docs * add automatic detection of the progress path * remove type exp * improve comment * fix: move rabc.py to rewards/sarm/ and update import paths * refactor(imports): update reward model imports to new module structure * refactor(imports): update reward model imports to reflect new module structure * refactor(imports): conditionally import pandas based on availability * feat(configs): add reward_model field to TrainPipelineConfig and Hub fields to RewardModelConfig * refactor(policies): remove reward model branches from policy factory and __init__ * refactor(rewards): expand __init__ facade and fix SARMConfig __post_init__ crash * feat(train): route reward model training through rewards/factory instead of policies/factory * refactor(train): streamline reward model training logic * fix(rewards): ensure FileNotFoundError is raised for missing config_file * refactor(train): update __get_path_fields__ to include reward_model for config loading * refactor(classifier): remove redundant input normalization in predict_reward method * fix(train): raise ValueError for non-trainable reward models in train function * refactor(pretrained_rm): add model card template * refactor(tests): reward models * refactor(sarm): update reset method and remove unused action prediction methods * refactor(wandb): differentiate tags for reward model and policy training in cfg_to_group function * fix(train): raise ValueError for PEFT usage in reward model training * refactor(rewards): enhance RewardModelConfig with device handling and delta indices properties --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import make_reward_pre_post_processors
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -70,8 +71,8 @@ def update_policy(
|
||||
accelerator: "Accelerator",
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
rabc_weights_provider=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
sample_weighter=None,
|
||||
) -> tuple[MetricsTracker, dict | None]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
|
||||
@@ -87,7 +88,7 @@ def update_policy(
|
||||
accelerator: The Accelerator instance for distributed training and mixed precision.
|
||||
lr_scheduler: An optional learning rate scheduler.
|
||||
lock: An optional lock for thread-safe optimizer updates.
|
||||
rabc_weights_provider: Optional RABCWeights instance for sample weighting.
|
||||
sample_weighter: Optional SampleWeighter instance for per-sample loss weighting.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@@ -97,27 +98,31 @@ def update_policy(
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Get RA-BC weights if enabled
|
||||
rabc_batch_weights = None
|
||||
rabc_batch_stats = None
|
||||
if rabc_weights_provider is not None:
|
||||
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
|
||||
# Compute sample weights if a weighter is provided
|
||||
sample_weights = None
|
||||
weight_stats = None
|
||||
if sample_weighter is not None:
|
||||
sample_weights, weight_stats = sample_weighter.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
# Use per-sample loss when RA-BC is enabled for proper weighting
|
||||
if rabc_batch_weights is not None:
|
||||
# Get per-sample losses
|
||||
if sample_weights is not None:
|
||||
# Use per-sample loss for weighted training
|
||||
# Note: Policies supporting sample weighting must implement forward(batch, reduction="none")
|
||||
per_sample_loss, output_dict = policy.forward(batch, reduction="none")
|
||||
|
||||
# Apply RA-BC weights: L_RA-BC = Σ(w_i * l_i) / (Σw_i + ε)
|
||||
# rabc_batch_weights is already normalized to sum to batch_size
|
||||
# Weighted loss: each sample's contribution is scaled by its weight.
|
||||
# We divide by weight sum (not batch size) so that if some weights are zero,
|
||||
# the remaining samples contribute proportionally more, preserving gradient scale.
|
||||
# Weights are pre-normalized to sum to batch_size for stable training dynamics.
|
||||
epsilon = 1e-6
|
||||
loss = (per_sample_loss * rabc_batch_weights).sum() / (rabc_batch_weights.sum() + epsilon)
|
||||
# Log raw mean weight (before normalization) - this is the meaningful metric
|
||||
output_dict["rabc_mean_weight"] = rabc_batch_stats["raw_mean_weight"]
|
||||
output_dict["rabc_num_zero_weight"] = rabc_batch_stats["num_zero_weight"]
|
||||
output_dict["rabc_num_full_weight"] = rabc_batch_stats["num_full_weight"]
|
||||
loss = (per_sample_loss * sample_weights).sum() / (sample_weights.sum() + epsilon)
|
||||
|
||||
# Log weighting statistics
|
||||
if output_dict is None:
|
||||
output_dict = {}
|
||||
for key, value in weight_stats.items():
|
||||
output_dict[f"sample_weight_{key}"] = value
|
||||
else:
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
@@ -188,8 +193,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when policy.device is set to CPU.
|
||||
force_cpu = cfg.policy.device == "cpu"
|
||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||
force_cpu = cfg.trainable_config.device == "cpu"
|
||||
accelerator = Accelerator(
|
||||
step_scheduler_with_optimizer=False,
|
||||
kwargs_handlers=[ddp_kwargs],
|
||||
@@ -245,26 +250,44 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info("Creating env")
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
if cfg.is_reward_model_training:
|
||||
if is_main_process:
|
||||
logging.info("Creating reward model")
|
||||
from lerobot.rewards import make_reward_model
|
||||
|
||||
policy = make_reward_model(
|
||||
cfg=cfg.reward_model,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
dataset_meta=dataset.meta,
|
||||
)
|
||||
if not policy.is_trainable:
|
||||
raise ValueError(
|
||||
f"Reward model '{policy.name}' is zero-shot and cannot be trained via lerobot-train. "
|
||||
"Use it directly for inference via compute_reward() (e.g. offline precompute)."
|
||||
)
|
||||
else:
|
||||
if is_main_process:
|
||||
logging.info("Creating policy")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
rename_map=cfg.rename_map,
|
||||
)
|
||||
|
||||
if cfg.peft is not None:
|
||||
if cfg.is_reward_model_training:
|
||||
raise ValueError("PEFT is only supported for policy training. ")
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
# Convert CLI peft config to dict for overrides
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
# Wait for all processes to finish model creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
processor_pretrained_path = cfg.policy.pretrained_path
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
if (
|
||||
getattr(cfg.policy, "use_relative_actions", False)
|
||||
getattr(active_cfg, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
@@ -274,18 +297,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
# Create processors - only provide dataset_stats if not resuming from saved processors
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
# For SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type == "sarm":
|
||||
if cfg.is_reward_model_training:
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if processor_pretrained_path is not None:
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
@@ -305,38 +325,36 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
if cfg.is_reward_model_training:
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(
|
||||
cfg.reward_model,
|
||||
**processor_kwargs,
|
||||
)
|
||||
else:
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
# Load precomputed SARM progress for RA-BC if enabled
|
||||
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
|
||||
rabc_weights = None
|
||||
if cfg.use_rabc:
|
||||
from lerobot.utils.rabc import RABCWeights
|
||||
# Create sample weighter if configured (e.g., for RA-BC training)
|
||||
sample_weighter = None
|
||||
if cfg.sample_weighting is not None:
|
||||
from lerobot.utils.sample_weighting import make_sample_weighter
|
||||
|
||||
# Get chunk_size from policy config
|
||||
chunk_size = getattr(policy.config, "chunk_size", None)
|
||||
if chunk_size is None:
|
||||
raise ValueError("Chunk size is not found in policy config")
|
||||
|
||||
head_mode = getattr(cfg, "rabc_head_mode", "sparse")
|
||||
logging.info(f"Loading SARM progress for RA-BC from {cfg.rabc_progress_path}")
|
||||
logging.info(f"Using chunk_size={chunk_size} from policy config, head_mode={head_mode}")
|
||||
rabc_weights = RABCWeights(
|
||||
progress_path=cfg.rabc_progress_path,
|
||||
chunk_size=chunk_size,
|
||||
head_mode=head_mode,
|
||||
kappa=getattr(cfg, "rabc_kappa", 0.01),
|
||||
epsilon=getattr(cfg, "rabc_epsilon", 1e-6),
|
||||
device=device,
|
||||
if is_main_process:
|
||||
logging.info(f"Creating sample weighter: {cfg.sample_weighting.type}")
|
||||
sample_weighter = make_sample_weighter(
|
||||
cfg.sample_weighting,
|
||||
policy,
|
||||
device,
|
||||
dataset_root=cfg.dataset.root,
|
||||
dataset_repo_id=cfg.dataset.repo_id,
|
||||
)
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
@@ -365,13 +383,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
@@ -448,7 +466,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
rabc_weights_provider=rabc_weights,
|
||||
sample_weighter=sample_weighter,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -467,16 +485,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log RA-BC statistics if enabled
|
||||
if rabc_weights is not None:
|
||||
rabc_stats = rabc_weights.get_stats()
|
||||
wandb_log_dict.update(
|
||||
{
|
||||
"rabc_delta_mean": rabc_stats["delta_mean"],
|
||||
"rabc_delta_std": rabc_stats["delta_std"],
|
||||
"rabc_num_frames": rabc_stats["num_frames"],
|
||||
}
|
||||
)
|
||||
# Log sample weighting statistics if enabled
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
@@ -558,14 +570,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
logging.info("End of training")
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
unwrapped_policy = accelerator.unwrap_model(policy)
|
||||
if cfg.policy.use_peft:
|
||||
unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy)
|
||||
if getattr(active_cfg, "push_to_hub", False):
|
||||
unwrapped_model = accelerator.unwrap_model(policy)
|
||||
# PEFT only applies when training a policy — reward models use the plain path.
|
||||
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
||||
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
||||
else:
|
||||
unwrapped_policy.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
unwrapped_model.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||
|
||||
# Properly clean up the distributed process group
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
Reference in New Issue
Block a user