remove imagenet dependency

This commit is contained in:
Jade Choghari
2025-11-21 10:43:34 +01:00
parent 7cfe4c768f
commit 9d13b6ceea
2 changed files with 83 additions and 28 deletions

View File

@@ -21,6 +21,7 @@ import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.datasets.factory import IMAGENET_STATS
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
from lerobot.processor import (
@@ -265,6 +266,83 @@ class XVLAImageScaleProcessorStep(ProcessorStep):
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_imagenet_normalize")
class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
"""Normalize image observations using ImageNet statistics.
This processor step applies ImageNet normalization (mean and std) to image observations.
It validates that input values are in the [0, 1] range before normalizing.
The normalization formula is: (image - mean) / std
Args:
image_keys: List of observation keys that contain images to normalize.
If None, will automatically detect keys starting with "observation.images."
Raises:
ValueError: If image values are not in the [0, 1] range.
"""
image_keys: list[str] | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Normalize image observations using ImageNet statistics."""
new_transition = transition.copy()
obs = new_transition.get(TransitionKey.OBSERVATION, {})
if obs is None:
return new_transition
# Make a copy of observations to avoid modifying the original
obs = obs.copy()
# Determine which keys to normalize
keys_to_normalize = self.image_keys
if keys_to_normalize is None:
# Auto-detect image keys
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
# Normalize each image
for key in keys_to_normalize:
if key in obs and isinstance(obs[key], torch.Tensor):
tensor = obs[key]
# Validate that values are in [0, 1] range
min_val = tensor.min().item()
max_val = tensor.max().item()
if min_val < 0.0 or max_val > 1.0:
raise ValueError(
f"Image '{key}' has values outside [0, 1] range: "
f"min={min_val:.4f}, max={max_val:.4f}. "
f"ImageNet normalization requires input values in [0, 1]."
)
# Apply ImageNet normalization
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
while mean.dim() < tensor.dim():
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
# Normalize: (image - mean) / std
obs[key] = (tensor - mean) / std
new_transition[TransitionKey.OBSERVATION] = obs
return new_transition
def transform_features(self, features):
"""ImageNet normalization doesn't change feature structure."""
return features
def get_config(self) -> dict[str, Any]:
"""Return serializable configuration."""
return {
"image_keys": self.image_keys,
}
@dataclass
@ProcessorStepRegistry.register(name="xvla_add_domain_id")
class XVLAAddDomainIdProcessorStep(ProcessorStep):
@@ -389,7 +467,7 @@ def make_xvla_libero_pre_post_processors() -> tuple[
"""
pre_processor_steps: list[ProcessorStep] = []
post_processor_steps: list[ProcessorStep] = []
pre_processor_steps.extend([LiberoProcessorStep(), XVLAAddDomainIdProcessorStep()])
pre_processor_steps.extend([LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()])
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
@@ -399,7 +477,3 @@ def make_xvla_libero_pre_post_processors() -> tuple[
steps=post_processor_steps,
),
)
__all__ = [
"XVLAAddDomainIdProcessorStep",
]

View File

@@ -25,7 +25,6 @@ import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.datasets.factory import IMAGENET_STATS
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION
@@ -304,18 +303,14 @@ class _NormalizationMixin:
ValueError: If an unsupported normalization mode is encountered.
"""
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
if (
norm_mode == NormalizationMode.IDENTITY
or key not in self._tensor_stats
and norm_mode != NormalizationMode.IMAGENET
):
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
return tensor
if norm_mode not in (
NormalizationMode.MEAN_STD,
NormalizationMode.MIN_MAX,
NormalizationMode.QUANTILES,
NormalizationMode.QUANTILE10,
NormalizationMode.IMAGENET,
):
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
@@ -325,22 +320,8 @@ class _NormalizationMixin:
if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
self.to(device=tensor.device, dtype=tensor.dtype)
if norm_mode == NormalizationMode.IMAGENET:
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
while mean.dim() < tensor.dim():
mean = mean.unsqueeze(0)
std = std.unsqueeze(0)
if inverse:
# De-normalize
return (tensor * std + mean) * 255.0
# Normalize
return (tensor - mean) / std
stats = self._tensor_stats[key]
if norm_mode == NormalizationMode.MEAN_STD:
mean = stats.get("mean", None)
std = stats.get("std", None)
@@ -576,4 +557,4 @@ def hotswap_stats(
step.stats = stats
# Re-initialize tensor_stats on the correct device.
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment]
return rp
return rp