Files
lerobot-clone/src/lerobot/policies/groot/modeling_groot.py
Steven Palma be46bdea8f feat(policies): add Nvidia Gr00t N1.5 model (#2292)
* feat(policies): add Nvidia Gr00t N1.5 model

Co-authored-by: lbenhorin <lbenhorin@nvidia.com>
Co-authored-by: Aravindh <aravindhs@nvidia.com>
Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com>
Co-authored-by: youliangt <youliangt@nvidia.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>

* fix(docs): add groot to index

Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com>

---------

Co-authored-by: lbenhorin <lbenhorin@nvidia.com>
Co-authored-by: Aravindh <aravindhs@nvidia.com>
Co-authored-by: nv-sachdevkartik <ksachdev@nvidia.com>
Co-authored-by: youliangt <youliangt@nvidia.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
Co-authored-by: sachdevkartik <sachdev.kartik25@gmail.com>
2025-10-23 13:50:30 +02:00

199 lines
7.8 KiB
Python

#!/usr/bin/env python
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T components where possible
without porting their code. The intent is to:
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
- Optionally align action horizon similar to gr00t_finetune.py
- Expose predict_action via GR00T model.get_action
- Provide a training forward that can call the GR00T model forward if batch
structure matches.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
"""
import os
from collections import deque
import torch
from torch import Tensor
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.groot_n1 import GR00TN15
from lerobot.policies.pretrained import PreTrainedPolicy
class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration."""
name = "groot"
config_class = GrootConfig
def __init__(self, config: GrootConfig):
"""Initialize Groot policy wrapper."""
super().__init__(config)
config.validate_features()
self.config = config
# Initialize GR00T model using ported components
self._groot_model = self._create_groot_model()
self.reset()
def _create_groot_model(self):
"""Create and initialize the GR00T model using Isaac-GR00T API.
This is only called when creating a NEW policy (not when loading from checkpoint).
Steps (delegating to Isaac-GR00T):
1) Download and load pretrained model via GR00TN15.from_pretrained
2) Align action horizon with data_config if provided
"""
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
model = GR00TN15.from_pretrained(
pretrained_model_name_or_path=self.config.base_model_path,
tune_llm=self.config.tune_llm,
tune_visual=self.config.tune_visual,
tune_projector=self.config.tune_projector,
tune_diffusion_model=self.config.tune_diffusion_model,
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
return model
def reset(self):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Training forward pass.
Delegates to Isaac-GR00T model.forward when inputs are compatible.
"""
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = next(self.parameters()).device
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.forward(groot_inputs)
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss")
loss_dict = {"loss": loss.item()}
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
Returns a tensor of shape (B, n_action_steps, action_dim).
"""
self.eval()
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
# Preprocessing is handled by the processor pipeline, so we just filter the batch
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
allowed_base = {"state", "state_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
original_action_dim = self.config.output_features["action"].shape[0]
actions = actions[:, :, :original_action_dim]
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select single action from action queue."""
self.eval()
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
# -------------------------
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Handle Flash Attention compatibility issues by setting environment variables.
This addresses the common 'undefined symbol' error that occurs when Flash Attention
is compiled against a different PyTorch version than what's currently installed.
"""
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try:
import flash_attn
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError as e:
print(f"[GROOT] Flash Attention not available: {e}")
print("[GROOT] Will use fallback attention mechanism")
except Exception as e:
if "undefined symbol" in str(e):
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")