Files
lerobot-clone/src/lerobot/processor/observation_processor.py

158 lines
5.7 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import einops
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from .pipeline import ObservationProcessor, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="observation_processor")
class VanillaObservationProcessor(ObservationProcessor):
"""
Processes environment observations into the LeRobot format by handling both images and states.
Image processing:
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
- Adds a batch dimension if missing
- Supports single images and image dictionaries
State processing:
- Maps 'environment_state' to observation.environment_state
- Maps 'agent_pos' to observation.state
- Converts numpy arrays to tensors
- Adds a batch dimension if missing
"""
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
# Convert to tensor
img_tensor = torch.from_numpy(img)
# Add batch dimension if needed
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
# Validate image format
_, h, w, c = img_tensor.shape
if not (c < h and c < w):
raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}")
if img_tensor.dtype != torch.uint8:
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
# Convert to channel-first format
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
# Convert to float32 and normalize to [0, 1]
img_tensor = img_tensor.type(torch.float32) / 255.0
return img_tensor
def _process_observation(self, observation):
"""
Processes both image and state observations.
"""
processed_obs = observation.copy()
if "pixels" in processed_obs:
pixels = processed_obs.pop("pixels")
if isinstance(pixels, dict):
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
else:
imgs = {OBS_IMAGE: pixels}
for imgkey, img in imgs.items():
processed_obs[imgkey] = self._process_single_image(img)
if "environment_state" in processed_obs:
env_state_np = processed_obs.pop("environment_state")
env_state = torch.from_numpy(env_state_np).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
processed_obs[OBS_ENV_STATE] = env_state
if "agent_pos" in processed_obs:
agent_pos_np = processed_obs.pop("agent_pos")
agent_pos = torch.from_numpy(agent_pos_np).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
processed_obs[OBS_STATE] = agent_pos
return processed_obs
def observation(self, observation):
return self._process_observation(observation)
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
"""Transforms feature keys to a standardized contract.
This method handles several renaming patterns:
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
- environment_state -> OBS_ENV_STATE,
- agent_pos -> OBS_STATE,
- observation.environment_state -> OBS_ENV_STATE,
- observation.agent_pos -> OBS_STATE
"""
exact_pairs = {
"pixels": OBS_IMAGE,
"environment_state": OBS_ENV_STATE,
"agent_pos": OBS_STATE,
}
prefix_pairs = {
"pixels.": f"{OBS_IMAGES}.",
}
for key in list(features.keys()):
matched_prefix = False
for old_prefix, new_prefix in prefix_pairs.items():
prefixed_old = f"observation.{old_prefix}"
if key.startswith(prefixed_old):
suffix = key[len(prefixed_old) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if key.startswith(old_prefix):
suffix = key[len(old_prefix) :]
features[f"{new_prefix}{suffix}"] = features.pop(key)
matched_prefix = True
break
if matched_prefix:
continue
for old, new in exact_pairs.items():
if key == old or key == f"observation.{old}":
if key in features:
features[new] = features.pop(key)
break
return features