Implement loading of PEFT adapters

Loading a PEFT adapter is currently done by initializing a policy with default config
and then applying the adapter on the resulting model. This has the obvious drawback
that any configurations done during training are not applied in the adapted model.

Currently the `use_peft` attribute of `PreTrainedConfig` is only set during loading
to signal the following code that it has to deal with a PEFT adapter. However
we could imagine a scenario where this is already set at training time and stored
alongside the adapter.
This commit is contained in:
nemo
2025-06-22 19:10:10 +02:00
parent 98856662c1
commit 7fd8b4c773
2 changed files with 22 additions and 14 deletions

View File

@@ -60,6 +60,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
# automatic gradient scaling is used.
use_amp: bool = False
# Whether the policy employed PEFT for training.
use_peft: bool = False
def __post_init__(self):
self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):

View File

@@ -130,6 +130,21 @@ class DatasetRecordConfig:
raise ValueError("You need to provide a task as argument in `single_task`.")
def get_policy_config_from_peft_checkpoint(peft_config):
if getattr(peft_config, "auto_mapping", None) is None:
raise ValueError(
"No auto-mapping config found in adapter config. Cannot determine policy config."
)
auto_mapping = getattr(peft_config, "auto_mapping", None)
base_model_class = auto_mapping["base_model_class"]
parent_library_name = auto_mapping["parent_library"]
parent_library = importlib.import_module(parent_library_name)
target_class = getattr(parent_library, base_model_class)
return target_class.config_class
@dataclass
class RecordConfig:
robot: RobotConfig
@@ -152,28 +167,18 @@ class RecordConfig:
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
if (policy_path / 'adapter_config.json').exists():
if (Path(policy_path) / 'adapter_config.json').exists():
# The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the
# policy's config alongside the adapter config but to initialize the policy we
# need a policy config. We assume that the config hasn't changed and we infer
# the policy's config class from the base class mentioned in the adapter config.
self.peft_config = PeftConfig.from_pretrained(policy_path)
if getattr(self.peft_config, "auto_mapping", None) is None:
raise ValueError(
"No auto-mapping config found in adapter config. Cannot determine policy config."
)
auto_mapping = getattr(self.peft_config, "auto_mapping", None)
base_model_class = auto_mapping["base_model_class"]
parent_library_name = auto_mapping["parent_library"]
parent_library = importlib.import_module(parent_library_name)
target_class = getattr(parent_library, base_model_class)
policy_config_class = target_class.config_class
policy_config_class = get_policy_config_from_peft_checkpoint(self.peft_config)
self.policy = policy_config_class()
self.policy.pretrained_path = policy_path
self.policy.use_peft = True
else:
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
@@ -308,7 +313,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
# Load pretrained policy
if cfg.use_peft:
if cfg.policy.use_peft:
# in case of PEFT we re-use the policy pretrained path to point to the adapter path.
peft_path = cfg.policy.pretrained_path
cfg.policy.pretrained_path = None