mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
* feat(policies): add Nvidia Gr00t N1.5 model Co-authored-by: lbenhorin <lbenhorin@nvidia.com> Co-authored-by: Aravindh <aravindhs@nvidia.com> Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com> Co-authored-by: youliangt <youliangt@nvidia.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Jade Choghari <chogharijade@gmail.com> * fix(docs): add groot to index Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com> --------- Co-authored-by: lbenhorin <lbenhorin@nvidia.com> Co-authored-by: Aravindh <aravindhs@nvidia.com> Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com> Co-authored-by: youliangt <youliangt@nvidia.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Jade Choghari <chogharijade@gmail.com> Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com>
199 lines
7.8 KiB
Python
199 lines
7.8 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Groot Policy Wrapper for LeRobot Integration
|
|
|
|
Minimal integration that delegates to Isaac-GR00T components where possible
|
|
without porting their code. The intent is to:
|
|
|
|
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
|
|
- Optionally align action horizon similar to gr00t_finetune.py
|
|
- Expose predict_action via GR00T model.get_action
|
|
- Provide a training forward that can call the GR00T model forward if batch
|
|
structure matches.
|
|
|
|
Notes:
|
|
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
|
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
|
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
|
"""
|
|
|
|
import os
|
|
from collections import deque
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from lerobot.policies.groot.configuration_groot import GrootConfig
|
|
from lerobot.policies.groot.groot_n1 import GR00TN15
|
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
|
|
|
|
|
class GrootPolicy(PreTrainedPolicy):
|
|
"""Wrapper around external Groot model for LeRobot integration."""
|
|
|
|
name = "groot"
|
|
config_class = GrootConfig
|
|
|
|
def __init__(self, config: GrootConfig):
|
|
"""Initialize Groot policy wrapper."""
|
|
super().__init__(config)
|
|
config.validate_features()
|
|
self.config = config
|
|
|
|
# Initialize GR00T model using ported components
|
|
self._groot_model = self._create_groot_model()
|
|
|
|
self.reset()
|
|
|
|
def _create_groot_model(self):
|
|
"""Create and initialize the GR00T model using Isaac-GR00T API.
|
|
|
|
This is only called when creating a NEW policy (not when loading from checkpoint).
|
|
|
|
Steps (delegating to Isaac-GR00T):
|
|
1) Download and load pretrained model via GR00TN15.from_pretrained
|
|
2) Align action horizon with data_config if provided
|
|
"""
|
|
# Handle Flash Attention compatibility issues
|
|
self._handle_flash_attention_compatibility()
|
|
|
|
model = GR00TN15.from_pretrained(
|
|
pretrained_model_name_or_path=self.config.base_model_path,
|
|
tune_llm=self.config.tune_llm,
|
|
tune_visual=self.config.tune_visual,
|
|
tune_projector=self.config.tune_projector,
|
|
tune_diffusion_model=self.config.tune_diffusion_model,
|
|
)
|
|
|
|
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
|
model.config.compute_dtype = model.compute_dtype
|
|
|
|
return model
|
|
|
|
def reset(self):
|
|
"""Reset policy state when environment resets."""
|
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
|
|
|
def get_optim_params(self) -> dict:
|
|
return self.parameters()
|
|
|
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
|
"""Training forward pass.
|
|
|
|
Delegates to Isaac-GR00T model.forward when inputs are compatible.
|
|
"""
|
|
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
|
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
|
|
groot_inputs = {
|
|
k: v
|
|
for k, v in batch.items()
|
|
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
|
}
|
|
|
|
# Get device from model parameters
|
|
device = next(self.parameters()).device
|
|
|
|
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
|
|
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
|
|
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
|
outputs = self._groot_model.forward(groot_inputs)
|
|
|
|
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
|
loss = outputs.get("loss")
|
|
|
|
loss_dict = {"loss": loss.item()}
|
|
|
|
return loss, loss_dict
|
|
|
|
@torch.no_grad()
|
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
|
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
|
|
|
|
Returns a tensor of shape (B, n_action_steps, action_dim).
|
|
"""
|
|
self.eval()
|
|
|
|
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
|
|
# Preprocessing is handled by the processor pipeline, so we just filter the batch
|
|
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
|
|
allowed_base = {"state", "state_mask", "embodiment_id"}
|
|
groot_inputs = {
|
|
k: v
|
|
for k, v in batch.items()
|
|
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
|
|
}
|
|
|
|
# Get device from model parameters
|
|
device = next(self.parameters()).device
|
|
|
|
# Use bf16 autocast for inference to keep memory low and match backbone dtype
|
|
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
|
|
outputs = self._groot_model.get_action(groot_inputs)
|
|
|
|
actions = outputs.get("action_pred")
|
|
|
|
original_action_dim = self.config.output_features["action"].shape[0]
|
|
actions = actions[:, :, :original_action_dim]
|
|
|
|
return actions
|
|
|
|
@torch.no_grad()
|
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
|
"""Select single action from action queue."""
|
|
self.eval()
|
|
|
|
if len(self._action_queue) == 0:
|
|
actions = self.predict_action_chunk(batch)
|
|
self._action_queue.extend(actions.transpose(0, 1))
|
|
return self._action_queue.popleft()
|
|
|
|
# -------------------------
|
|
# Internal helpers
|
|
# -------------------------
|
|
def _handle_flash_attention_compatibility(self) -> None:
|
|
"""Handle Flash Attention compatibility issues by setting environment variables.
|
|
|
|
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
|
is compiled against a different PyTorch version than what's currently installed.
|
|
"""
|
|
|
|
# Set environment variables to handle Flash Attention compatibility
|
|
# These help with symbol resolution issues
|
|
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
|
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
|
|
|
# Try to import flash_attn and handle failures gracefully
|
|
try:
|
|
import flash_attn
|
|
|
|
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
|
except ImportError as e:
|
|
print(f"[GROOT] Flash Attention not available: {e}")
|
|
print("[GROOT] Will use fallback attention mechanism")
|
|
except Exception as e:
|
|
if "undefined symbol" in str(e):
|
|
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
|
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
|
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
|
print(" pip uninstall flash-attn")
|
|
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
|
print("[GROOT] Continuing with fallback attention mechanism")
|
|
else:
|
|
print(f"[GROOT] Flash Attention error: {e}")
|
|
print("[GROOT] Continuing with fallback attention mechanism")
|