Revert "trying to isolate pi0 side-effect"

This reverts commit ec42dd9973.
This commit is contained in:
Maximellerbach
2026-06-01 16:47:54 +02:00
parent ec42dd9973
commit 90c15db1a7
3 changed files with 27 additions and 9 deletions

View File

@@ -17,6 +17,7 @@ from __future__ import annotations
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
import torch
@@ -27,7 +28,13 @@ from torch import Tensor, nn
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.import_utils import require_package
from lerobot.utils.import_utils import _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import AutoModel, AutoVideoProcessor
else:
AutoModel = None
AutoVideoProcessor = None
from .action_head import VLAJEPAActionHead
from .configuration_vla_jepa import VLAJEPAConfig
@@ -74,8 +81,6 @@ class VLAJEPAModel(nn.Module):
# JEPA world model components
if config.enable_world_model:
from transformers import AutoModel, AutoVideoProcessor
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),

View File

@@ -15,19 +15,26 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import numpy as np
import torch
from PIL import Image
from lerobot.utils.import_utils import _transformers_available
if TYPE_CHECKING or _transformers_available:
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
else:
AutoProcessor = None
Qwen3VLForConditionalGeneration = None
from .configuration_vla_jepa import VLAJEPAConfig
class Qwen3VLInterface(torch.nn.Module):
def __init__(self, config: VLAJEPAConfig) -> None:
super().__init__()
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
self.config = config
self.model = Qwen3VLForConditionalGeneration.from_pretrained(
config.qwen_model_name,

View File

@@ -258,10 +258,16 @@ class _FakeVideoProcessor:
@pytest.fixture
def patch_vla_jepa_external_models(monkeypatch: pytest.MonkeyPatch) -> None:
from transformers import AutoModel, AutoVideoProcessor
from lerobot.policies.vla_jepa import modeling_vla_jepa
monkeypatch.setattr(modeling_vla_jepa, "Qwen3VLInterface", _FakeQwenInterface)
monkeypatch.setattr(AutoModel, "from_pretrained", lambda *args, **kwargs: _FakeVideoEncoder())
monkeypatch.setattr(AutoVideoProcessor, "from_pretrained", lambda *args, **kwargs: _FakeVideoProcessor())
monkeypatch.setattr(
modeling_vla_jepa.AutoModel,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoEncoder(),
)
monkeypatch.setattr(
modeling_vla_jepa.AutoVideoProcessor,
"from_pretrained",
lambda *args, **kwargs: _FakeVideoProcessor(),
)