mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
logits matching atol1e-2
This commit is contained in:
@@ -321,7 +321,6 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
print("get_action_chunk")
|
||||
inputs = self._build_model_inputs(batch)
|
||||
breakpoint()
|
||||
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
||||
actions = self._trim_action_dim(actions)
|
||||
return actions
|
||||
|
||||
@@ -303,7 +303,6 @@ class _NormalizationMixin:
|
||||
ValueError: If an unsupported normalization mode is encountered.
|
||||
"""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
breakpoint()
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats and norm_mode != NormalizationMode.IMAGENET:
|
||||
return tensor
|
||||
if norm_mode not in (
|
||||
|
||||
25
test_3.py
25
test_3.py
@@ -3,6 +3,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
from xvla.models.modeling_xvla import XVLA
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
@@ -15,7 +16,7 @@ observation_width: int = 224 # todo: jadechoghari, image size is different for t
|
||||
OBS = {
|
||||
f"{OBS_IMAGES}.image": torch.randn(1, 3, observation_height, observation_width),
|
||||
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
|
||||
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
|
||||
OBS_STATE: torch.randn(1, 20), # ONLY if OBS_STATE is already a string
|
||||
"task": "put the object in the box",
|
||||
}
|
||||
|
||||
@@ -54,13 +55,18 @@ preprocessor, postprocessor = make_pre_post_processors(
|
||||
|
||||
observation = preprocessor(OBS)
|
||||
inputs = policy._build_model_inputs(observation)
|
||||
breakpoint()
|
||||
|
||||
|
||||
#### now the og model ###########################################################
|
||||
from xvla.models.processing_xvla import XVLAProcessor
|
||||
|
||||
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
|
||||
inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
||||
domain_id = torch.tensor([int(3)], dtype=torch.long)
|
||||
inputs.update({
|
||||
"proprio": OBS[OBS_STATE].to("cuda"),
|
||||
"domain_id": domain_id.to("cuda"),
|
||||
})
|
||||
|
||||
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
|
||||
a = inputs[k]
|
||||
@@ -79,11 +85,22 @@ for k in inputs.keys() & inputs_1.keys(): # intersection of keys
|
||||
print(" ❌ tensors differ")
|
||||
print(" max diff:", diff.max().item())
|
||||
print(" mean diff:", diff.mean().item())
|
||||
breakpoint()
|
||||
|
||||
|
||||
model = XVLA.from_pretrained("/raid/jade/models/xvla-libero")
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
|
||||
action = model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||
# (Pdb) inputs['input_ids'].shape
|
||||
# torch.Size([1, 64])
|
||||
# (Pdb) inputs_1['input_ids'].shape
|
||||
# torch.Size([1, 50])
|
||||
# (Pdb) [0, 0, :, :4, 0]
|
||||
# (Pdb) [0, 0, :, :4, 0]
|
||||
action_1 = policy.model.generate_actions(**inputs, steps=10).squeeze(0).float().cpu().numpy()
|
||||
|
||||
#np all close
|
||||
print(np.allclose(action, action_1, atol=1e-4, rtol=1e-4))
|
||||
print("max diff:", np.max(np.abs(action - action_1)))
|
||||
print("mean diff:", np.mean(np.abs(action - action_1)))
|
||||
breakpoint()
|
||||
Reference in New Issue
Block a user