diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index a85ca187f..c4e344e7e 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -60,6 +60,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # automatic gradient scaling is used. use_amp: bool = False + # Whether the policy employed PEFT for training. + use_peft: bool = False + def __post_init__(self): self.pretrained_path = None if not self.device or not is_torch_device_available(self.device): diff --git a/lerobot/record.py b/lerobot/record.py index 8ae34d76c..4f3c32c3a 100644 --- a/lerobot/record.py +++ b/lerobot/record.py @@ -130,6 +130,21 @@ class DatasetRecordConfig: raise ValueError("You need to provide a task as argument in `single_task`.") +def get_policy_config_from_peft_checkpoint(peft_config): + if getattr(peft_config, "auto_mapping", None) is None: + raise ValueError( + "No auto-mapping config found in adapter config. Cannot determine policy config." + ) + + auto_mapping = getattr(peft_config, "auto_mapping", None) + base_model_class = auto_mapping["base_model_class"] + parent_library_name = auto_mapping["parent_library"] + + parent_library = importlib.import_module(parent_library_name) + target_class = getattr(parent_library, base_model_class) + return target_class.config_class + + @dataclass class RecordConfig: robot: RobotConfig @@ -152,28 +167,18 @@ class RecordConfig: if policy_path: cli_overrides = parser.get_cli_overrides("policy") - if (policy_path / 'adapter_config.json').exists(): + if (Path(policy_path) / 'adapter_config.json').exists(): # The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the # policy's config alongside the adapter config but to initialize the policy we # need a policy config. We assume that the config hasn't changed and we infer # the policy's config class from the base class mentioned in the adapter config. self.peft_config = PeftConfig.from_pretrained(policy_path) - if getattr(self.peft_config, "auto_mapping", None) is None: - raise ValueError( - "No auto-mapping config found in adapter config. Cannot determine policy config." - ) - - auto_mapping = getattr(self.peft_config, "auto_mapping", None) - base_model_class = auto_mapping["base_model_class"] - parent_library_name = auto_mapping["parent_library"] - - parent_library = importlib.import_module(parent_library_name) - target_class = getattr(parent_library, base_model_class) - policy_config_class = target_class.config_class + policy_config_class = get_policy_config_from_peft_checkpoint(self.peft_config) self.policy = policy_config_class() self.policy.pretrained_path = policy_path + self.policy.use_peft = True else: self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) @@ -308,7 +313,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: # Load pretrained policy - if cfg.use_peft: + if cfg.policy.use_peft: # in case of PEFT we re-use the policy pretrained path to point to the adapter path. peft_path = cfg.policy.pretrained_path cfg.policy.pretrained_path = None