mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
Clean up loading code
- Centralized instantiation of the PEFT wrapper in `make_policy` for inference (e.g. in `lerobot-record`) - Training a PEFT policy also sets `cfg.use_peft` so that all inference code loading the policy can rely on that attribute to identify if PEFT loading is needed - Modified RTC example to also include PEFT policies. Mostly because this is an example I'm currently exploring.
This commit is contained in:
@@ -194,15 +194,9 @@ class RecordConfig:
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
|
||||
# In case of a PEFT model We assume that the user saved the policy config (`config.json`) alongside the
|
||||
# adapter parameters / config. If they didn't we could instantiate the default configuration for the policy
|
||||
# but we wouldn't know if that is correct. So, in case of a missing config this will simply fail.
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
|
||||
if (Path(policy_path) / "adapter_config.json").exists():
|
||||
self.policy.use_peft = True
|
||||
|
||||
if self.teleop is None and self.policy is None:
|
||||
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
||||
|
||||
@@ -433,19 +427,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
if cfg.policy and cfg.policy.use_peft:
|
||||
from peft import PeftModel
|
||||
|
||||
logging.info("Loading policy's PEFT adapter.")
|
||||
# in case of PEFT we re-use the policy pretrained path to point to the adapter path.
|
||||
peft_path = cfg.policy.pretrained_path
|
||||
cfg.policy.pretrained_path = None
|
||||
|
||||
policy = make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path)
|
||||
|
||||
else:
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
|
||||
Reference in New Issue
Block a user