feat(policies): add Nvidia Gr00t N1.5 model (#2292)

* feat(policies): add Nvidia Gr00t N1.5 model

Co-authored-by: lbenhorin <lbenhorin@nvidia.com>
Co-authored-by: Aravindh <aravindhs@nvidia.com>
Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com>
Co-authored-by: youliangt <youliangt@nvidia.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>

* fix(docs): add groot to index

Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com>

---------

Co-authored-by: lbenhorin <lbenhorin@nvidia.com>
Co-authored-by: Aravindh <aravindhs@nvidia.com>
Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com>
Co-authored-by: youliangt <youliangt@nvidia.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com>
This commit is contained in:
Steven Palma
2025-10-23 13:50:30 +02:00
committed by GitHub
parent 306429a85b
commit be46bdea8f
26 changed files with 4766 additions and 6 deletions

View File

@@ -505,14 +505,17 @@ def eval_main(cfg: EvalPipelineConfig):
)
policy.eval()
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides = {
"device_processor": {"device": str(policy.config.device)},
"rename_observations_processor": {"rename_map": cfg.rename_map},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
preprocessor_overrides={
"device_processor": {"device": str(policy.config.device)},
"rename_observations_processor": {"rename_map": cfg.rename_map},
},
preprocessor_overrides=preprocessor_overrides,
)
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(