diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 244bee241..bee33f434 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -21,6 +21,7 @@ import numpy as np import torch from torch import Tensor +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey @@ -43,22 +44,26 @@ class ImageProcessor: processed_obs = {} - # Handle pixels key - if "pixels" in observation: - if isinstance(observation["pixels"], dict): - imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()} - else: - imgs = {"observation.image": observation["pixels"]} + # Copy all observations first + for key, value in observation.items(): + processed_obs[key] = value + # Handle pixels key if present + pixels = observation.get("pixels") + if pixels is not None: + # Remove pixels from processed_obs since we'll replace it with processed images + processed_obs.pop("pixels", None) + # Determine image mapping + if isinstance(pixels, dict): + imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()} + else: + imgs = {OBS_IMAGE: pixels} + + # Process each image for imgkey, img in imgs.items(): processed_img = self._process_single_image(img) processed_obs[imgkey] = processed_img - # Copy other observations unchanged - for key, value in observation.items(): - if key != "pixels": - processed_obs[key] = value - # Return new transition with processed observation new_transition = transition.copy() new_transition[TransitionKey.OBSERVATION] = processed_obs @@ -130,7 +135,7 @@ class StateProcessor: env_state = torch.from_numpy(observation["environment_state"]).float() if env_state.dim() == 1: env_state = env_state.unsqueeze(0) - processed_obs["observation.environment_state"] = env_state + processed_obs[OBS_ENV_STATE] = env_state # Remove original key del processed_obs["environment_state"] @@ -139,7 +144,7 @@ class StateProcessor: agent_pos = torch.from_numpy(observation["agent_pos"]).float() if agent_pos.dim() == 1: agent_pos = agent_pos.unsqueeze(0) - processed_obs["observation.state"] = agent_pos + processed_obs[OBS_STATE] = agent_pos # Remove original key del processed_obs["agent_pos"]