diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index e31d8b043..45d83e652 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -17,6 +17,7 @@ from __future__ import annotations import logging from collections import deque from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import torch @@ -27,7 +28,13 @@ from torch import Tensor, nn from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.utils import populate_queues from lerobot.utils.constants import ACTION, OBS_STATE -from lerobot.utils.import_utils import require_package +from lerobot.utils.import_utils import _transformers_available, require_package + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoModel, AutoVideoProcessor +else: + AutoModel = None + AutoVideoProcessor = None from .action_head import VLAJEPAActionHead from .configuration_vla_jepa import VLAJEPAConfig @@ -74,8 +81,6 @@ class VLAJEPAModel(nn.Module): # JEPA world model components if config.enable_world_model: - from transformers import AutoModel, AutoVideoProcessor - self.video_encoder = AutoModel.from_pretrained( config.jepa_encoder_name, torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype), diff --git a/src/lerobot/policies/vla_jepa/qwen_interface.py b/src/lerobot/policies/vla_jepa/qwen_interface.py index 8482ceb8c..24f530efc 100644 --- a/src/lerobot/policies/vla_jepa/qwen_interface.py +++ b/src/lerobot/policies/vla_jepa/qwen_interface.py @@ -15,19 +15,26 @@ from __future__ import annotations from collections.abc import Sequence +from typing import TYPE_CHECKING import numpy as np import torch from PIL import Image +from lerobot.utils.import_utils import _transformers_available + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoProcessor, Qwen3VLForConditionalGeneration +else: + AutoProcessor = None + Qwen3VLForConditionalGeneration = None + from .configuration_vla_jepa import VLAJEPAConfig class Qwen3VLInterface(torch.nn.Module): def __init__(self, config: VLAJEPAConfig) -> None: super().__init__() - from transformers import AutoProcessor, Qwen3VLForConditionalGeneration - self.config = config self.model = Qwen3VLForConditionalGeneration.from_pretrained( config.qwen_model_name, diff --git a/tests/policies/vla_jepa/conftest.py b/tests/policies/vla_jepa/conftest.py index c1feef329..5301b5bc7 100644 --- a/tests/policies/vla_jepa/conftest.py +++ b/tests/policies/vla_jepa/conftest.py @@ -258,10 +258,16 @@ class _FakeVideoProcessor: @pytest.fixture def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None: - from transformers import AutoModel, AutoVideoProcessor - from lerobot.policies.vla_jepa import modeling_vla_jepa monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface) - monkeypatch.setattr(AutoModel, "from_pretrained", lambda *args, **kwargs: _FakeVideoEncoder()) - monkeypatch.setattr(AutoVideoProcessor, "from_pretrained", lambda *args, **kwargs: _FakeVideoProcessor()) + monkeypatch.setattr( + modeling_vla_jepa.AutoModel, + "from_pretrained", + lambda *args, **kwargs: _FakeVideoEncoder(), + ) + monkeypatch.setattr( + modeling_vla_jepa.AutoVideoProcessor, + "from_pretrained", + lambda *args, **kwargs: _FakeVideoProcessor(), + )