mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
Merge branch 'main' into docs/add-lelab
This commit is contained in:
@@ -177,6 +177,12 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
|
||||
active_cfg = self.trainable_config
|
||||
if self.rename_map and active_cfg.pretrained_path is None:
|
||||
raise ValueError(
|
||||
"`rename_map` requires a pretrained policy checkpoint. "
|
||||
"Fresh initialization derives feature names from the current dataset, so no rename is applied."
|
||||
)
|
||||
|
||||
if not self.job_name:
|
||||
if self.env is None:
|
||||
self.job_name = f"{active_cfg.type}"
|
||||
|
||||
@@ -81,7 +81,7 @@ def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) ->
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@ProcessorStepRegistry.register("relative_actions_processor")
|
||||
@dataclass
|
||||
class RelativeActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to relative actions (action -= state) for masked dimensions.
|
||||
|
||||
@@ -292,19 +292,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
active_cfg = cfg.trainable_config
|
||||
processor_pretrained_path = active_cfg.pretrained_path
|
||||
if (
|
||||
getattr(active_cfg, "use_relative_actions", False)
|
||||
and processor_pretrained_path is not None
|
||||
and not cfg.resume
|
||||
):
|
||||
logging.warning(
|
||||
"use_relative_actions=true with pretrained processors can skip relative transforms if "
|
||||
"the checkpoint processors do not define them. Building processors from current policy config."
|
||||
)
|
||||
processor_pretrained_path = None
|
||||
|
||||
processor_kwargs = {}
|
||||
postprocessor_kwargs = {}
|
||||
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
@@ -312,24 +301,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if not cfg.is_reward_model_training and processor_pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": device.type},
|
||||
"normalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"features": {**policy.config.input_features, **policy.config.output_features},
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
}
|
||||
processor_kwargs["preprocessor_overrides"]["rename_observations_processor"] = {
|
||||
"rename_map": cfg.rename_map
|
||||
}
|
||||
postprocessor_kwargs["postprocessor_overrides"] = {
|
||||
postprocessor_overrides = {
|
||||
"unnormalizer_processor": {
|
||||
"stats": dataset.meta.stats,
|
||||
"features": policy.config.output_features,
|
||||
"norm_map": policy.config.normalization_mapping,
|
||||
},
|
||||
}
|
||||
if getattr(active_cfg, "use_relative_actions", False):
|
||||
preprocessor_overrides["relative_actions_processor"] = {
|
||||
"enabled": True,
|
||||
"exclude_joints": getattr(active_cfg, "relative_exclude_joints", []),
|
||||
"action_names": getattr(active_cfg, "action_feature_names", None),
|
||||
}
|
||||
postprocessor_overrides["absolute_actions_processor"] = {"enabled": True}
|
||||
processor_kwargs["preprocessor_overrides"] = preprocessor_overrides
|
||||
processor_kwargs["postprocessor_overrides"] = postprocessor_overrides
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
preprocessor, postprocessor = make_reward_pre_post_processors(
|
||||
@@ -341,7 +337,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
**processor_kwargs,
|
||||
**postprocessor_kwargs,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user