mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Implement loading of PEFT adapters
Loading a PEFT adapter is currently done by initializing a policy with default config and then applying the adapter on the resulting model. This has the obvious drawback that any configurations done during training are not applied in the adapted model. Currently the `use_peft` attribute of `PreTrainedConfig` is only set during loading to signal the following code that it has to deal with a PEFT adapter. However we could imagine a scenario where this is already set at training time and stored alongside the adapter.
This commit is contained in:
@@ -60,6 +60,9 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
# Whether the policy employed PEFT for training.
|
||||
use_peft: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.pretrained_path = None
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
|
||||
@@ -130,6 +130,21 @@ class DatasetRecordConfig:
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
|
||||
def get_policy_config_from_peft_checkpoint(peft_config):
|
||||
if getattr(peft_config, "auto_mapping", None) is None:
|
||||
raise ValueError(
|
||||
"No auto-mapping config found in adapter config. Cannot determine policy config."
|
||||
)
|
||||
|
||||
auto_mapping = getattr(peft_config, "auto_mapping", None)
|
||||
base_model_class = auto_mapping["base_model_class"]
|
||||
parent_library_name = auto_mapping["parent_library"]
|
||||
|
||||
parent_library = importlib.import_module(parent_library_name)
|
||||
target_class = getattr(parent_library, base_model_class)
|
||||
return target_class.config_class
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecordConfig:
|
||||
robot: RobotConfig
|
||||
@@ -152,28 +167,18 @@ class RecordConfig:
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
|
||||
if (policy_path / 'adapter_config.json').exists():
|
||||
if (Path(policy_path) / 'adapter_config.json').exists():
|
||||
# The pretrained checkpoint is a PEFT adapter, cool. Currently we don't upload the
|
||||
# policy's config alongside the adapter config but to initialize the policy we
|
||||
# need a policy config. We assume that the config hasn't changed and we infer
|
||||
# the policy's config class from the base class mentioned in the adapter config.
|
||||
self.peft_config = PeftConfig.from_pretrained(policy_path)
|
||||
|
||||
if getattr(self.peft_config, "auto_mapping", None) is None:
|
||||
raise ValueError(
|
||||
"No auto-mapping config found in adapter config. Cannot determine policy config."
|
||||
)
|
||||
|
||||
auto_mapping = getattr(self.peft_config, "auto_mapping", None)
|
||||
base_model_class = auto_mapping["base_model_class"]
|
||||
parent_library_name = auto_mapping["parent_library"]
|
||||
|
||||
parent_library = importlib.import_module(parent_library_name)
|
||||
target_class = getattr(parent_library, base_model_class)
|
||||
policy_config_class = target_class.config_class
|
||||
policy_config_class = get_policy_config_from_peft_checkpoint(self.peft_config)
|
||||
|
||||
self.policy = policy_config_class()
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.policy.use_peft = True
|
||||
|
||||
else:
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
@@ -308,7 +313,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
# Load pretrained policy
|
||||
|
||||
if cfg.use_peft:
|
||||
if cfg.policy.use_peft:
|
||||
# in case of PEFT we re-use the policy pretrained path to point to the adapter path.
|
||||
peft_path = cfg.policy.pretrained_path
|
||||
cfg.policy.pretrained_path = None
|
||||
|
||||
Reference in New Issue
Block a user