mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
add imagenet as a norm type
This commit is contained in:
@@ -37,6 +37,7 @@ class NormalizationMode(str, Enum):
|
||||
IDENTITY = "IDENTITY"
|
||||
QUANTILES = "QUANTILES"
|
||||
QUANTILE10 = "QUANTILE10"
|
||||
IMAGENET = "IMAGENET"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -319,7 +319,9 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
return total_loss, log_dict
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
print("get_action_chunk")
|
||||
inputs = self._build_model_inputs(batch)
|
||||
breakpoint()
|
||||
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
||||
actions = self._trim_action_dim(actions)
|
||||
return actions
|
||||
|
||||
@@ -27,7 +27,7 @@ from torch import Tensor
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
from lerobot.datasets.factory import IMAGENET_STATS
|
||||
from .converters import from_tensor_to_numpy, to_tensor
|
||||
from .core import EnvTransition, PolicyAction, TransitionKey
|
||||
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
|
||||
@@ -303,14 +303,15 @@ class _NormalizationMixin:
|
||||
ValueError: If an unsupported normalization mode is encountered.
|
||||
"""
|
||||
norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY)
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
|
||||
breakpoint()
|
||||
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats and norm_mode != NormalizationMode.IMAGENET:
|
||||
return tensor
|
||||
|
||||
if norm_mode not in (
|
||||
NormalizationMode.MEAN_STD,
|
||||
NormalizationMode.MIN_MAX,
|
||||
NormalizationMode.QUANTILES,
|
||||
NormalizationMode.QUANTILE10,
|
||||
NormalizationMode.IMAGENET,
|
||||
):
|
||||
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
|
||||
|
||||
@@ -320,6 +321,22 @@ class _NormalizationMixin:
|
||||
if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
|
||||
self.to(device=tensor.device, dtype=tensor.dtype)
|
||||
|
||||
if norm_mode == NormalizationMode.IMAGENET:
|
||||
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
|
||||
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
|
||||
|
||||
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
|
||||
while mean.dim() < tensor.dim():
|
||||
mean = mean.unsqueeze(0)
|
||||
std = std.unsqueeze(0)
|
||||
|
||||
if inverse:
|
||||
# De-normalize
|
||||
return (tensor * std + mean) * 255.0
|
||||
|
||||
# Normalize
|
||||
return (tensor / 255.0 - mean) / std
|
||||
|
||||
stats = self._tensor_stats[key]
|
||||
if norm_mode == NormalizationMode.MEAN_STD:
|
||||
mean = stats.get("mean", None)
|
||||
|
||||
89
test_3.py
Normal file
89
test_3.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
# from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
np.random.seed(42)
|
||||
observation_height: int = 224
|
||||
observation_width: int = 224 # todo: jadechoghari, image size is different for the two models
|
||||
# create an observation dict
|
||||
OBS = {
|
||||
f"{OBS_IMAGES}.image": torch.randn(1, 3, observation_height, observation_width),
|
||||
f"{OBS_IMAGES}.image2": torch.randn(1, 3, observation_height, observation_width),
|
||||
OBS_STATE: torch.randn(1, 9), # ONLY if OBS_STATE is already a string
|
||||
"task": "put the object in the box",
|
||||
}
|
||||
|
||||
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
|
||||
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
|
||||
def fake_rgb(H, W):
|
||||
arr = np.random.randint(0, 255, (H, W, 3), dtype=np.uint8)
|
||||
t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
|
||||
t = t.unsqueeze(0).float()
|
||||
# normalize pixel to imagenet
|
||||
return t
|
||||
|
||||
|
||||
OBS[f"{OBS_IMAGES}.image"] = fake_rgb(observation_height, observation_width)
|
||||
OBS[f"{OBS_IMAGES}.image2"] = fake_rgb(observation_height, observation_width)
|
||||
|
||||
cfg = PreTrainedConfig.from_pretrained("/raid/jade/models/xvla-libero-og_migrated")
|
||||
cfg.pretrained_path = "/raid/jade/models/xvla-libero-og_migrated"
|
||||
env_cfg = make_env_config("libero", task="libero_spatial")
|
||||
policy = make_policy(
|
||||
cfg=cfg,
|
||||
env_cfg=env_cfg,
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
|
||||
preprocessor_overrides = {
|
||||
"device_processor": {"device": str(cfg.device)},
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
pretrained_path=cfg.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
|
||||
observation = preprocessor(OBS)
|
||||
inputs = policy._build_model_inputs(observation)
|
||||
breakpoint()
|
||||
|
||||
#### now the og model ###########################################################
|
||||
from xvla.models.processing_xvla import XVLAProcessor
|
||||
|
||||
processor = XVLAProcessor.from_pretrained("/raid/jade/models/xvla-libero", num_views=2)
|
||||
inputs_1 = processor([OBS[f"{OBS_IMAGES}.image"], OBS[f"{OBS_IMAGES}.image2"]], OBS["task"])
|
||||
|
||||
for k in inputs.keys() & inputs_1.keys(): # intersection of keys
|
||||
a = inputs[k]
|
||||
b = inputs_1[k].to("cuda")
|
||||
|
||||
print(f"\n🔎 Key: {k}")
|
||||
|
||||
# Check shape
|
||||
print(" shape:", a.shape, b.shape)
|
||||
|
||||
# Check if close
|
||||
if torch.allclose(a, b, atol=1e-5, rtol=1e-5):
|
||||
print(" ✔️ tensors are equal (allclose)")
|
||||
else:
|
||||
diff = torch.abs(a - b)
|
||||
print(" ❌ tensors differ")
|
||||
print(" max diff:", diff.max().item())
|
||||
print(" mean diff:", diff.mean().item())
|
||||
breakpoint()
|
||||
|
||||
|
||||
# (Pdb) inputs['input_ids'].shape
|
||||
# torch.Size([1, 64])
|
||||
# (Pdb) inputs_1['input_ids'].shape
|
||||
# torch.Size([1, 50])
|
||||
# (Pdb) [0, 0, :, :4, 0]
|
||||
Reference in New Issue
Block a user