Files
lerobot-clone/src/lerobot/policies/groot/modeling_groot.py

330 lines
13 KiB
Python

#!/usr/bin/env python
# Copyright 2024 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Groot Policy Wrapper for LeRobot Integration
Minimal integration that delegates to Isaac-GR00T components where possible
without porting their code. The intent is to:
- Download and load the pretrained GR00T model via GR00TN15.from_pretrained
- Optionally align action horizon similar to gr00t_finetune.py
- Expose predict_action via GR00T model.get_action
- Provide a training forward that can call the GR00T model forward if batch
structure matches.
Notes:
- Dataset loading and full training orchestration is handled by Isaac-GR00T
TrainRunner in their codebase. If you want to invoke that flow end-to-end
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
"""
import builtins
import os
from collections import deque
from pathlib import Path
from typing import TypeVar
import torch
from torch import Tensor
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES
from ..pretrained import PreTrainedPolicy
from .configuration_groot import GrootConfig
from .groot_n1 import GR00TN15
T = TypeVar("T", bound="GrootPolicy")
class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration."""
name = "groot"
config_class = GrootConfig
def __init__(self, config: GrootConfig, **kwargs):
"""Initialize Groot policy wrapper."""
super().__init__(config)
config.validate_features()
self.config = config
# Initialize GR00T model using ported components
self._groot_model = self._create_groot_model()
self.reset()
def _create_groot_model(self):
"""Create and initialize the GR00T model using Isaac-GR00T API.
This is only called when creating a NEW policy (not when loading from checkpoint).
Steps (delegating to Isaac-GR00T):
1) Download and load pretrained model via GR00TN15.from_pretrained
2) Align action horizon with data_config if provided
"""
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
model = GR00TN15.from_pretrained(
pretrained_model_name_or_path=self.config.base_model_path,
tune_llm=self.config.tune_llm,
tune_visual=self.config.tune_visual,
tune_projector=self.config.tune_projector,
tune_diffusion_model=self.config.tune_diffusion_model,
)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
return model
def reset(self):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: GrootConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = True,
**kwargs,
) -> T:
"""Load Groot policy from pretrained model.
Handles two cases:
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
Args:
pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint
config: Optional GrootConfig. If None, loads from checkpoint or creates default
force_download: Force download even if cached
resume_download: Resume interrupted download
proxies: Proxy settings
token: HuggingFace authentication token
cache_dir: Cache directory path
local_files_only: Only use local files
revision: Specific model revision
strict: Strict state dict loading
**kwargs: Additional arguments (passed to config)
Returns:
Initialized GrootPolicy instance with loaded model
"""
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
print(
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
)
model_id = str(pretrained_name_or_path)
is_finetuned_checkpoint = False
# Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors)
try:
if os.path.isdir(model_id):
is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE))
else:
# Try to download the safetensors file to check if it exists
try:
hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=False, # Just check, don't force download
proxies=proxies,
token=token,
local_files_only=local_files_only,
)
is_finetuned_checkpoint = True
except HfHubHTTPError:
is_finetuned_checkpoint = False
except Exception:
is_finetuned_checkpoint = False
if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
strict=strict,
**kwargs,
)
# This is a base GR00T model - load it fresh
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
# Create default config with the pretrained path
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
# Add minimal visual feature required for validation
# validate_features() will automatically add state and action features
# These are placeholders - actual robot features come from the preprocessor
if not config.input_features:
config.input_features = {
f"{OBS_IMAGES}.camera": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Default image size from config
),
}
else:
# Override the base_model_path with the provided path
config.base_model_path = str(pretrained_name_or_path)
# Pass through any additional config overrides from kwargs
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
policy.eval()
return policy
def get_optim_params(self) -> dict:
return self.parameters()
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Training forward pass.
Delegates to Isaac-GR00T model.forward when inputs are compatible.
"""
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = next(self.parameters()).device
# Run GR00T forward under bf16 autocast when enabled to reduce activation memory
# Rationale: Matches original GR00T finetuning (bf16 compute, fp32 params) and avoids fp32 upcasts.
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.forward(groot_inputs)
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
loss = outputs.get("loss")
loss_dict = {"loss": loss.item()}
return loss, loss_dict
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
Returns a tensor of shape (B, n_action_steps, action_dim).
"""
self.eval()
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
# Preprocessing is handled by the processor pipeline, so we just filter the batch
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
allowed_base = {"state", "state_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Get device from model parameters
device = next(self.parameters()).device
# Use bf16 autocast for inference to keep memory low and match backbone dtype
with torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=self.config.use_bf16):
outputs = self._groot_model.get_action(groot_inputs)
actions = outputs.get("action_pred")
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select single action from action queue."""
self.eval()
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
# -------------------------
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Handle Flash Attention compatibility issues by setting environment variables.
This addresses the common 'undefined symbol' error that occurs when Flash Attention
is compiled against a different PyTorch version than what's currently installed.
"""
# Set environment variables to handle Flash Attention compatibility
# These help with symbol resolution issues
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
# Try to import flash_attn and handle failures gracefully
try:
import flash_attn
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
except ImportError as e:
print(f"[GROOT] Flash Attention not available: {e}")
print("[GROOT] Will use fallback attention mechanism")
except Exception as e:
if "undefined symbol" in str(e):
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
print(" pip uninstall flash-attn")
print(" pip install --no-build-isolation flash-attn==2.6.3")
print("[GROOT] Continuing with fallback attention mechanism")
else:
print(f"[GROOT] Flash Attention error: {e}")
print("[GROOT] Continuing with fallback attention mechanism")