mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling
- Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability.
This commit is contained in:
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
@@ -43,22 +44,26 @@ class ImageProcessor:
|
||||
|
||||
processed_obs = {}
|
||||
|
||||
# Handle pixels key
|
||||
if "pixels" in observation:
|
||||
if isinstance(observation["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observation["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observation["pixels"]}
|
||||
# Copy all observations first
|
||||
for key, value in observation.items():
|
||||
processed_obs[key] = value
|
||||
|
||||
# Handle pixels key if present
|
||||
pixels = observation.get("pixels")
|
||||
if pixels is not None:
|
||||
# Remove pixels from processed_obs since we'll replace it with processed images
|
||||
processed_obs.pop("pixels", None)
|
||||
# Determine image mapping
|
||||
if isinstance(pixels, dict):
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
|
||||
else:
|
||||
imgs = {OBS_IMAGE: pixels}
|
||||
|
||||
# Process each image
|
||||
for imgkey, img in imgs.items():
|
||||
processed_img = self._process_single_image(img)
|
||||
processed_obs[imgkey] = processed_img
|
||||
|
||||
# Copy other observations unchanged
|
||||
for key, value in observation.items():
|
||||
if key != "pixels":
|
||||
processed_obs[key] = value
|
||||
|
||||
# Return new transition with processed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
@@ -130,7 +135,7 @@ class StateProcessor:
|
||||
env_state = torch.from_numpy(observation["environment_state"]).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
processed_obs["observation.environment_state"] = env_state
|
||||
processed_obs[OBS_ENV_STATE] = env_state
|
||||
# Remove original key
|
||||
del processed_obs["environment_state"]
|
||||
|
||||
@@ -139,7 +144,7 @@ class StateProcessor:
|
||||
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
processed_obs["observation.state"] = agent_pos
|
||||
processed_obs[OBS_STATE] = agent_pos
|
||||
# Remove original key
|
||||
del processed_obs["agent_pos"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user