diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 45d83e652..e31d8b043 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -17,7 +17,6 @@ from __future__ import annotations import logging from collections import deque from pathlib import Path -from typing import TYPE_CHECKING import numpy as np import torch @@ -28,13 +27,7 @@ from torch import Tensor, nn from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.utils import populate_queues from lerobot.utils.constants import ACTION, OBS_STATE -from lerobot.utils.import_utils import _transformers_available, require_package - -if TYPE_CHECKING or _transformers_available: - from transformers import AutoModel, AutoVideoProcessor -else: - AutoModel = None - AutoVideoProcessor = None +from lerobot.utils.import_utils import require_package from .action_head import VLAJEPAActionHead from .configuration_vla_jepa import VLAJEPAConfig @@ -81,6 +74,8 @@ class VLAJEPAModel(nn.Module): # JEPA world model components if config.enable_world_model: + from transformers import AutoModel, AutoVideoProcessor + self.video_encoder = AutoModel.from_pretrained( config.jepa_encoder_name, torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype), diff --git a/src/lerobot/policies/vla_jepa/qwen_interface.py b/src/lerobot/policies/vla_jepa/qwen_interface.py index 24f530efc..8482ceb8c 100644 --- a/src/lerobot/policies/vla_jepa/qwen_interface.py +++ b/src/lerobot/policies/vla_jepa/qwen_interface.py @@ -15,26 +15,19 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING import numpy as np import torch from PIL import Image -from lerobot.utils.import_utils import _transformers_available - -if TYPE_CHECKING or _transformers_available: - from transformers import AutoProcessor, Qwen3VLForConditionalGeneration -else: - AutoProcessor = None - Qwen3VLForConditionalGeneration = None - from .configuration_vla_jepa import VLAJEPAConfig class Qwen3VLInterface(torch.nn.Module): def __init__(self, config: VLAJEPAConfig) -> None: super().__init__() + from transformers import AutoProcessor, Qwen3VLForConditionalGeneration + self.config = config self.model = Qwen3VLForConditionalGeneration.from_pretrained( config.qwen_model_name, diff --git a/tests/policies/vla_jepa/conftest.py b/tests/policies/vla_jepa/conftest.py index 5301b5bc7..c1feef329 100644 --- a/tests/policies/vla_jepa/conftest.py +++ b/tests/policies/vla_jepa/conftest.py @@ -258,16 +258,10 @@ class _FakeVideoProcessor: @pytest.fixture def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None: + from transformers import AutoModel, AutoVideoProcessor + from lerobot.policies.vla_jepa import modeling_vla_jepa monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface) - monkeypatch.setattr( - modeling_vla_jepa.AutoModel, - "from_pretrained", - lambda *args, **kwargs: _FakeVideoEncoder(), - ) - monkeypatch.setattr( - modeling_vla_jepa.AutoVideoProcessor, - "from_pretrained", - lambda *args, **kwargs: _FakeVideoProcessor(), - ) + monkeypatch.setattr(AutoModel, "from_pretrained", lambda *args, **kwargs: _FakeVideoEncoder()) + monkeypatch.setattr(AutoVideoProcessor, "from_pretrained", lambda *args, **kwargs: _FakeVideoProcessor())