mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
refactor(pipeline): minor improvements (#1684)
* chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes
This commit is contained in:
@@ -13,8 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -23,52 +22,27 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageProcessor:
|
||||
"""Process image observations from environment format to policy format.
|
||||
|
||||
Converts images from:
|
||||
- Channel-last (H, W, C) to channel-first (C, H, W)
|
||||
- uint8 [0, 255] to float32 [0, 1]
|
||||
- Adds batch dimension if needed
|
||||
- Handles both single images and dictionaries of images
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessor(ObservationProcessor):
|
||||
"""
|
||||
Processes environment observations into the LeRobot format by handling both images and states.
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
Image processing:
|
||||
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
|
||||
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
|
||||
- Adds a batch dimension if missing
|
||||
- Supports single images and image dictionaries
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
processed_obs = {}
|
||||
|
||||
# Copy all observations first
|
||||
for key, value in observation.items():
|
||||
processed_obs[key] = value
|
||||
|
||||
# Handle pixels key if present
|
||||
pixels = observation.get("pixels")
|
||||
if pixels is not None:
|
||||
# Remove pixels from processed_obs since we'll replace it with processed images
|
||||
processed_obs.pop("pixels", None)
|
||||
# Determine image mapping
|
||||
if isinstance(pixels, dict):
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
|
||||
else:
|
||||
imgs = {OBS_IMAGE: pixels}
|
||||
|
||||
# Process each image
|
||||
for imgkey, img in imgs.items():
|
||||
processed_img = self._process_single_image(img)
|
||||
processed_obs[imgkey] = processed_img
|
||||
|
||||
# Return new transition with processed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
State processing:
|
||||
- Maps 'environment_state' to observation.environment_state
|
||||
- Maps 'agent_pos' to observation.state
|
||||
- Converts numpy arrays to tensors
|
||||
- Adds a batch dimension if missing
|
||||
"""
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""Process a single image array."""
|
||||
@@ -95,173 +69,89 @@ class ImageProcessor:
|
||||
|
||||
return img_tensor
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
pixels -> OBS_IMAGE,
|
||||
observation.pixels -> OBS_IMAGE,
|
||||
pixels.<cam> -> OBS_IMAGES.<cam>,
|
||||
observation.pixels.<cam> -> OBS_IMAGES.<cam>
|
||||
def _process_observation(self, observation):
|
||||
"""
|
||||
Processes both image and state observations.
|
||||
"""
|
||||
if "pixels" in features:
|
||||
features[OBS_IMAGE] = features.pop("pixels")
|
||||
if "observation.pixels" in features:
|
||||
features[OBS_IMAGE] = features.pop("observation.pixels")
|
||||
|
||||
prefixes = ("pixels.", "observation.pixels.")
|
||||
for key in list(features.keys()):
|
||||
for p in prefixes:
|
||||
if key.startswith(p):
|
||||
suffix = key[len(p) :]
|
||||
features[f"{OBS_IMAGES}.{suffix}"] = features.pop(key)
|
||||
break
|
||||
return features
|
||||
processed_obs = observation.copy()
|
||||
|
||||
if "pixels" in processed_obs:
|
||||
pixels = processed_obs.pop("pixels")
|
||||
|
||||
@dataclass
|
||||
class StateProcessor:
|
||||
"""Process state observations from environment format to policy format.
|
||||
if isinstance(pixels, dict):
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
|
||||
else:
|
||||
imgs = {OBS_IMAGE: pixels}
|
||||
|
||||
Handles:
|
||||
- environment_state -> observation.environment_state
|
||||
- agent_pos -> observation.state
|
||||
- Converts numpy arrays to tensors
|
||||
- Adds batch dimension if needed
|
||||
"""
|
||||
for imgkey, img in imgs.items():
|
||||
processed_obs[imgkey] = self._process_single_image(img)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
processed_obs = dict(observation) # Copy existing observations
|
||||
|
||||
# Process environment_state
|
||||
if "environment_state" in observation:
|
||||
env_state = torch.from_numpy(observation["environment_state"]).float()
|
||||
if "environment_state" in processed_obs:
|
||||
env_state_np = processed_obs.pop("environment_state")
|
||||
env_state = torch.from_numpy(env_state_np).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
processed_obs[OBS_ENV_STATE] = env_state
|
||||
# Remove original key
|
||||
del processed_obs["environment_state"]
|
||||
|
||||
# Process agent_pos
|
||||
if "agent_pos" in observation:
|
||||
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
|
||||
if "agent_pos" in processed_obs:
|
||||
agent_pos_np = processed_obs.pop("agent_pos")
|
||||
agent_pos = torch.from_numpy(agent_pos_np).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
processed_obs[OBS_STATE] = agent_pos
|
||||
# Remove original key
|
||||
del processed_obs["agent_pos"]
|
||||
|
||||
# Return new transition with processed observation
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = processed_obs
|
||||
return new_transition
|
||||
return processed_obs
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
environment_state -> OBS_ENV_STATE,
|
||||
agent_pos -> OBS_STATE,
|
||||
observation.environment_state -> OBS_ENV_STATE,
|
||||
observation.agent_pos -> OBS_STATE
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- environment_state -> OBS_ENV_STATE,
|
||||
- agent_pos -> OBS_STATE,
|
||||
- observation.environment_state -> OBS_ENV_STATE,
|
||||
- observation.agent_pos -> OBS_STATE
|
||||
"""
|
||||
pairs = (
|
||||
("environment_state", OBS_ENV_STATE),
|
||||
("agent_pos", OBS_STATE),
|
||||
)
|
||||
for old, new in pairs:
|
||||
if old in features:
|
||||
features[new] = features.pop(old)
|
||||
prefixed = f"observation.{old}"
|
||||
if prefixed in features:
|
||||
features[new] = features.pop(prefixed)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessor:
|
||||
"""Complete observation processor that combines image and state processing.
|
||||
|
||||
This processor replicates the functionality of the original preprocess_observation
|
||||
function but in a modular, composable way that fits into the pipeline architecture.
|
||||
"""
|
||||
|
||||
image_processor: ImageProcessor = field(default_factory=ImageProcessor)
|
||||
state_processor: StateProcessor = field(default_factory=StateProcessor)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# First process images
|
||||
transition = self.image_processor(transition)
|
||||
# Then process state
|
||||
transition = self.state_processor(transition)
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {
|
||||
"image_processor": self.image_processor.get_config(),
|
||||
"state_processor": self.state_processor.get_config(),
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary."""
|
||||
state = {}
|
||||
state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()})
|
||||
state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()})
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary."""
|
||||
image_state = {
|
||||
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
|
||||
}
|
||||
state_state = {
|
||||
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
|
||||
}
|
||||
|
||||
self.image_processor.load_state_dict(image_state)
|
||||
self.state_processor.load_state_dict(state_state)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state."""
|
||||
self.image_processor.reset()
|
||||
self.state_processor.reset()
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features = self.image_processor.feature_contract(features)
|
||||
features = self.state_processor.feature_contract(features)
|
||||
exact_pairs = {
|
||||
"pixels": OBS_IMAGE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
"agent_pos": OBS_STATE,
|
||||
}
|
||||
|
||||
prefix_pairs = {
|
||||
"pixels.": f"{OBS_IMAGES}.",
|
||||
}
|
||||
|
||||
for key in list(features.keys()):
|
||||
matched_prefix = False
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
|
||||
if key.startswith(old_prefix):
|
||||
suffix = key[len(old_prefix) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
|
||||
if matched_prefix:
|
||||
continue
|
||||
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
if key in features:
|
||||
features[new] = features.pop(key)
|
||||
break
|
||||
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user