Compare commits

...

1 Commits

Author SHA1 Message Date
Andrew Wrenn
b568c41355 Add GR00T N1.7 support
Add GR00T N1.7 policy configuration, checkpoint compatibility, processor parity, LIBERO documentation, and focused tests.

Co-authored-by: Ryan Halabi <ryhalabi@nvidia.com>
2026-06-01 08:57:04 -07:00
11 changed files with 4773 additions and 83 deletions

View File

@@ -64,7 +64,7 @@
- local: eo1
title: EO-1
- local: groot
title: NVIDIA GR00T N1.5
title: NVIDIA GR00T
- local: xvla
title: X-VLA
- local: multi_task_dit

View File

@@ -1,16 +1,16 @@
# GR00T N1.5 Policy
# GR00T Policy
GR00T N1.5 is an open foundation model from NVIDIA designed for generalized humanoid robot reasoning and skills. It is a cross-embodiment model that accepts multimodal input, including language and images, to perform manipulation tasks in diverse environments.
GR00T is an NVIDIA foundation model family for generalized humanoid robot reasoning and skills. It is a cross-embodiment policy that accepts multimodal input, including language, images, and proprioception, to perform manipulation tasks in diverse environments.
This document outlines the specifics of its integration and usage within the LeRobot framework.
LeRobot integrates GR00T through the `groot` policy type. The default model family is GR00T N1.5, and GR00T N1.7 can be selected with `policy.model_version=n1.7`.
## Model Overview
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. It is built to improve generalization and language-following abilities for humanoid robots.
NVIDIA Isaac GR00T N1.5 is an upgraded version of the GR00T N1 foundation model. GR00T N1.7 extends the family with a Cosmos-Reason2/Qwen3-VL backbone and N1.7 checkpoints for SimplerEnv, DROID, and LIBERO.
Developers and researchers can post-train GR00T N1.5 with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
Developers and researchers can post-train GR00T with their own real or synthetic data to adapt it for specific humanoid robots or tasks.
GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception.
GR00T uses pre-trained vision and language encoders with a flow matching action transformer to model a chunk of actions conditioned on vision, language, and proprioception.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png"
@@ -28,33 +28,35 @@ This approach allows the model to be highly adaptable through post-training for
## Installation Requirements
As of today, GR00T N1.5 requires flash attention for it's internal working.
We are working on making this optional, but in the meantime that means that we require an extra installation step and it can only be used in CUDA enabled devices.
1. Following the Environment Setup of our [Installation Guide](./installation). **Attention** don't install `lerobot` in this step.
2. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) by running:
Install LeRobot with the GR00T extra:
```bash
# Check https://pytorch.org/get-started/locally/ for your system
pip install "torch>=2.2.1,<2.8.0" "torchvision>=0.21.0,<0.23.0" # --index-url https://download.pytorch.org/whl/cu1XX
pip install ninja "packaging>=24.2,<26.0" # flash attention dependencies
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
pip install "lerobot[groot]"
```
3. Install LeRobot by running:
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra installs the policy dependencies, including `transformers`, `diffusers`, `peft`, `dm-tree`, and Flash Attention where available. If Flash Attention is unavailable or incompatible, LeRobot falls back to SDPA attention in supported GR00T paths, with lower expected throughput.
For a source checkout, follow the Environment Setup in the [Installation Guide](./installation), then install the extra:
```bash
pip install lerobot[groot]
uv sync --locked --extra groot
```
If you need to install Flash Attention manually for your CUDA/PyTorch build, use the wheel or source build recommended by the [Flash Attention project](https://github.com/Dao-AILab/flash-attention).
## Usage
To use GR00T in your LeRobot configuration, specify the policy type as:
To use GR00T N1.5 in your LeRobot configuration, specify the policy type:
```python
policy.type=groot
```bash
--policy.type=groot
```
To use GR00T N1.7:
```bash
--policy.type=groot \
--policy.model_version=n1.7
```
## Training
@@ -85,14 +87,20 @@ accelerate launch \
--job_name=$JOB_NAME
```
For N1.7, add:
```bash
--policy.model_version=n1.7
```
## Performance Results
### Libero Benchmark Results
### LIBERO Benchmark Results
> [!NOTE]
> Follow our instructions for Libero usage: [Libero](./libero)
> Follow the [LIBERO](./libero) setup instructions before running `lerobot-eval`.
GR00T has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the Libero dataset and compared the results to the GR00T reference results.
GR00T has demonstrated strong performance on the LIBERO benchmark suite. To compare and test its LeRobot implementation, we finetuned the GR00T N1.5 model for 30k steps on the LIBERO dataset and compared the results to the GR00T reference results.
| Benchmark | LeRobot Implementation | GR00T Reference |
| ------------------ | ---------------------- | --------------- |
@@ -101,7 +109,49 @@ GR00T has demonstrated strong performance on the Libero benchmark suite. To comp
| **Libero Long** | 82.0% | 76.0% |
| **Average** | 87.0% | 87.0% |
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
These results demonstrate GR00T's strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, follow the instructions in the [LIBERO](./libero) section.
### GR00T N1.7 LIBERO Checkpoints
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
| Suite | Checkpoint subdirectory |
| -------------- | ----------------------- |
| LIBERO Spatial | `libero_spatial` |
| LIBERO Object | `libero_object` |
| LIBERO Goal | `libero_goal` |
| LIBERO 10 | `libero_10` |
Preliminary LeRobot integration results:
| Suite | Status | Success rate | n_episodes |
| -------------- | ------ | -----------: | ---------: |
| LIBERO Spatial | ✓ | ~95% | XX |
| LIBERO Object | ✓ | XX% | XX |
| LIBERO Goal | ✓ | XX% | XX |
| LIBERO 10 | ✓ | XX% | XX |
| **Average** | ✓ | **XX%** | **XX** |
Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
```bash
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO
lerobot-eval \
--policy.type=groot \
--policy.model_version=n1.7 \
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
--policy.embodiment_tag=libero_sim \
--env.type=libero \
--env.task=libero_spatial \
--eval.n_episodes=50
```
Use `eval.n_episodes >= 50` per suite when reporting success rates.
### Evaluate in your hardware setup
@@ -131,4 +181,4 @@ lerobot-rollout\
## License
This model follows NVIDIA's proprietary license, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). Future versions (starting from N1.7) will follow **Apache 2.0 License**.
GR00T N1.5 follows NVIDIA's license terms, consistent with the original [GR00T repository](https://github.com/NVIDIA/Isaac-GR00T). GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).

View File

@@ -24,4 +24,8 @@ Code: https://github.com/NVIDIA/Isaac-GR00T
Blog: https://developer.nvidia.com/isaac/gr00t
Hugging Face Model: https://huggingface.co/nvidia/GR00T-N1.5-3B
Hugging Face Models:
- GR00T N1.5: https://huggingface.co/nvidia/GR00T-N1.5-3B
- GR00T N1.7: https://huggingface.co/nvidia/GR00T-N1.7-3B
- GR00T N1.7 LIBERO checkpoints: https://huggingface.co/nvidia/GR00T-N1.7-LIBERO

View File

@@ -18,6 +18,7 @@ from __future__ import annotations
import importlib
import logging
from copy import copy
from typing import TYPE_CHECKING, Any, TypedDict, Unpack
import torch
@@ -48,7 +49,7 @@ from .act.configuration_act import ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig
from .eo1.configuration_eo1 import EO1Config
from .gaussian_actor.configuration_gaussian_actor import GaussianActorConfig
from .groot.configuration_groot import GrootConfig
from .groot.configuration_groot import GROOT_N1_7, GrootConfig
from .molmoact2.configuration_molmoact2 import MolmoAct2Config
from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig
from .pi0.configuration_pi0 import PI0Config
@@ -273,24 +274,47 @@ def make_pre_post_processors(
policy configuration type.
"""
if pretrained_path:
if isinstance(policy_cfg, GrootConfig):
from .groot.configuration_groot import is_raw_groot_n1_7_checkpoint
if is_raw_groot_n1_7_checkpoint(pretrained_path):
from .groot.processor_groot import make_groot_pre_post_processors
processor_cfg = copy(policy_cfg)
processor_cfg.base_model_path = str(pretrained_path)
return make_groot_pre_post_processors(
config=processor_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
# TODO(Steven): Temporary patch, implement correctly the processors for Gr00t
if isinstance(policy_cfg, GrootConfig):
# GROOT handles normalization in groot_pack_inputs_v3 step
# GROOT handles normalization in its pack-inputs step
# Need to override both stats AND normalize_min_max since saved config might be empty
preprocessor_overrides = {}
postprocessor_overrides = {}
preprocessor_overrides["groot_pack_inputs_v3"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
}
dataset_stats = kwargs.get("dataset_stats")
preprocessor_overrides = dict(kwargs.get("preprocessor_overrides", {}))
postprocessor_overrides = dict(kwargs.get("postprocessor_overrides", {}))
pack_inputs_key = (
"groot_n1_7_pack_inputs_v1"
if policy_cfg.model_version == GROOT_N1_7
else "groot_pack_inputs_v3"
)
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
pack_input_overrides["normalize_min_max"] = True
if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7:
pack_input_overrides["stats"] = dataset_stats
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
# Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats
env_action_dim = policy_cfg.output_features[ACTION].shape[0]
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = {
"stats": kwargs.get("dataset_stats"),
"normalize_min_max": True,
"env_action_dim": env_action_dim,
}
action_unpack_overrides = dict(
postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {})
)
action_unpack_overrides["normalize_min_max"] = True
action_unpack_overrides["env_action_dim"] = env_action_dim
if dataset_stats is not None and policy_cfg.model_version != GROOT_N1_7:
action_unpack_overrides["stats"] = dataset_stats
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
kwargs["preprocessor_overrides"] = preprocessor_overrides
kwargs["postprocessor_overrides"] = postprocessor_overrides

View File

@@ -18,4 +18,12 @@ from .configuration_groot import GrootConfig
from .modeling_groot import GrootPolicy
from .processor_groot import make_groot_pre_post_processors
__all__ = ["GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
__all__ = ["GR00TN17", "GR00TN17Config", "GrootConfig", "GrootPolicy", "make_groot_pre_post_processors"]
def __getattr__(name: str):
if name in {"GR00TN17", "GR00TN17Config"}:
from .groot_n1_7 import GR00TN17, GR00TN17Config
return {"GR00TN17": GR00TN17, "GR00TN17Config": GR00TN17Config}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -181,8 +181,7 @@ class BasicTransformerBlock(nn.Module):
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
# encoder_attention_mask=encoder_attention_mask,
attention_mask=encoder_attention_mask if encoder_hidden_states is not None else attention_mask,
)
if self.final_dropout:
attn_output = self.final_dropout(attn_output)
@@ -318,6 +317,71 @@ class DiT(ModelMixin, ConfigMixin):
return self.proj_out_2(hidden_states)
class AlternateVLDiT(DiT):
"""N1.7 DiT variant that alternates cross-attention over image and text tokens."""
def __init__(self, *args, attend_text_every_n_blocks: int = 2, **kwargs):
super().__init__(*args, **kwargs)
self.attend_text_every_n_blocks = attend_text_every_n_blocks
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
return_all_hidden_states: bool = False,
image_mask: torch.Tensor | None = None,
backbone_attention_mask: torch.Tensor | None = None,
):
if image_mask is None:
raise ValueError("image_mask is required for AlternateVLDiT.")
if backbone_attention_mask is None:
raise ValueError("backbone_attention_mask is required for AlternateVLDiT.")
temb = self.timestep_encoder(timestep)
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
image_attention_mask = image_mask & backbone_attention_mask
non_image_attention_mask = (~image_mask) & backbone_attention_mask
all_hidden_states = [hidden_states]
if not self.config.interleave_self_attention:
raise ValueError("AlternateVLDiT requires interleave_self_attention=True.")
for idx, block in enumerate(self.transformer_blocks):
if idx % 2 == 1:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
temb=temb,
)
else:
curr_encoder_attention_mask = (
non_image_attention_mask
if idx % (2 * self.attend_text_every_n_blocks) == 0
else image_attention_mask
)
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=curr_encoder_attention_mask,
temb=temb,
)
all_hidden_states.append(hidden_states)
conditioning = temb
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
if return_all_hidden_states:
return self.proj_out_2(hidden_states), all_hidden_states
return self.proj_out_2(hidden_states)
class SelfAttentionTransformer(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True

View File

@@ -14,12 +14,295 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
GROOT_N1_5 = "n1.5"
GROOT_N1_7 = "n1.7"
GROOT_N1_5_BASE_MODEL = "nvidia/GR00T-N1.5-3B"
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
_GROOT_MODEL_VERSION_ALIASES = {
"n1.5": GROOT_N1_5,
"n1_5": GROOT_N1_5,
"n15": GROOT_N1_5,
"1.5": GROOT_N1_5,
"n1.7": GROOT_N1_7,
"n1_7": GROOT_N1_7,
"n1d7": GROOT_N1_7,
"n17": GROOT_N1_7,
"1.7": GROOT_N1_7,
}
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
"none": None,
"": None,
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
}
def normalize_groot_model_version(model_version: str) -> str:
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
if normalized is None:
supported = ", ".join(sorted({GROOT_N1_5, GROOT_N1_7}))
raise ValueError(
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
)
return normalized
def normalize_groot_action_decode_transform(transform: str | None) -> str | None:
if transform is None:
return None
normalized = _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.get(transform.lower())
if normalized is None and transform.lower() not in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES:
supported = ", ".join(
sorted(key for key, value in _GROOT_ACTION_DECODE_TRANSFORM_ALIASES.items() if value is not None)
)
raise ValueError(
f"Unsupported GR00T N1.7 action decode transform '{transform}'. "
f"Supported transforms: none, {supported}."
)
return normalized
def infer_groot_model_version(model_path: str | None) -> str | None:
if not model_path:
return None
model_path_lower = model_path.lower()
if "gr00t-n1.7" in model_path_lower or "gr00t_n1.7" in model_path_lower:
return GROOT_N1_7
if "gr00t-n1.5" in model_path_lower or "gr00t_n1.5" in model_path_lower:
return GROOT_N1_5
config_version = _infer_groot_model_version_from_local_config(model_path)
if config_version is not None:
return config_version
return None
def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
if model_path is None:
return False
path = Path(model_path).expanduser()
if path.is_dir():
config_path = path / "config.json"
elif path.name == "config.json":
config_path = path
else:
return False
try:
with config_path.open() as f:
config = json.load(f)
except (OSError, json.JSONDecodeError):
return False
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None:
if model_path is None:
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
try:
with processor_config_path.open() as f:
processor_config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
if not isinstance(modality_configs, dict):
return None
if "libero_sim" in modality_configs:
return "libero_sim"
if len(modality_configs) == 1:
return next(iter(modality_configs))
return None
def infer_groot_n1_7_action_horizon(
model_path: str | Path | None, embodiment_tag: str | None = None
) -> int | None:
if model_path is None:
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
try:
with processor_config_path.open() as f:
processor_config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
processor_kwargs = processor_config.get("processor_kwargs", {})
if not isinstance(processor_kwargs, dict):
return None
modality_configs = processor_kwargs.get("modality_configs", {})
if not isinstance(modality_configs, dict):
return None
if embodiment_tag is None:
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
if embodiment_tag is None:
return None
embodiment_config = modality_configs.get(embodiment_tag, {})
if not isinstance(embodiment_config, dict):
return None
action_config = embodiment_config.get("action", {})
if not isinstance(action_config, dict):
return None
delta_indices = action_config.get("delta_indices", [])
if not isinstance(delta_indices, list):
return None
return len(delta_indices) or None
def infer_groot_n1_7_action_execution_horizon(
model_path: str | Path | None, embodiment_tag: str | None = None
) -> int | None:
action_horizon = infer_groot_n1_7_action_horizon(model_path, embodiment_tag)
if action_horizon is None:
return None
if embodiment_tag is None:
embodiment_tag = infer_groot_n1_7_embodiment_tag(model_path)
if embodiment_tag == "libero_sim":
# NVIDIA's N1.7 LIBERO rollout wrapper replans after 8 of the 16 decoded
# actions. Keeping that execution cadence avoids stale open-loop chunks.
return min(action_horizon, 8)
return action_horizon
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
model_path = Path(model_name).expanduser()
if model_path.exists():
return str(model_path)
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
return str(cached_snapshot) if cached_snapshot is not None else model_name
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
required_files = (
"config.json",
"tokenizer_config.json",
"preprocessor_config.json",
"video_preprocessor_config.json",
)
for hub_cache in _candidate_hf_hub_caches(cache_dir):
repo_cache = hub_cache / repo_cache_name
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.is_dir():
continue
candidates: list[Path] = []
ref_path = repo_cache / "refs" / "main"
try:
ref = ref_path.read_text().strip()
except OSError:
ref = ""
if ref:
candidates.append(snapshots_dir / ref)
candidates.extend(
sorted(
(path for path in snapshots_dir.iterdir() if path.is_dir()),
key=lambda path: path.stat().st_mtime,
reverse=True,
)
)
seen: set[Path] = set()
for snapshot in candidates:
if snapshot in seen:
continue
seen.add(snapshot)
if all((snapshot / filename).exists() for filename in required_files):
return snapshot
return None
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
candidates: list[Path] = []
if cache_dir is not None:
cache_path = Path(cache_dir).expanduser()
candidates.append(cache_path)
candidates.append(cache_path / "hub")
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
if hub_cache:
candidates.append(Path(hub_cache).expanduser())
hf_home = os.environ.get("HF_HOME")
if hf_home:
candidates.append(Path(hf_home).expanduser() / "hub")
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
deduped: list[Path] = []
seen: set[Path] = set()
for candidate in candidates:
resolved = candidate.resolve() if candidate.exists() else candidate
if resolved not in seen:
seen.add(resolved)
deduped.append(candidate)
return deduped
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
path = Path(model_path).expanduser()
if path.is_dir():
config_path = path / "config.json"
elif path.name == "config.json":
config_path = path
else:
return None
if not config_path.exists():
return None
try:
with config_path.open() as f:
config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
return _infer_groot_model_version_from_config(config)
def _infer_groot_model_version_from_config(config: dict) -> str | None:
model_version = config.get("model_version")
if isinstance(model_version, str):
try:
return normalize_groot_model_version(model_version)
except ValueError:
return None
candidates = [config.get("model_type"), *(config.get("architectures") or [])]
for candidate in candidates:
if not isinstance(candidate, str):
continue
normalized = candidate.lower().replace("-", "_")
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
return GROOT_N1_7
if normalized in {"gr00t_n1_5", "gr00tn15", "gr00t_n1d5"}:
return GROOT_N1_5
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
return GROOT_N1_7
return None
@PreTrainedConfig.register_subclass("groot")
@dataclass
@@ -52,12 +335,21 @@ class GrootConfig(PreTrainedConfig):
# Groot-specific model parameters (from groot_finetune_script.py)
# Explicit GR00T model family selection. Defaults to N1.5 to preserve existing behavior.
model_version: str = GROOT_N1_5
# Path or HuggingFace model ID for the base Groot model
base_model_path: str = "nvidia/GR00T-N1.5-3B"
base_model_path: str | None = None
# HF repo ID (or local path) that hosts vocab.json and merges.txt for Eagle tokenizer.
tokenizer_assets_repo: str = "lerobot/eagle2hg-processor-groot-n1p5"
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
action_decode_transform: str | None = None
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
@@ -117,6 +409,35 @@ class GrootConfig(PreTrainedConfig):
resume: bool = False
def __post_init__(self):
self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None:
self.base_model_path = (
GROOT_N1_7_BASE_MODEL if self.model_version == GROOT_N1_7 else GROOT_N1_5_BASE_MODEL
)
if self.action_decode_transform is not None and self.model_version != GROOT_N1_7:
raise ValueError("action_decode_transform can only be used with model_version='n1.7'.")
if self.model_version == GROOT_N1_7:
if self.max_state_dim == 64:
self.max_state_dim = 132
if self.max_action_dim == 32:
self.max_action_dim = 132
if self.chunk_size == 50:
self.chunk_size = 40
if self.n_action_steps == 50:
self.n_action_steps = 40
if tuple(self.image_size) == (224, 224):
self.image_size = (256, 256)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version:
raise ValueError(
f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
super().__post_init__()
if self.n_action_steps > self.chunk_size:
@@ -192,7 +513,12 @@ class GrootConfig(PreTrainedConfig):
@property
def action_delta_indices(self) -> list[int]:
"""Return indices for delta actions."""
return list(range(min(self.chunk_size, 16)))
model_action_horizon = 16
if self.model_version == GROOT_N1_7:
model_action_horizon = (
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
)
return list(range(min(self.chunk_size, model_action_horizon)))
@property
def reward_delta_indices(self) -> None:

View File

@@ -0,0 +1,962 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib
import json
import logging
from contextlib import suppress
from copy import deepcopy
from typing import TYPE_CHECKING, Any
import torch
import torch.nn.functional as F # noqa: N812
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, RepositoryNotFoundError
from torch import nn
from torch.distributions import Beta
from lerobot.utils.import_utils import _transformers_available, require_package
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.feature_extraction_utils import BatchFeature
else:
AutoConfig = None
AutoModel = None
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
try:
import tree
except ImportError:
tree = None
try:
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
except ImportError:
Qwen3VLConfig = None
Qwen3VLForConditionalGeneration = None
logger = logging.getLogger(__name__)
def _copy_default(value: Any) -> Any:
return deepcopy(value)
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"model_dtype": "bfloat16",
"dtype": "bfloat16",
"model_name": "nvidia/Cosmos-Reason2-2B",
"backbone_model_type": "qwen",
"model_revision": None,
"tune_top_llm_layers": 0,
"backbone_embedding_dim": 2048,
"tune_llm": False,
"tune_visual": False,
"select_layer": 12,
"reproject_vision": False,
"use_flash_attention": True,
"load_bf16": False,
"backbone_trainable_params_fp32": True,
"image_crop_size": (230, 230),
"image_target_size": (256, 256),
"shortest_image_edge": None,
"crop_fraction": None,
"random_rotation_angle": None,
"color_jitter_params": None,
"use_albumentations_transforms": True,
"extra_augmentation_config": None,
"formalize_language": True,
"apply_sincos_state_encoding": False,
"use_percentiles": True,
"use_relative_action": False,
"max_state_dim": 132,
"max_action_dim": 132,
"action_horizon": 40,
"hidden_size": 1024,
"input_embedding_dim": 1536,
"state_history_length": 1,
"add_pos_embed": True,
"attn_dropout": 0.2,
"use_vlln": True,
"max_seq_len": 1024,
"use_alternate_vl_dit": True,
"attend_text_every_n_blocks": 2,
"diffusion_model_cfg": {
"positional_embeddings": None,
"num_layers": 32,
"num_attention_heads": 32,
"attention_head_dim": 48,
"norm_type": "ada_norm",
"dropout": 0.2,
"final_dropout": True,
"output_dim": 1024,
"interleave_self_attention": True,
},
"vl_self_attention_cfg": {
"positional_embeddings": None,
"num_layers": 4,
"num_attention_heads": 32,
"attention_head_dim": 64,
"dropout": 0.2,
"final_dropout": True,
},
"num_inference_timesteps": 4,
"noise_beta_alpha": 1.5,
"noise_beta_beta": 1.0,
"noise_s": 0.999,
"num_timestep_buckets": 1000,
"tune_projector": True,
"tune_diffusion_model": True,
"tune_vlln": True,
"state_dropout_prob": 0.2,
"exclude_state": False,
"use_mean_std": False,
"max_num_embodiments": 32,
"rtc_ramp_rate": 6.0,
}
class GR00TN17Config(PretrainedConfig):
"""Configuration for NVIDIA GR00T N1.7.
N1.7 uses the Cosmos-Reason2-2B / Qwen3-VL backbone and a multi-embodiment
flow-matching action head. This mirrors the public N1.7 checkpoint config
while keeping it local to LeRobot and independent from the external
Isaac-GR00T ``gr00t`` Python package.
"""
model_type = "Gr00tN1d7"
_defaults = GR00T_N1_7_DEFAULTS
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in GR00T_N1_7_DEFAULTS.items():
setattr(self, key, _copy_default(kwargs.pop(key, value)))
for key, value in kwargs.items():
setattr(self, key, value)
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
cfg = self.to_dict()
if not exclude_augment:
return cfg
exclude_keys = {
"random_rotation_angle",
"color_jitter_params",
"use_albumentations_transforms",
"formalize_language",
"image_crop_size",
"image_target_size",
"shortest_image_edge",
"crop_fraction",
}
return {k: v for k, v in cfg.items() if k not in exclude_keys}
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
class CategorySpecificLinear(nn.Module):
"""Linear layer with category-specific weights for multi-embodiment support."""
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int):
super().__init__()
self.num_categories = num_categories
self.W = nn.Parameter(0.02 * torch.randn(num_categories, input_dim, hidden_dim))
self.b = nn.Parameter(torch.zeros(num_categories, hidden_dim))
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
selected_w = self.W[cat_ids]
selected_b = self.b[cat_ids]
return torch.bmm(x, selected_w) + selected_b.unsqueeze(1)
class CategorySpecificMLP(nn.Module):
"""Two-layer MLP with category-specific weights."""
def __init__(self, num_categories: int, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.layer1 = CategorySpecificLinear(num_categories, input_dim, hidden_dim)
self.layer2 = CategorySpecificLinear(num_categories, hidden_dim, output_dim)
def forward(self, x: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
hidden = F.relu(self.layer1(x, cat_ids))
return self.layer2(hidden, cat_ids)
class SinusoidalPositionalEncoding(nn.Module):
"""Sinusoidal encoding of shape ``(B, T, D)`` for timestep tensors ``(B, T)``.
The frequency scalar is intentionally created on CPU and then broadcast with
the device-local arange result. That mirrors Isaac-GR00T's N1.7 timestep
embedding and avoids tiny dtype/device construction differences in parity
tests.
"""
def __init__(self, embedding_dim: int):
super().__init__()
self.embedding_dim = embedding_dim
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
timesteps = timesteps.float()
half_dim = self.embedding_dim // 2
exponent = -torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * (
torch.log(torch.tensor(10000.0)) / half_dim
)
freqs = timesteps.unsqueeze(-1) * exponent.exp()
return torch.cat([torch.sin(freqs), torch.cos(freqs)], dim=-1)
def swish(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)
class MultiEmbodimentActionEncoder(nn.Module):
"""Action encoder with category-specific projections and sinusoidal time encoding."""
def __init__(self, action_dim: int, hidden_size: int, num_embodiments: int):
super().__init__()
self.W1 = CategorySpecificLinear(num_embodiments, action_dim, hidden_size)
self.W2 = CategorySpecificLinear(num_embodiments, 2 * hidden_size, hidden_size)
self.W3 = CategorySpecificLinear(num_embodiments, hidden_size, hidden_size)
self.pos_encoding = SinusoidalPositionalEncoding(hidden_size)
def forward(self, actions: torch.Tensor, timesteps: torch.Tensor, cat_ids: torch.Tensor) -> torch.Tensor:
batch_size, horizon, _ = actions.shape
if timesteps.dim() != 1 or timesteps.shape[0] != batch_size:
raise ValueError("Expected `timesteps` to have shape (B,).")
timesteps = timesteps.unsqueeze(1).expand(-1, horizon)
action_emb = self.W1(actions, cat_ids)
time_emb = self.pos_encoding(timesteps).to(dtype=action_emb.dtype)
x = swish(self.W2(torch.cat([action_emb, time_emb], dim=-1), cat_ids))
return self.W3(x, cat_ids)
class Qwen3Backbone(nn.Module):
"""Cosmos-Reason2/Qwen3-VL backbone used by GR00T N1.7.
The public checkpoint stores the action head in the GR00T checkpoint but
uses a Hugging Face Qwen3-VL-compatible backbone interface. This wrapper
keeps the nested HF module layout compatible across transformer versions
and exposes the hidden states consumed by the action head.
"""
def __init__(
self,
model_name: str = "nvidia/Cosmos-Reason2-2B",
tune_llm: bool = False,
tune_visual: bool = False,
select_layer: int = -1,
reproject_vision: bool = False,
use_flash_attention: bool = False,
load_bf16: bool = False,
tune_top_llm_layers: int = 0,
trainable_params_fp32: bool = False,
transformers_loading_kwargs: dict[str, Any] | None = None,
load_pretrained_weights: bool = True,
):
if Qwen3VLForConditionalGeneration is None:
raise ImportError(
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
"or use a transformers version that provides Qwen3-VL support."
)
super().__init__()
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
extra_kwargs: dict[str, Any] = {}
if use_flash_attention:
try:
import flash_attn # noqa: F401
extra_kwargs["attn_implementation"] = "flash_attention_2"
except ImportError:
logger.warning("flash_attn is not installed. Falling back to SDPA attention.")
extra_kwargs["attn_implementation"] = "sdpa"
if load_bf16:
extra_kwargs["torch_dtype"] = torch.bfloat16
if load_pretrained_weights:
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
model_name,
**extra_kwargs,
**transformers_loading_kwargs,
).eval()
else:
self.model = self._from_backbone_config(
model_name=model_name,
model_kwargs=extra_kwargs,
config_kwargs=transformers_loading_kwargs,
).eval()
while len(self.language_model.layers) > select_layer:
self.language_model.layers.pop(-1)
self.select_layer = select_layer
self.set_trainable_parameters(tune_llm, tune_visual, tune_top_llm_layers)
if load_bf16 and trainable_params_fp32:
for parameter in self.parameters():
if parameter.requires_grad:
parameter.data = parameter.data.to(torch.float32)
def set_trainable_parameters(
self, tune_llm: bool, tune_visual: bool, tune_top_llm_layers: int = 0
) -> None:
self.tune_llm = tune_llm
self.tune_visual = tune_visual
for parameter in self.parameters():
parameter.requires_grad = True
if not tune_llm:
self.language_model.requires_grad_(False)
if not tune_visual:
self.visual.requires_grad_(False)
if tune_top_llm_layers > 0:
for layer in self.language_model.layers[-tune_top_llm_layers:]:
for parameter in layer.parameters():
parameter.requires_grad = True
def set_frozen_modules_to_eval_mode(self) -> None:
if self.training:
if self.language_model and not self.tune_llm:
self.language_model.eval()
if self.visual and not self.tune_visual:
self.visual.eval()
@property
def language_model(self) -> nn.Module:
return getattr(self.model, "model", self.model).language_model
@property
def visual(self) -> nn.Module:
return getattr(self.model, "model", self.model).visual
def _from_backbone_config(
self,
*,
model_name: str,
model_kwargs: dict[str, Any],
config_kwargs: dict[str, Any],
) -> nn.Module:
if _is_cosmos_reason2_backbone(model_name):
backbone_config = _cosmos_reason2_qwen3_vl_config()
else:
if AutoConfig is None:
raise ImportError(
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
)
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
return BatchFeature(data=batch)
def _ensure_mm_token_type_ids(self, model_input: dict[str, torch.Tensor]) -> None:
if "mm_token_type_ids" in model_input:
return
if "image_grid_thw" not in model_input and "video_grid_thw" not in model_input:
return
input_ids = model_input.get("input_ids")
if input_ids is None:
return
mm_token_type_ids = torch.zeros(input_ids.shape, dtype=torch.int32, device=input_ids.device)
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None:
mm_token_type_ids[input_ids == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[input_ids == video_token_id] = 2
model_input["mm_token_type_ids"] = mm_token_type_ids
def _ensure_legacy_qwen3_position_ids(self, model_input: dict[str, torch.Tensor]) -> None:
"""Restore the Qwen3-VL text position ids used by older Transformers releases.
Transformers 5.x computes 3-row multimodal RoPE ids for Qwen3-VL and then
drops text position ids before calling text-layer flash attention. GR00T
N1.7 was aligned against the older Transformers path, where a fourth text
position row is forwarded alongside the temporal/height/width rows. Adding
the row here preserves the newer multimodal position computation while
keeping flash attention on the legacy code path.
"""
if "position_ids" in model_input:
return
qwen3_model = getattr(self.model, "model", self.model)
compute_3d_position_ids = getattr(qwen3_model, "compute_3d_position_ids", None)
if compute_3d_position_ids is None:
return
position_ids = compute_3d_position_ids(
input_ids=model_input.get("input_ids"),
image_grid_thw=model_input.get("image_grid_thw"),
video_grid_thw=model_input.get("video_grid_thw"),
inputs_embeds=None,
attention_mask=model_input.get("attention_mask"),
past_key_values=None,
mm_token_type_ids=model_input.get("mm_token_type_ids"),
)
if position_ids.ndim == 3 and position_ids.shape[0] == 3:
position_ids = torch.cat([position_ids[:1], position_ids], dim=0)
model_input["position_ids"] = position_ids
def _last_decoder_layer_output(self, model_input: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return the pre-final-norm decoder output consumed by the N1.7 action head.
Older Transformers releases exposed this tensor as ``hidden_states[-1]``.
Newer releases expose the post-final-norm tensor there instead. Capturing
the last decoder layer output directly keeps the N1.7 action head input
stable across Transformers versions.
"""
captured: dict[str, torch.Tensor] = {}
def capture_output(_module: nn.Module, _inputs: tuple[Any, ...], output: Any) -> None:
if isinstance(output, torch.Tensor):
captured["features"] = output
elif isinstance(output, (tuple, list)) and output:
captured["features"] = output[0]
elif hasattr(output, "last_hidden_state"):
captured["features"] = output.last_hidden_state
hook = self.language_model.layers[-1].register_forward_hook(capture_output)
try:
outputs = self.model(**model_input, output_hidden_states=True)
finally:
hook.remove()
return captured.get("features", outputs.hidden_states[-1])
def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
optional_keys = ["mm_token_type_ids", "pixel_values_videos", "video_grid_thw"]
model_input = {key: vl_input[key] for key in keys_to_use}
model_input.update({key: vl_input[key] for key in optional_keys if key in vl_input})
self._ensure_mm_token_type_ids(model_input)
self._ensure_legacy_qwen3_position_ids(model_input)
features = self._last_decoder_layer_output(model_input)
image_mask = model_input["input_ids"] == self.model.config.image_token_id
attention_mask = model_input["attention_mask"] == 1
return BatchFeature(
data={
"backbone_features": features,
"backbone_attention_mask": attention_mask,
"image_mask": image_mask,
}
)
class GR00TN17ActionHead(nn.Module):
supports_gradient_checkpointing = True
def __init__(self, config: GR00TN17Config):
require_package("diffusers", extra="groot")
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.input_embedding_dim = config.input_embedding_dim
if config.use_alternate_vl_dit:
self.model = AlternateVLDiT(
**config.diffusion_model_cfg,
cross_attention_dim=config.backbone_embedding_dim,
attend_text_every_n_blocks=config.attend_text_every_n_blocks,
)
else:
self.model = DiT(
**config.diffusion_model_cfg,
cross_attention_dim=config.backbone_embedding_dim,
)
self.action_dim = config.max_action_dim
self.action_horizon = config.action_horizon
self.num_inference_timesteps = config.num_inference_timesteps
self.state_encoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=config.max_state_dim * config.state_history_length,
hidden_dim=self.hidden_size,
output_dim=self.input_embedding_dim,
)
self.action_encoder = MultiEmbodimentActionEncoder(
action_dim=self.action_dim,
hidden_size=self.input_embedding_dim,
num_embodiments=config.max_num_embodiments,
)
self.action_decoder = CategorySpecificMLP(
num_categories=config.max_num_embodiments,
input_dim=self.hidden_size,
hidden_dim=self.hidden_size,
output_dim=self.action_dim,
)
self.vlln = nn.LayerNorm(config.backbone_embedding_dim) if config.use_vlln else nn.Identity()
vl_self_attention_cfg = getattr(config, "vl_self_attention_cfg", None)
if vl_self_attention_cfg and vl_self_attention_cfg.get("num_layers", 0) > 0:
self.vl_self_attention = SelfAttentionTransformer(**vl_self_attention_cfg)
else:
self.vl_self_attention = nn.Identity()
if config.add_pos_embed:
self.position_embedding = nn.Embedding(config.max_seq_len, self.input_embedding_dim)
nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
self.state_dropout_prob = config.state_dropout_prob
self._noise_beta_alpha = config.noise_beta_alpha
self._noise_beta_beta = config.noise_beta_beta
self._beta_dist = None
self.num_timestep_buckets = config.num_timestep_buckets
self.set_trainable_parameters(config.tune_projector, config.tune_diffusion_model, config.tune_vlln)
def set_trainable_parameters(
self, tune_projector: bool, tune_diffusion_model: bool, tune_vlln: bool
) -> None:
self.tune_projector = tune_projector
self.tune_diffusion_model = tune_diffusion_model
self.tune_vlln = tune_vlln
for parameter in self.parameters():
parameter.requires_grad = True
if not tune_projector:
self.state_encoder.requires_grad_(False)
self.action_encoder.requires_grad_(False)
self.action_decoder.requires_grad_(False)
if self.config.add_pos_embed:
self.position_embedding.requires_grad_(False)
if not tune_diffusion_model:
self.model.requires_grad_(False)
if not tune_vlln:
self.vlln.requires_grad_(False)
self.vl_self_attention.requires_grad_(False)
def set_frozen_modules_to_eval_mode(self) -> None:
if self.training:
if not self.tune_projector:
self.state_encoder.eval()
self.action_encoder.eval()
self.action_decoder.eval()
if self.config.add_pos_embed:
self.position_embedding.eval()
if not self.tune_diffusion_model:
self.model.eval()
if not self.tune_vlln:
self.vlln.eval()
self.vl_self_attention.eval()
def sample_time(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
if self._beta_dist is None:
beta_alpha = torch.tensor(self._noise_beta_alpha, device="cpu", dtype=torch.float32)
beta_beta = torch.tensor(self._noise_beta_beta, device="cpu", dtype=torch.float32)
self._beta_dist = Beta(beta_alpha, beta_beta, validate_args=False)
sample = self._beta_dist.sample([batch_size]).to(device, dtype=dtype)
return (1 - sample) * self.config.noise_s
def process_backbone_output(self, backbone_output: BatchFeature) -> BatchFeature:
backbone_features = self.vlln(backbone_output["backbone_features"])
backbone_output["backbone_features"] = self.vl_self_attention(backbone_features)
return backbone_output
def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()
backbone_output = self.process_backbone_output(backbone_output)
vl_embeds = backbone_output.backbone_features
device = vl_embeds.device
embodiment_id = action_input.embodiment_id
if action_input.state.shape[1] != self.config.state_history_length:
raise ValueError("state history length does not match GR00T N1.7 config.")
state = action_input.state.view(action_input.state.shape[0], 1, -1)
state_features = self.state_encoder(state, embodiment_id)
if self.training and self.state_dropout_prob > 0:
do_dropout = (
torch.rand(state_features.shape[0], device=state_features.device) < self.state_dropout_prob
)
state_features = state_features * (1 - do_dropout[:, None, None].to(dtype=state_features.dtype))
actions = action_input.action
noise = torch.randn(actions.shape, device=actions.device, dtype=actions.dtype)
t = self.sample_time(actions.shape[0], device=actions.device, dtype=actions.dtype)
t = t[:, None, None]
noisy_trajectory = (1 - t) * noise + t * actions
velocity = actions - noise
t_discretized = (t[:, 0, 0] * self.num_timestep_buckets).long()
action_features = self.action_encoder(noisy_trajectory, t_discretized, embodiment_id)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
sa_embs = torch.cat((state_features, action_features), dim=1)
if self.config.use_alternate_vl_dit:
model_output, _ = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
encoder_attention_mask=backbone_output.backbone_attention_mask,
timestep=t_discretized,
return_all_hidden_states=True,
image_mask=backbone_output.image_mask,
backbone_attention_mask=backbone_output.backbone_attention_mask,
)
else:
model_output, _ = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
encoder_attention_mask=backbone_output.backbone_attention_mask,
timestep=t_discretized,
return_all_hidden_states=True,
)
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
return BatchFeature(
data={
"loss": loss,
"action_loss": action_loss,
"action_mask": action_mask,
"backbone_features": vl_embeds,
"state_features": state_features,
}
)
def _encode_features(self, backbone_output: BatchFeature, action_input: BatchFeature) -> BatchFeature:
backbone_output = self.process_backbone_output(backbone_output)
state = action_input.state
if state.shape[1] != self.config.state_history_length:
raise ValueError("state history length does not match GR00T N1.7 config.")
state = state.view(state.shape[0], 1, -1)
state_features = self.state_encoder(state, action_input.embodiment_id)
return BatchFeature(
data={"backbone_features": backbone_output.backbone_features, "state_features": state_features}
)
@torch.no_grad()
def get_action_with_features(
self,
backbone_features: torch.Tensor,
state_features: torch.Tensor,
embodiment_id: torch.Tensor,
backbone_output: BatchFeature,
action_input: BatchFeature,
options: dict[str, Any] | None = None,
) -> BatchFeature:
vl_embeds = backbone_features
batch_size = vl_embeds.shape[0]
device = vl_embeds.device
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.action_dim),
dtype=vl_embeds.dtype,
device=device,
)
dt = 1.0 / self.num_inference_timesteps
vel_strength = torch.ones_like(actions)
if "action" in action_input:
if options is None:
raise ValueError("RTC options are required when action is provided to get_action.")
action_horizon_before_padding = options["action_horizon"]
actions[:, : options["rtc_overlap_steps"], :] = action_input["action"][
:,
action_horizon_before_padding - options["rtc_overlap_steps"] : action_horizon_before_padding,
:,
]
vel_strength[:, : options["rtc_frozen_steps"], :] = 0.0
intermediate_steps = options["rtc_overlap_steps"] - options["rtc_frozen_steps"]
t = torch.linspace(0.0, 1.0, intermediate_steps + 2, device=device)
ramp = 1 - torch.exp(-options["rtc_ramp_rate"] * t)
ramp = ramp / ramp[-1].clamp_min(1e-8)
vel_strength[:, options["rtc_frozen_steps"] : options["rtc_overlap_steps"], :] = ramp[1:-1][
None, :, None
].to(device)
for t_step in range(self.num_inference_timesteps):
t_cont = t_step / float(self.num_inference_timesteps)
t_discretized = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full(size=(batch_size,), fill_value=t_discretized, device=device)
action_features = self.action_encoder(actions, timesteps_tensor, embodiment_id)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
action_features = action_features + self.position_embedding(pos_ids).unsqueeze(0)
sa_embs = torch.cat((state_features, action_features), dim=1)
if self.config.use_alternate_vl_dit:
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
image_mask=backbone_output.image_mask,
backbone_attention_mask=backbone_output.backbone_attention_mask,
)
else:
model_output = self.model(
hidden_states=sa_embs,
encoder_hidden_states=vl_embeds,
timestep=timesteps_tensor,
)
pred = self.action_decoder(model_output, embodiment_id)
actions = actions + dt * pred[:, -self.action_horizon :] * vel_strength
return BatchFeature(
data={
"action_pred": actions,
"backbone_features": vl_embeds,
"state_features": state_features,
}
)
@torch.no_grad()
def get_action(
self,
backbone_output: BatchFeature,
action_input: BatchFeature,
options: dict[str, Any] | None = None,
) -> BatchFeature:
features = self._encode_features(backbone_output, action_input)
return self.get_action_with_features(
backbone_features=features.backbone_features,
state_features=features.state_features,
embodiment_id=action_input.embodiment_id,
backbone_output=backbone_output,
action_input=action_input,
options=options,
)
@property
def device(self) -> torch.device:
return next(iter(self.parameters())).device
@property
def dtype(self) -> torch.dtype:
return next(iter(self.parameters())).dtype
def prepare_input(self, batch: dict[str, Any]) -> BatchFeature:
return BatchFeature(data=batch)
def _is_cosmos_reason2_backbone(model_name: str) -> bool:
return str(model_name).rstrip("/") == "nvidia/Cosmos-Reason2-2B"
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
if Qwen3VLConfig is None:
raise ImportError(
"Qwen3VLConfig is required for GR00T N1.7. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
)
return Qwen3VLConfig(
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=True,
text_config={
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 6144,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-6,
"rope_scaling": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_type": "default",
},
"rope_theta": 5000000,
"tie_word_embeddings": True,
"use_cache": True,
"vocab_size": 151936,
},
vision_config={
"deepstack_visual_indexes": [5, 11, 17],
"depth": 24,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1024,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4096,
"model_type": "qwen3_vl",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 2048,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
)
def get_backbone_cls(config: GR00TN17Config):
if (
config.backbone_model_type == "qwen"
or "nvidia/Cosmos-Reason2" in config.model_name
or "Qwen/Qwen3-VL" in config.model_name
):
return Qwen3Backbone
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
class GR00TN17(PreTrainedModel):
"""GR00T N1.7 model with a Cosmos-Reason2/Qwen3-VL backbone."""
config_class = GR00TN17Config
supports_gradient_checkpointing = True
def __init__(
self,
config: GR00TN17Config,
transformers_loading_kwargs: dict[str, Any] | None = None,
load_backbone_weights: bool = True,
):
super().__init__(config)
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
self.config = config
backbone_cls = get_backbone_cls(config)
self.backbone = backbone_cls(
model_name=config.model_name,
tune_llm=config.tune_llm,
tune_visual=config.tune_visual,
select_layer=config.select_layer,
reproject_vision=config.reproject_vision,
use_flash_attention=config.use_flash_attention,
load_bf16=config.load_bf16,
tune_top_llm_layers=config.tune_top_llm_layers,
trainable_params_fp32=config.backbone_trainable_params_fp32,
transformers_loading_kwargs=transformers_loading_kwargs,
load_pretrained_weights=load_backbone_weights,
)
self.action_head = GR00TN17ActionHead(config)
self.post_init()
def prepare_input(self, inputs: dict[str, Any]) -> tuple[BatchFeature, BatchFeature]:
global tree
if tree is None:
require_package("dm-tree", extra="groot", import_name="tree")
tree = importlib.import_module("tree")
backbone_inputs = self.backbone.prepare_input(inputs)
action_inputs = self.action_head.prepare_input(inputs)
def to_device_with_dtype(x):
if not isinstance(x, torch.Tensor):
return x
if torch.is_floating_point(x):
return x.to(self.device, dtype=self.dtype)
return x.to(self.device)
return (
tree.map_structure(to_device_with_dtype, backbone_inputs),
tree.map_structure(to_device_with_dtype, action_inputs),
)
def forward(self, inputs: dict[str, Any]) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
return self.action_head(backbone_outputs, action_inputs)
def get_action(self, inputs: dict[str, Any], options: dict[str, Any] | None = None) -> BatchFeature:
backbone_inputs, action_inputs = self.prepare_input(inputs)
backbone_outputs = self.backbone(backbone_inputs)
return self.action_head.get_action(backbone_outputs, action_inputs, options)
@property
def device(self) -> torch.device:
return next(iter(self.parameters())).device
@property
def dtype(self) -> torch.dtype:
return next(iter(self.parameters())).dtype
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
tune_visual = kwargs.pop("tune_visual", True)
tune_llm = kwargs.pop("tune_llm", False)
tune_projector = kwargs.pop("tune_projector", True)
tune_diffusion_model = kwargs.pop("tune_diffusion_model", True)
tune_vlln = kwargs.pop("tune_vlln", True)
transformers_loading_kwargs = kwargs.pop("transformers_loading_kwargs", None) or {
"trust_remote_code": True
}
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
for key in ("revision", "cache_dir", "local_files_only", "token"):
if key in kwargs:
transformers_loading_kwargs.setdefault(key, kwargs[key])
try:
local_model_path = snapshot_download(
pretrained_model_name_or_path,
repo_type="model",
revision=kwargs.get("revision"),
cache_dir=kwargs.get("cache_dir"),
local_files_only=kwargs.get("local_files_only", False),
token=kwargs.get("token"),
)
except (HFValidationError, RepositoryNotFoundError):
local_model_path = pretrained_model_name_or_path
pretrained_model = super().from_pretrained(
local_model_path,
transformers_loading_kwargs=transformers_loading_kwargs,
load_backbone_weights=load_backbone_weights,
**kwargs,
)
pretrained_model.backbone.set_trainable_parameters(
tune_visual=tune_visual,
tune_llm=tune_llm,
tune_top_llm_layers=pretrained_model.config.tune_top_llm_layers,
)
pretrained_model.action_head.set_trainable_parameters(
tune_projector=tune_projector,
tune_diffusion_model=tune_diffusion_model,
tune_vlln=tune_vlln,
)
return pretrained_model
def _register_with_transformers() -> None:
if AutoConfig is None or AutoModel is None:
return
try:
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config, exist_ok=True)
except TypeError:
with suppress(ValueError):
AutoConfig.register(GR00TN17Config.model_type, GR00TN17Config)
try:
AutoModel.register(GR00TN17Config, GR00TN17, exist_ok=True)
except TypeError:
with suppress(ValueError):
AutoModel.register(GR00TN17Config, GR00TN17)
_register_with_transformers()

View File

@@ -46,7 +46,15 @@ from lerobot.utils.constants import ACTION, OBS_IMAGES
from lerobot.utils.import_utils import require_package
from ..pretrained import PreTrainedPolicy
from .configuration_groot import GrootConfig
from .configuration_groot import (
GROOT_N1_5,
GROOT_N1_7,
GrootConfig,
infer_groot_model_version,
infer_groot_n1_7_action_execution_horizon,
infer_groot_n1_7_action_horizon,
normalize_groot_model_version,
)
from .groot_n1 import GR00TN15
T = TypeVar("T", bound="GrootPolicy")
@@ -67,6 +75,7 @@ class GrootPolicy(PreTrainedPolicy):
# Initialize GR00T model using ported components
self._groot_model = self._create_groot_model()
self._action_queue_steps = self._resolve_action_queue_steps()
self.reset()
@@ -82,13 +91,23 @@ class GrootPolicy(PreTrainedPolicy):
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
model = GR00TN15.from_pretrained(
pretrained_model_name_or_path=self.config.base_model_path,
tune_llm=self.config.tune_llm,
tune_visual=self.config.tune_visual,
tune_projector=self.config.tune_projector,
tune_diffusion_model=self.config.tune_diffusion_model,
)
model_kwargs = {
"pretrained_model_name_or_path": self.config.base_model_path,
"tune_llm": self.config.tune_llm,
"tune_visual": self.config.tune_visual,
"tune_projector": self.config.tune_projector,
"tune_diffusion_model": self.config.tune_diffusion_model,
}
if self.config.model_version == GROOT_N1_7:
from .groot_n1_7 import GR00TN17
model = GR00TN17.from_pretrained(
**model_kwargs,
tune_vlln=True,
transformers_loading_kwargs={"trust_remote_code": True},
)
else:
model = GR00TN15.from_pretrained(**model_kwargs)
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
model.config.compute_dtype = model.compute_dtype
@@ -97,7 +116,7 @@ class GrootPolicy(PreTrainedPolicy):
def reset(self):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
self._action_queue = deque([], maxlen=self._action_queue_steps)
@classmethod
def from_pretrained(
@@ -141,8 +160,13 @@ class GrootPolicy(PreTrainedPolicy):
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
requested_version = (
normalize_groot_model_version(config.model_version)
if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_5
)
print(
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
)
@@ -193,8 +217,12 @@ class GrootPolicy(PreTrainedPolicy):
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_5
# Create default config with the pretrained path
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
config = GrootConfig(
model_version=model_version,
base_model_path=str(pretrained_name_or_path),
)
# Add minimal visual feature required for validation
# validate_features() will automatically add state and action features
@@ -215,6 +243,25 @@ class GrootPolicy(PreTrainedPolicy):
if hasattr(config, key):
setattr(config, key, value)
config.model_version = normalize_groot_model_version(config.model_version)
inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != config.model_version:
raise ValueError(
f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'."
)
if config.model_version == GROOT_N1_7:
if config.max_state_dim == 64:
config.max_state_dim = 132
if config.max_action_dim == 32:
config.max_action_dim = 132
if config.chunk_size == 50:
config.chunk_size = 40
if config.n_action_steps == 50:
config.n_action_steps = 40
if tuple(config.image_size) == (224, 224):
config.image_size = (256, 256)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
@@ -225,18 +272,59 @@ class GrootPolicy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
def _resolve_action_queue_steps(self) -> int:
n_action_steps = int(self.config.n_action_steps)
if self.config.model_version != GROOT_N1_7:
return n_action_steps
checkpoint_action_horizon = infer_groot_n1_7_action_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
execution_horizon = infer_groot_n1_7_action_execution_horizon(
self.config.base_model_path,
self.config.embodiment_tag,
)
horizons = [n_action_steps]
if checkpoint_action_horizon is not None:
horizons.append(checkpoint_action_horizon)
if execution_horizon is not None:
horizons.append(execution_horizon)
return min(horizons)
def _filter_groot_inputs(self, batch: dict[str, Tensor], *, include_action: bool) -> dict[str, Tensor]:
allowed_base = {"state", "state_mask", "embodiment_id"}
if include_action:
allowed_base.update({"action", "action_mask"})
if self.config.model_version == GROOT_N1_7:
allowed_base.update(
{
"input_ids",
"attention_mask",
"pixel_values",
"image_grid_thw",
"mm_token_type_ids",
"pixel_values_videos",
"video_grid_thw",
}
)
allowed_base.add("action_mask")
else:
allowed_base.update({"action_mask"} if include_action else set())
return {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Training forward pass.
Delegates to Isaac-GR00T model.forward when inputs are compatible.
"""
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
allowed_base = {"state", "state_mask", "action", "action_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
groot_inputs = self._filter_groot_inputs(batch, include_action=True)
# Get device from model parameters
device = next(self.parameters()).device
@@ -261,15 +349,10 @@ class GrootPolicy(PreTrainedPolicy):
"""
self.eval()
# Build a clean input dict for GR00T: keep only tensors GR00T consumes
# Preprocessing is handled by the processor pipeline, so we just filter the batch
# NOTE: During inference, we should NOT pass action/action_mask (that's what we're predicting)
allowed_base = {"state", "state_mask", "embodiment_id"}
groot_inputs = {
k: v
for k, v in batch.items()
if (k in allowed_base or k.startswith("eagle_")) and not (k.startswith("next.") or k == "info")
}
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
# Get device from model parameters
device = next(self.parameters()).device
@@ -292,7 +375,7 @@ class GrootPolicy(PreTrainedPolicy):
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch)
self._action_queue.extend(actions.transpose(0, 1))
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
return self._action_queue.popleft()
# -------------------------

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff