From b928c123fbbe06d751d12109ffc5791179df268f Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Sat, 15 Nov 2025 22:37:23 +0100 Subject: [PATCH] add imagenet as a norm type --- src/lerobot/configs/types.py | 1 + src/lerobot/policies/xvla/modeling_xvla.py | 2 + src/lerobot/processor/normalize_processor.py | 23 ++++- test_3.py | 89 ++++++++++++++++++++ 4 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 test_3.py diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index 11a1f8d74..bbc6d95a8 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -37,6 +37,7 @@ class NormalizationMode(str, Enum): IDENTITY = "IDENTITY" QUANTILES = "QUANTILES" QUANTILE10 = "QUANTILE10" + IMAGENET = "IMAGENET" @dataclass diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index 56a111da3..a1cf37621 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -319,7 +319,9 @@ class XVLAPolicy(PreTrainedPolicy): return total_loss, log_dict def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + print("get_action_chunk") inputs = self._build_model_inputs(batch) + breakpoint() actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps) actions = self._trim_action_dim(actions) return actions diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 1257315c4..5dcc02e60 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -27,7 +27,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.constants import ACTION - +from lerobot.datasets.factory import IMAGENET_STATS from .converters import from_tensor_to_numpy, to_tensor from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry @@ -303,14 +303,15 @@ class _NormalizationMixin: ValueError: If an unsupported normalization mode is encountered. """ norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY) - if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats: + breakpoint() + if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats and norm_mode != NormalizationMode.IMAGENET: return tensor - if norm_mode not in ( NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX, NormalizationMode.QUANTILES, NormalizationMode.QUANTILE10, + NormalizationMode.IMAGENET, ): raise ValueError(f"Unsupported normalization mode: {norm_mode}") @@ -320,6 +321,22 @@ class _NormalizationMixin: if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype: self.to(device=tensor.device, dtype=tensor.dtype) + if norm_mode == NormalizationMode.IMAGENET: + mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype) + std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype) + + # Expand mean/std to match tensor dims (e.g., BCHW or BNCHW) + while mean.dim() < tensor.dim(): + mean = mean.unsqueeze(0) + std = std.unsqueeze(0) + + if inverse: + # De-normalize + return (tensor * std + mean) * 255.0 + + # Normalize + return (tensor / 255.0 - mean) / std + stats = self._tensor_stats[key] if norm_mode == NormalizationMode.MEAN_STD: mean = stats.get("mean", None) diff --git a/test_3.py b/test_3.py new file mode 100644 index 000000000..cebbe4818 --- /dev/null +++ b/test_3.py @@ -0,0 +1,89 @@ +from lerobot.policies.factory import make_policy, make_pre_post_processors +# from lerobot.policies.xvla.configuration_xvla import XVLAConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.envs.factory import make_env_config +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +import torch +import numpy as np +import random +torch.manual_seed(42) +random.seed(42) +np.random.seed(42) +observation_height: int = 224 +observation_width: int = 224 # todo: jadechoghari, image size is different for the two models +# create an observation dict +OBS = { + f"{OBS_IMAGES}.image": torch.randn(1, 3, observation_height, observation_width), + f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width), + OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string + "task": "put the object in the box", +} + +IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1) +IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1) +def fake_rgb(H, W): + arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8) + t = torch.from_numpy(arr).permute(2, 0, 1) # CHW + t = t.unsqueeze(0).float() + # normalize pixel to imagenet + return t + + +OBS[f"{OBS_IMAGES}.image"] = fake_rgb(observation_height, observation_width) +OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width) + +cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated") +cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated" +env_cfg = make_env_config("libero", task="libero_spatial") +policy = make_policy( + cfg=cfg, + env_cfg=env_cfg, +) + +policy.eval() + +preprocessor_overrides = { + "device_processor": {"device": str(cfg.device)}, +} + +preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg, + pretrained_path=cfg.pretrained_path, + preprocessor_overrides=preprocessor_overrides, +) + +observation = preprocessor(OBS) +inputs = policy._build_model_inputs(observation) +breakpoint() + +#### now the og model ########################################################### +from xvla.models.processing_xvla import XVLAProcessor + +processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2) +inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"]) + +for k in inputs.keys() & inputs_1.keys(): # intersection of keys + a = inputs[k] + b = inputs_1[k].to("cuda") + + print(f"\nšŸ”Ž Key: {k}") + + # Check shape + print(" shape:", a.shape, b.shape) + + # Check if close + if torch.allclose(a, b, atol=1e-5, rtol=1e-5): + print(" āœ”ļø tensors are equal (allclose)") + else: + diff = torch.abs(a - b) + print(" āŒ tensors differ") + print(" max diff:", diff.max().item()) + print(" mean diff:", diff.mean().item()) + breakpoint() + + +# (Pdb) inputs['input_ids'].shape +# torch.Size([1, 64]) +# (Pdb) inputs_1['input_ids'].shape +# torch.Size([1, 50]) +# (Pdb) [0, 0, :, :4, 0] \ No newline at end of file