mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
remove imagenet dependency
This commit is contained in:
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.factory import IMAGENET_STATS
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
||||
from lerobot.processor import (
|
||||
@@ -265,6 +266,83 @@ class XVLAImageScaleProcessorStep(ProcessorStep):
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="xvla_imagenet_normalize")
|
||||
class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
|
||||
"""Normalize image observations using ImageNet statistics.
|
||||
|
||||
This processor step applies ImageNet normalization (mean and std) to image observations.
|
||||
It validates that input values are in the [0, 1] range before normalizing.
|
||||
|
||||
The normalization formula is: (image - mean) / std
|
||||
|
||||
Args:
|
||||
image_keys: List of observation keys that contain images to normalize.
|
||||
If None, will automatically detect keys starting with "observation.images."
|
||||
|
||||
Raises:
|
||||
ValueError: If image values are not in the [0, 1] range.
|
||||
"""
|
||||
|
||||
image_keys: list[str] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Normalize image observations using ImageNet statistics."""
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||
if obs is None:
|
||||
return new_transition
|
||||
|
||||
# Make a copy of observations to avoid modifying the original
|
||||
obs = obs.copy()
|
||||
|
||||
# Determine which keys to normalize
|
||||
keys_to_normalize = self.image_keys
|
||||
if keys_to_normalize is None:
|
||||
# Auto-detect image keys
|
||||
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
|
||||
|
||||
# Normalize each image
|
||||
for key in keys_to_normalize:
|
||||
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||
tensor = obs[key]
|
||||
|
||||
# Validate that values are in [0, 1] range
|
||||
min_val = tensor.min().item()
|
||||
max_val = tensor.max().item()
|
||||
if min_val < 0.0 or max_val > 1.0:
|
||||
raise ValueError(
|
||||
f"Image '{key}' has values outside [0, 1] range: "
|
||||
f"min={min_val:.4f}, max={max_val:.4f}. "
|
||||
f"ImageNet normalization requires input values in [0, 1]."
|
||||
)
|
||||
|
||||
# Apply ImageNet normalization
|
||||
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
|
||||
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
|
||||
|
||||
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
|
||||
while mean.dim() < tensor.dim():
|
||||
mean = mean.unsqueeze(0)
|
||||
std = std.unsqueeze(0)
|
||||
|
||||
# Normalize: (image - mean) / std
|
||||
obs[key] = (tensor - mean) / std
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features):
|
||||
"""ImageNet normalization doesn't change feature structure."""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return serializable configuration."""
|
||||
return {
|
||||
"image_keys": self.image_keys,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="xvla_add_domain_id")
|
||||
class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
||||
@@ -389,7 +467,7 @@ def make_xvla_libero_pre_post_processors() -> tuple[
|
||||
"""
|
||||
pre_processor_steps: list[ProcessorStep] = []
|
||||
post_processor_steps: list[ProcessorStep] = []
|
||||
pre_processor_steps.extend([LiberoProcessorStep(), XVLAAddDomainIdProcessorStep()])
|
||||
pre_processor_steps.extend([LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()])
|
||||
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
@@ -399,7 +477,3 @@ def make_xvla_libero_pre_post_processors() -> tuple[
|
||||
steps=post_processor_steps,
|
||||
),
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"XVLAAddDomainIdProcessorStep",
|
||||
]
|
||||
@@ -25,7 +25,6 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.factory import IMAGENET_STATS
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
@@ -304,18 +303,14 @@ class _NormalizationMixin:
|
||||
ValueError: If an unsupported normalization mode is encountered.
|
||||
"""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
if (
|
||||
norm_mode == NormalizationMode.IDENTITY
|
||||
or key not in self._tensor_stats
|
||||
and norm_mode != NormalizationMode.IMAGENET
|
||||
):
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
|
||||
return tensor
|
||||
|
||||
if norm_mode not in (
|
||||
NormalizationMode.MEAN_STD,
|
||||
NormalizationMode.MIN_MAX,
|
||||
NormalizationMode.QUANTILES,
|
||||
NormalizationMode.QUANTILE10,
|
||||
NormalizationMode.IMAGENET,
|
||||
):
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
@@ -325,22 +320,8 @@ class _NormalizationMixin:
|
||||
if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
|
||||
self.to(device=tensor.device, dtype=tensor.dtype)
|
||||
|
||||
if norm_mode == NormalizationMode.IMAGENET:
|
||||
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
|
||||
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
|
||||
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
|
||||
while mean.dim() < tensor.dim():
|
||||
mean = mean.unsqueeze(0)
|
||||
std = std.unsqueeze(0)
|
||||
|
||||
if inverse:
|
||||
# De-normalize
|
||||
return (tensor * std + mean) * 255.0
|
||||
|
||||
# Normalize
|
||||
return (tensor - mean) / std
|
||||
|
||||
stats = self._tensor_stats[key]
|
||||
|
||||
if norm_mode == NormalizationMode.MEAN_STD:
|
||||
mean = stats.get("mean", None)
|
||||
std = stats.get("std", None)
|
||||
@@ -576,4 +557,4 @@ def hotswap_stats(
|
||||
step.stats = stats
|
||||
# Re-initialize tensor_stats on the correct device.
|
||||
step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment]
|
||||
return rp
|
||||
return rp
|
||||
Reference in New Issue
Block a user