refactor(observation): Streamline observation preprocessing and remove unused processor methods

- Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting.
- Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow.
- Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script.
This commit is contained in:
Adil Zouitine
2025-08-05 10:32:56 +02:00
parent 8077456c00
commit 05bd18f453
2 changed files with 53 additions and 41 deletions

View File

@@ -16,8 +16,10 @@
import warnings
from typing import Any
import einops
import gymnasium as gym
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -26,40 +28,62 @@ from lerobot.utils.utils import get_channel_first_image_shape
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
"""Convert environment observation to LeRobot format observation.
This function uses the new pipeline system internally but maintains
backward compatibility with the original interface.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
# map to expected inputs for the policy
return_observations = {}
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
else:
imgs = {"observation.image": observations["pixels"]}
# Create processor with observation processor
processor = RobotProcessor([VanillaObservationProcessor()])
for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)
# Create transition dictionary and process
transition = {
TransitionKey.OBSERVATION: observations,
TransitionKey.ACTION: None,
TransitionKey.REWARD: None,
TransitionKey.DONE: None,
TransitionKey.TRUNCATED: None,
TransitionKey.INFO: None,
TransitionKey.COMPLEMENTARY_DATA: None,
}
result = processor(transition)
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
# This is the case for human-in-the-loop RL where there is only one environment.
if img.ndim == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
# Extract and return the processed observation
return result[TransitionKey.OBSERVATION]
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
return_observations[imgkey] = img
if "environment_state" in observations:
env_state = torch.from_numpy(observations["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
return_observations["observation.environment_state"] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations["observation.state"] = agent_pos
return return_observations
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to externalize normalization from policies)
# (need to also refactor preprocess_observation and externalize normalization from policies)
policy_features = {}
for key, ft in env_cfg.features.items():
if ft.type is FeatureType.VISUAL:

View File

@@ -68,11 +68,10 @@ from tqdm import trange
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -129,16 +128,6 @@ def rollout(
if render_callback is not None:
render_callback(env)
# Create observation processing processor
# NOTE: During environment interaction, we skip batch dictionary conversion
# since that format is only needed for loss computation during training.
# Using identity functions to avoid unnecessary format transformations.
obs_processor = RobotProcessor(
[VanillaObservationProcessor()],
to_transition=lambda x: x,
to_output=lambda x: x,
)
all_observations = []
all_actions = []
all_rewards = []
@@ -158,13 +147,10 @@ def rollout(
check_env_attributes_and_types(env)
while not np.all(done):
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
transition = (observation, None, None, None, None, None, None)
processed_transition = obs_processor(transition)
observation = processed_transition[TransitionKey.OBSERVATION]
observation = preprocess_observation(observation)
if return_observations:
all_observations.append(deepcopy(observation))
# TODO(azouitine): Move this in processor side
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
@@ -209,9 +195,7 @@ def rollout(
# Track the final observation.
if return_observations:
transition = (observation, None, None, None, None, None, None)
processed_transition = obs_processor(transition)
observation = processed_transition[TransitionKey.OBSERVATION]
observation = preprocess_observation(observation)
all_observations.append(deepcopy(observation))
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
@@ -517,6 +501,10 @@ def eval_main(cfg: EvalPipelineConfig):
logging.info("End of eval")
if __name__ == "__main__":
def main():
init_logging()
eval_main()
if __name__ == "__main__":
main()