add imagenet as a norm type

This commit is contained in:
Jade Choghari
2025-11-15 22:37:23 +01:00
parent f52cf79d8e
commit b928c123fb
4 changed files with 112 additions and 3 deletions

View File

@@ -37,6 +37,7 @@ class NormalizationMode(str, Enum):
IDENTITY = "IDENTITY"
QUANTILES = "QUANTILES"
QUANTILE10 = "QUANTILE10"
IMAGENET = "IMAGENET"
@dataclass

View File

@@ -319,7 +319,9 @@ class XVLAPolicy(PreTrainedPolicy):
return total_loss, log_dict
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
print("get_action_chunk")
inputs = self._build_model_inputs(batch)
breakpoint()
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
actions = self._trim_action_dim(actions)
return actions

View File

@@ -27,7 +27,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION
from lerobot.datasets.factory import IMAGENET_STATS
from .converters import from_tensor_to_numpy, to_tensor
from .core import EnvTransition, PolicyAction, TransitionKey
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
@@ -303,14 +303,15 @@ class _NormalizationMixin:
ValueError: If an unsupported normalization mode is encountered.
"""
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
breakpoint()
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats and norm_mode != NormalizationMode.IMAGENET:
return tensor
if norm_mode not in (
NormalizationMode.MEAN_STD,
NormalizationMode.MIN_MAX,
NormalizationMode.QUANTILES,
NormalizationMode.QUANTILE10,
NormalizationMode.IMAGENET,
):
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
@@ -320,6 +321,22 @@ class _NormalizationMixin:
if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
self.to(device=tensor.device, dtype=tensor.dtype)
if norm_mode == NormalizationMode.IMAGENET:
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
while mean.dim() < tensor.dim():
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
if inverse:
# De-normalize
return (tensor * std + mean) * 255.0
# Normalize
return (tensor / 255.0 - mean) / std
stats = self._tensor_stats[key]
if norm_mode == NormalizationMode.MEAN_STD:
mean = stats.get("mean", None)

89
test_3.py Normal file
View File

@@ -0,0 +1,89 @@
from lerobot.policies.factory import make_policy, make_pre_post_processors
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.envs.factory import make_env_config
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
import torch
import numpy as np
import random
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
observation_height: int = 224
observation_width: int = 224 # todo: jadechoghari, image size is different for the two models
# create an observation dict
OBS = {
f"{OBS_IMAGES}.image": torch.randn(1, 3, observation_height, observation_width),
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
"task": "put the object in the box",
}
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
def fake_rgb(H, W):
arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
t = t.unsqueeze(0).float()
# normalize pixel to imagenet
return t
OBS[f"{OBS_IMAGES}.image"] = fake_rgb(observation_height, observation_width)
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
env_cfg = make_env_config("libero", task="libero_spatial")
policy = make_policy(
cfg=cfg,
env_cfg=env_cfg,
)
policy.eval()
preprocessor_overrides = {
"device_processor": {"device": str(cfg.device)},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg,
pretrained_path=cfg.pretrained_path,
preprocessor_overrides=preprocessor_overrides,
)
observation = preprocessor(OBS)
inputs = policy._build_model_inputs(observation)
breakpoint()
#### now the og model ###########################################################
from xvla.models.processing_xvla import XVLAProcessor
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
a = inputs[k]
b = inputs_1[k].to("cuda")
print(f"\n🔎 Key: {k}")
# Check shape
print(" shape:", a.shape, b.shape)
# Check if close
if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
print(" ✔️ tensors are equal (allclose)")
else:
diff = torch.abs(a - b)
print(" ❌ tensors differ")
print(" max diff:", diff.max().item())
print(" mean diff:", diff.mean().item())
breakpoint()
# (Pdb) inputs['input_ids'].shape
# torch.Size([1, 64])
# (Pdb) inputs_1['input_ids'].shape
# torch.Size([1, 50])
# (Pdb) [0, 0, :, :4, 0]