From 7fd8b4c773db9a3f1366e1d5cd702ee344fd08e4 Mon Sep 17 00:00:00 2001 From: nemo Date: Sun, 22 Jun 2025 19:10:10 +0200 Subject: [PATCH] Implement loading of PEFT adapters Loading a PEFT adapter is currently done by initializing a policy with default config and then applying the adapter on the resulting model. This has the obvious drawback that any configurations done during training are not applied in the adapted model. Currently the `use_peft` attribute of `PreTrainedConfig` is only set during loading to signal the following code that it has to deal with a PEFT adapter. However we could imagine a scenario where this is already set at training time and stored alongside the adapter. --- lerobot/configs/policies.py | 3 +++ lerobot/record.py | 33 +++++++++++++++++++-------------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index a85ca187f..c4e344e7e 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -60,6 +60,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # automatic gradient scaling is used. use_amp: bool = False + # Whether the policy employed PEFT for training. + use_peft: bool = False + def __post_init__(self): self.pretrained_path = None if not self.device or not is_torch_device_available(self.device): diff --git a/lerobot/record.py b/lerobot/record.py index 8ae34d76c..4f3c32c3a 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -130,6 +130,21 @@ class DatasetRecordConfig: raise ValueError("You need to provide a task as argument in `single_task`.") +def get_policy_config_from_peft_checkpoint(peft_config): + if getattr(peft_config, "auto_mapping", None) is None: + raise ValueError( + "No auto-mapping config found in adapter config. Cannot determine policy config." + ) + + auto_mapping = getattr(peft_config, "auto_mapping", None) + base_model_class = auto_mapping["base_model_class"] + parent_library_name = auto_mapping["parent_library"] + + parent_library = importlib.import_module(parent_library_name) + target_class = getattr(parent_library, base_model_class) + return target_class.config_class + + @dataclass class RecordConfig: robot: RobotConfig @@ -152,28 +167,18 @@ class RecordConfig: if policy_path: cli_overrides = parser.get_cli_overrides("policy") - if (policy_path / 'adapter_config.json').exists(): + if (Path(policy_path) / 'adapter_config.json').exists(): # The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the # policy's config alongside the adapter config but to initialize the policy we # need a policy config. We assume that the config hasn't changed and we infer # the policy's config class from the base class mentioned in the adapter config. self.peft_config = PeftConfig.from_pretrained(policy_path) - if getattr(self.peft_config, "auto_mapping", None) is None: - raise ValueError( - "No auto-mapping config found in adapter config. Cannot determine policy config." - ) - - auto_mapping = getattr(self.peft_config, "auto_mapping", None) - base_model_class = auto_mapping["base_model_class"] - parent_library_name = auto_mapping["parent_library"] - - parent_library = importlib.import_module(parent_library_name) - target_class = getattr(parent_library, base_model_class) - policy_config_class = target_class.config_class + policy_config_class = get_policy_config_from_peft_checkpoint(self.peft_config) self.policy = policy_config_class() self.policy.pretrained_path = policy_path + self.policy.use_peft = True else: self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) @@ -308,7 +313,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: # Load pretrained policy - if cfg.use_peft: + if cfg.policy.use_peft: # in case of PEFT we re-use the policy pretrained path to point to the adapter path. peft_path = cfg.policy.pretrained_path cfg.policy.pretrained_path = None