diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index 6b19fda59..e3b76fb56 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -21,6 +21,7 @@ import numpy as np import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.datasets.factory import IMAGENET_STATS from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.policies.xvla.utils import rotate6d_to_axis_angle from lerobot.processor import ( @@ -265,6 +266,83 @@ class XVLAImageScaleProcessorStep(ProcessorStep): } +@dataclass +@ProcessorStepRegistry.register(name="xvla_imagenet_normalize") +class XVLAImageNetNormalizeProcessorStep(ProcessorStep): + """Normalize image observations using ImageNet statistics. + + This processor step applies ImageNet normalization (mean and std) to image observations. + It validates that input values are in the [0, 1] range before normalizing. + + The normalization formula is: (image - mean) / std + + Args: + image_keys: List of observation keys that contain images to normalize. + If None, will automatically detect keys starting with "observation.images." + + Raises: + ValueError: If image values are not in the [0, 1] range. + """ + + image_keys: list[str] | None = None + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Normalize image observations using ImageNet statistics.""" + new_transition = transition.copy() + obs = new_transition.get(TransitionKey.OBSERVATION, {}) + if obs is None: + return new_transition + + # Make a copy of observations to avoid modifying the original + obs = obs.copy() + + # Determine which keys to normalize + keys_to_normalize = self.image_keys + if keys_to_normalize is None: + # Auto-detect image keys + keys_to_normalize = [k for k in obs if k.startswith("observation.images.")] + + # Normalize each image + for key in keys_to_normalize: + if key in obs and isinstance(obs[key], torch.Tensor): + tensor = obs[key] + + # Validate that values are in [0, 1] range + min_val = tensor.min().item() + max_val = tensor.max().item() + if min_val < 0.0 or max_val > 1.0: + raise ValueError( + f"Image '{key}' has values outside [0, 1] range: " + f"min={min_val:.4f}, max={max_val:.4f}. " + f"ImageNet normalization requires input values in [0, 1]." + ) + + # Apply ImageNet normalization + mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype) + std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype) + + # Expand mean/std to match tensor dims (e.g., BCHW or BNCHW) + while mean.dim() < tensor.dim(): + mean = mean.unsqueeze(0) + std = std.unsqueeze(0) + + # Normalize: (image - mean) / std + obs[key] = (tensor - mean) / std + + new_transition[TransitionKey.OBSERVATION] = obs + return new_transition + + def transform_features(self, features): + """ImageNet normalization doesn't change feature structure.""" + return features + + def get_config(self) -> dict[str, Any]: + """Return serializable configuration.""" + return { + "image_keys": self.image_keys, + } + + @dataclass @ProcessorStepRegistry.register(name="xvla_add_domain_id") class XVLAAddDomainIdProcessorStep(ProcessorStep): @@ -389,7 +467,7 @@ def make_xvla_libero_pre_post_processors() -> tuple[ """ pre_processor_steps: list[ProcessorStep] = [] post_processor_steps: list[ProcessorStep] = [] - pre_processor_steps.extend([LiberoProcessorStep(), XVLAAddDomainIdProcessorStep()]) + pre_processor_steps.extend([LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()]) post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()]) return ( PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( @@ -399,7 +477,3 @@ def make_xvla_libero_pre_post_processors() -> tuple[ steps=post_processor_steps, ), ) - -__all__ = [ - "XVLAAddDomainIdProcessorStep", -] \ No newline at end of file diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index ae799edd9..cb17bd341 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -25,7 +25,6 @@ import torch from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature -from lerobot.datasets.factory import IMAGENET_STATS from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.constants import ACTION @@ -304,18 +303,14 @@ class _NormalizationMixin: ValueError: If an unsupported normalization mode is encountered. """ norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY) - if ( - norm_mode == NormalizationMode.IDENTITY - or key not in self._tensor_stats - and norm_mode != NormalizationMode.IMAGENET - ): + if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats: return tensor + if norm_mode not in ( NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX, NormalizationMode.QUANTILES, NormalizationMode.QUANTILE10, - NormalizationMode.IMAGENET, ): raise ValueError(f"Unsupported normalization mode: {norm_mode}") @@ -325,22 +320,8 @@ class _NormalizationMixin: if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype: self.to(device=tensor.device, dtype=tensor.dtype) - if norm_mode == NormalizationMode.IMAGENET: - mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype) - std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype) - # Expand mean/std to match tensor dims (e.g., BCHW or BNCHW) - while mean.dim() < tensor.dim(): - mean = mean.unsqueeze(0) - std = std.unsqueeze(0) - - if inverse: - # De-normalize - return (tensor * std + mean) * 255.0 - - # Normalize - return (tensor - mean) / std - stats = self._tensor_stats[key] + if norm_mode == NormalizationMode.MEAN_STD: mean = stats.get("mean", None) std = stats.get("std", None) @@ -576,4 +557,4 @@ def hotswap_stats( step.stats = stats # Re-initialize tensor_stats on the correct device. step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment] - return rp + return rp \ No newline at end of file