Files
lerobot-clone/test_2.py
Jade Choghari f52cf79d8e logits matching
2025-11-15 19:23:27 +01:00

72 lines
2.4 KiB
Python

from xvla.models.processing_xvla import XVLAProcessor
from xvla.models.modeling_xvla import XVLA
from xvla.models.configuration_xvla import XVLAConfig
import torch
import random
import numpy as np
from PIL import Image
from lerobot.policies.factory import make_policy
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config
cfg = XVLAConfig.from_pretrained("/raid/jade/models/xvla-libero")
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
model.eval()
model.to("cuda")
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero")
# /raid/jade/models/xvla-libero
# seet seed
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
def make_random_pil_images(num_images=3, H=480, W=640):
images = []
for _ in range(num_images):
# Random RGB image
arr = np.random.randint(0, 256, (H, W, 3), dtype=np.uint8)
img = Image.fromarray(arr)
images.append(img)
return images
# Example:
images = make_random_pil_images()
language_instruction = "This is a random image"
# Multimodal preprocessing by processor
inputs = processor(images, language_instruction)
if not {"input_ids", "image_input", "image_mask"}.issubset(inputs):
raise ValueError("Processor did not return the expected keys.")
proprio = torch.randn(1, 20)
domain_id = torch.tensor([int(0)], dtype=torch.long)
# Align to model's device/dtype
device = model.device
dtype = next(model.parameters()).dtype
def to_model(t: torch.Tensor) -> torch.Tensor:
if not isinstance(t, torch.Tensor):
t = torch.as_tensor(t)
# cast floats to model dtype, keep integral/bool as-is
return t.to(device=device, dtype=dtype) if t.is_floating_point() else t.to(device=device)
inputs = {k: to_model(v) for k, v in inputs.items()}
inputs.update({
"proprio": to_model(proprio),
"domain_id": domain_id.to(device),
})
# Inference
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
#### now for lerobot model #####################################################
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
env_cfg = make_env_config("libero", task="libero_spatial")
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
policy = make_policy(cfg=cfg, env_cfg=env_cfg)
policy.eval()
policy.to("cuda")
action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
breakpoint()