logits matching atol1e-2

This commit is contained in:
Jade Choghari
2025-11-15 22:55:49 +01:00
parent b928c123fb
commit cde2e24d79
3 changed files with 21 additions and 6 deletions

View File

@@ -321,7 +321,6 @@ class XVLAPolicy(PreTrainedPolicy):
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
print("get_action_chunk")
inputs = self._build_model_inputs(batch)
breakpoint()
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
actions = self._trim_action_dim(actions)
return actions

View File

@@ -303,7 +303,6 @@ class _NormalizationMixin:
ValueError: If an unsupported normalization mode is encountered.
"""
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
breakpoint()
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats and norm_mode != NormalizationMode.IMAGENET:
return tensor
if norm_mode not in (

View File

@@ -3,6 +3,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from xvla.models.modeling_xvla import XVLA
import torch
import numpy as np
import random
@@ -15,7 +16,7 @@ observation_width: int = 224 # todo: jadechoghari, image size is different for t
OBS = {
f"{OBS_IMAGES}.image": torch.randn(1, 3, observation_height, observation_width),
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
OBS_STATE: torch.randn(1, 20), # ONLY if OBS_STATE is already a string
"task": "put the object in the box",
}
@@ -54,13 +55,18 @@ preprocessor, postprocessor = make_pre_post_processors(
observation = preprocessor(OBS)
inputs = policy._build_model_inputs(observation)
breakpoint()
#### now the og model ###########################################################
from xvla.models.processing_xvla import XVLAProcessor
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
domain_id = torch.tensor([int(3)], dtype=torch.long)
inputs.update({
"proprio": OBS[OBS_STATE].to("cuda"),
"domain_id": domain_id.to("cuda"),
})
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
a = inputs[k]
@@ -79,11 +85,22 @@ for k in inputs.keys() & inputs_1.keys(): # intersection of keys
print(" ❌ tensors differ")
print(" max diff:", diff.max().item())
print(" mean diff:", diff.mean().item())
breakpoint()
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
model.eval()
model.to("cuda")
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
# (Pdb) inputs['input_ids'].shape
# torch.Size([1, 64])
# (Pdb) inputs_1['input_ids'].shape
# torch.Size([1, 50])
# (Pdb) [0, 0, :, :4, 0]
# (Pdb) [0, 0, :, :4, 0]
action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
#np all close
print(np.allclose(action, action_1, atol=1e-4, rtol=1e-4))
print("max diff:", np.max(np.abs(action - action_1)))
print("mean diff:", np.mean(np.abs(action - action_1)))
breakpoint()