refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling

- Introduced constants for observation keys to enhance readability.
- Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly.
- Updated the environment state and agent position assignments to use the new constants, improving maintainability.
This commit is contained in:
Adil Zouitine
2025-07-21 18:13:40 +02:00
parent f2b79656eb
commit 75bc44c166

View File

@@ -21,6 +21,7 @@ import numpy as np
import torch
from torch import Tensor
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
@@ -43,22 +44,26 @@ class ImageProcessor:
processed_obs = {}
# Handle pixels key
if "pixels" in observation:
if isinstance(observation["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
else:
imgs = {"observation.image": observation["pixels"]}
# Copy all observations first
for key, value in observation.items():
processed_obs[key] = value
# Handle pixels key if present
pixels = observation.get("pixels")
if pixels is not None:
# Remove pixels from processed_obs since we'll replace it with processed images
processed_obs.pop("pixels", None)
# Determine image mapping
if isinstance(pixels, dict):
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
else:
imgs = {OBS_IMAGE: pixels}
# Process each image
for imgkey, img in imgs.items():
processed_img = self._process_single_image(img)
processed_obs[imgkey] = processed_img
# Copy other observations unchanged
for key, value in observation.items():
if key != "pixels":
processed_obs[key] = value
# Return new transition with processed observation
new_transition = transition.copy()
new_transition[TransitionKey.OBSERVATION] = processed_obs
@@ -130,7 +135,7 @@ class StateProcessor:
env_state = torch.from_numpy(observation["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
processed_obs["observation.environment_state"] = env_state
processed_obs[OBS_ENV_STATE] = env_state
# Remove original key
del processed_obs["environment_state"]
@@ -139,7 +144,7 @@ class StateProcessor:
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
processed_obs["observation.state"] = agent_pos
processed_obs[OBS_STATE] = agent_pos
# Remove original key
del processed_obs["agent_pos"]