diff --git a/pyproject.toml b/pyproject.toml index 7653e3bb8..241a5bcf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -259,6 +259,7 @@ all = [ "lerobot[smolvla]", # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn "lerobot[xvla]", + "lerobot[evo1]", "lerobot[hilserl]", "lerobot[async]", "lerobot[dev]", diff --git a/src/lerobot/policies/evo1/README.md b/src/lerobot/policies/evo1/README.md new file mode 120000 index 000000000..bcd30fe73 --- /dev/null +++ b/src/lerobot/policies/evo1/README.md @@ -0,0 +1 @@ +../../../../docs/source/evo1.mdx \ No newline at end of file diff --git a/src/lerobot/policies/evo1/internvl3_embedder.py b/src/lerobot/policies/evo1/internvl3_embedder.py index 4cf1b00e0..8962b8f0d 100644 --- a/src/lerobot/policies/evo1/internvl3_embedder.py +++ b/src/lerobot/policies/evo1/internvl3_embedder.py @@ -17,6 +17,7 @@ from __future__ import annotations import functools import logging from collections.abc import Sequence +from typing import TYPE_CHECKING import torch import torch.nn as nn @@ -24,7 +25,13 @@ import torchvision.transforms.functional as TF from PIL import Image from torchvision.transforms.functional import to_pil_image -from lerobot.utils.import_utils import require_package +from lerobot.utils.import_utils import _transformers_available, require_package + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoModel, AutoTokenizer +else: + AutoModel = None + AutoTokenizer = None IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) @@ -111,7 +118,6 @@ class InternVL3Embedder(nn.Module): self.gradient_checkpointing_use_reentrant = bool(gradient_checkpointing_use_reentrant) require_package("transformers", extra="evo1") - from transformers import AutoModel, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) if isinstance(model_dtype, str): diff --git a/uv.lock b/uv.lock index d35f51d3b..36560c289 100644 --- a/uv.lock +++ b/uv.lock @@ -3079,6 +3079,7 @@ requires-dist = [ { name = "lerobot", extras = ["diffusers-dep"], marker = "extra == 'multi-task-dit'" }, { name = "lerobot", extras = ["diffusion"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["dynamixel"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["evo1"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["feetech"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" }, { name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },