mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
refactor(observation): Streamline observation preprocessing and remove unused processor methods
- Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting. - Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow. - Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script.
This commit is contained in:
@@ -16,8 +16,10 @@
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -26,40 +28,62 @@ from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
|
||||
This function uses the new pipeline system internally but maintains
|
||||
backward compatibility with the original interface.
|
||||
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
# Create processor with observation processor
|
||||
processor = RobotProcessor([VanillaObservationProcessor()])
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# Create transition dictionary and process
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observations,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: None,
|
||||
TransitionKey.DONE: None,
|
||||
TransitionKey.TRUNCATED: None,
|
||||
TransitionKey.INFO: None,
|
||||
TransitionKey.COMPLEMENTARY_DATA: None,
|
||||
}
|
||||
result = processor(transition)
|
||||
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
|
||||
# This is the case for human-in-the-loop RL where there is only one environment.
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# Extract and return the processed observation
|
||||
return result[TransitionKey.OBSERVATION]
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
|
||||
if "environment_state" in observations:
|
||||
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
|
||||
return_observations["observation.environment_state"] = env_state
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations["observation.state"] = agent_pos
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to externalize normalization from policies)
|
||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||
policy_features = {}
|
||||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
|
||||
@@ -68,11 +68,10 @@ from tqdm import trange
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
|
||||
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.processor import RobotProcessor, TransitionKey, VanillaObservationProcessor
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -129,16 +128,6 @@ def rollout(
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
# Create observation processing processor
|
||||
# NOTE: During environment interaction, we skip batch dictionary conversion
|
||||
# since that format is only needed for loss computation during training.
|
||||
# Using identity functions to avoid unnecessary format transformations.
|
||||
obs_processor = RobotProcessor(
|
||||
[VanillaObservationProcessor()],
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
@@ -158,13 +147,10 @@ def rollout(
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done):
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_processor(transition)
|
||||
observation = processed_transition[TransitionKey.OBSERVATION]
|
||||
observation = preprocess_observation(observation)
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
# TODO(azouitine): Move this in processor side
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
}
|
||||
@@ -209,9 +195,7 @@ def rollout(
|
||||
|
||||
# Track the final observation.
|
||||
if return_observations:
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_processor(transition)
|
||||
observation = processed_transition[TransitionKey.OBSERVATION]
|
||||
observation = preprocess_observation(observation)
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
@@ -517,6 +501,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
init_logging()
|
||||
eval_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user