[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-02 15:31:15 +00:00
committed by Adil Zouitine
parent f6c7287ae7
commit 769f531603
9 changed files with 485 additions and 475 deletions

View File

@@ -16,10 +16,8 @@
import warnings
from typing import Any
import einops
import gymnasium as gym
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -29,32 +27,29 @@ from lerobot.utils.utils import get_channel_first_image_shape
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
This function uses the new pipeline system internally but maintains
backward compatibility with the original interface.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
# Create pipeline with observation processor
pipeline = RobotPipeline([ObservationProcessor()])
# Create transition tuple and process
transition = (observations, None, None, None, None, None, None)
processed_transition = pipeline(transition)
# Return processed observations
return processed_transition[TransitionIndex.OBSERVATION]
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to externalize normalization from policies)

View File

@@ -13,16 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .pipeline import RobotPipeline, PipelineStep, EnvTransition
from .observation_processor import (
ImageProcessor,
StateProcessor,
ObservationProcessor,
StateProcessor,
)
from .pipeline import EnvTransition, PipelineStep, RobotPipeline
__all__ = [
"RobotPipeline",
"PipelineStep",
"PipelineStep",
"EnvTransition",
"ImageProcessor",
"StateProcessor",

View File

@@ -15,35 +15,36 @@
# limitations under the License.
from __future__ import annotations
from typing import Any
from dataclasses import dataclass, field
from typing import Any
import einops
import numpy as np
import torch
import einops
from torch import Tensor
from lerobot.processor.pipeline import EnvTransition, PipelineStep, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, TransitionIndex
@dataclass
class ImageProcessor:
"""Process image observations from environment format to policy format.
Converts images from:
- Channel-last (H, W, C) to channel-first (C, H, W)
- uint8 [0, 255] to float32 [0, 1]
- Adds batch dimension if needed
- Handles both single images and dictionaries of images
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
if observation is None:
return transition
processed_obs = {}
# Handle pixels key
if "pixels" in observation:
if isinstance(observation["pixels"], dict):
@@ -54,12 +55,12 @@ class ImageProcessor:
for imgkey, img in imgs.items():
processed_img = self._process_single_image(img)
processed_obs[imgkey] = processed_img
# Copy other observations unchanged
for key, value in observation.items():
if key != "pixels":
processed_obs[key] = value
# Return new transition with processed observation
return (
processed_obs,
@@ -70,44 +71,44 @@ class ImageProcessor:
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
# Convert to tensor
img_tensor = torch.from_numpy(img)
# Add batch dimension if needed
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
# Validate image format
_, h, w, c = img_tensor.shape
if not (c < h and c < w):
raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}")
if img_tensor.dtype != torch.uint8:
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
# Convert to channel-first format
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
# Convert to float32 and normalize to [0, 1]
img_tensor = img_tensor.type(torch.float32) / 255.0
return img_tensor
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
@@ -116,22 +117,22 @@ class ImageProcessor:
@dataclass
class StateProcessor:
"""Process state observations from environment format to policy format.
Handles:
- environment_state -> observation.environment_state
- agent_pos -> observation.state
- Converts numpy arrays to tensors
- Adds batch dimension if needed
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
if observation is None:
return transition
processed_obs = dict(observation) # Copy existing observations
# Process environment_state
if "environment_state" in observation:
env_state = torch.from_numpy(observation["environment_state"]).float()
@@ -140,16 +141,16 @@ class StateProcessor:
processed_obs["observation.environment_state"] = env_state
# Remove original key
del processed_obs["environment_state"]
# Process agent_pos
if "agent_pos" in observation:
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
processed_obs["observation.state"] = agent_pos
# Remove original key
# Remove original key
del processed_obs["agent_pos"]
# Return new transition with processed observation
return (
processed_obs,
@@ -160,19 +161,19 @@ class StateProcessor:
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
@@ -181,43 +182,47 @@ class StateProcessor:
@dataclass
class ObservationProcessor:
"""Complete observation processor that combines image and state processing.
This processor replicates the functionality of the original preprocess_observation
function but in a modular, composable way that fits into the pipeline architecture.
"""
image_processor: ImageProcessor = field(default_factory=ImageProcessor)
state_processor: StateProcessor = field(default_factory=StateProcessor)
def __call__(self, transition: EnvTransition) -> EnvTransition:
# First process images
transition = self.image_processor(transition)
# Then process state
transition = self.state_processor(transition)
return transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {
"image_processor": self.image_processor.get_config(),
"state_processor": self.state_processor.get_config(),
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary."""
state = {}
state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()})
state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()})
return state
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary."""
image_state = {k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")}
state_state = {k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")}
image_state = {
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
}
state_state = {
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
}
self.image_processor.load_state_dict(image_state)
self.state_processor.load_state_dict(state_state)
def reset(self) -> None:
"""Reset processor state."""
self.image_processor.reset()

View File

@@ -14,19 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os, json
from typing import Any, Dict, Sequence, Iterable, Protocol, Optional, Tuple, Callable, Union
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from enum import IntEnum
import numpy as np
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
import torch
from huggingface_hub import hf_hub_download, ModelHubMixin
from safetensors.torch import save_file, load_file
from huggingface_hub import ModelHubMixin, hf_hub_download
from safetensors.torch import load_file, save_file
class TransitionIndex(IntEnum):
"""Explicit indices for EnvTransition tuple components."""
OBSERVATION = 0
ACTION = 1
REWARD = 2
@@ -38,29 +41,28 @@ class TransitionIndex(IntEnum):
# (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[
Any| None, # observation
Any| None, # action
float| None, # reward
bool| None, # done
bool| None, # truncated
Dict[str, Any]| None, # info
Dict[str, Any]| None, # complementary_data
Any | None, # observation
Any | None, # action
float | None, # reward
bool | None, # done
bool | None, # truncated
Dict[str, Any] | None, # info
Dict[str, Any] | None, # complementary_data
]
class PipelineStep(Protocol):
"""Structural typing interface for a single pipeline step.
A step is any callable accepting a full `EnvTransition` tuple and
returning a (possibly modified) tuple of the same structure. Implementers
are encouraged—but not required—to expose the optional helper methods
listed below. When present, these hooks let `RobotPipeline`
automatically serialise the step's configuration and learnable state using
a safe-to-share JSON + SafeTensors format.
Optional helper protocol:
* ``get_config() -> Dict[str, Any]`` User-defined JSON-serializable
* ``get_config() -> Dict[str, Any]`` User-defined JSON-serializable
configuration and state. YOU decide what to save here. This is where all
non-tensor state goes (e.g., name, counter, threshold, window_size).
The config dict will be passed to your class constructor when loading.
@@ -70,7 +72,7 @@ class PipelineStep(Protocol):
* ``load_state_dict(state)`` Inverse of ``state_dict``. Receives a dict
containing torch tensors only.
* ``reset()`` Clear internal buffers at episode boundaries.
Example separation:
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
- state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)}
@@ -78,11 +80,11 @@ class PipelineStep(Protocol):
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
def get_config(self) -> Dict[str, Any]: ...
def get_config(self) -> dict[str, Any]: ...
def state_dict(self) -> Dict[str, torch.Tensor]: ...
def state_dict(self) -> dict[str, torch.Tensor]: ...
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: ...
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
def reset(self) -> None: ...
@@ -120,24 +122,25 @@ class RobotPipeline(ModelHubMixin):
pipe.push_to_hub("my-org/cartpole_pipe")
loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe")
"""
steps: Sequence[PipelineStep] = field(default_factory=list)
name: str = "RobotPipeline"
seed: Optional[int] = None
seed: int | None = None
# Pipeline-level hooks
# A hook can optionally return a modified transition. If it returns
# ``None`` the current value is left untouched.
before_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
default_factory=list, repr=False
)
after_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
default_factory=list, repr=False
)
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Run *transition* through every step, firing hooks on the way."""
# Basic validation with helpful error message
if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError(
@@ -168,23 +171,23 @@ class RobotPipeline(ModelHubMixin):
yield transition
_CFG_NAME = "pipeline.json"
def _save_pretrained(self, destination_path: str, **kwargs):
"""Internal save method for ModelHubMixin compatibility."""
self.save_pretrained(destination_path)
def save_pretrained(self, destination_path: str, **kwargs):
"""Serialize the pipeline definition and parameters to *destination_path*."""
os.makedirs(destination_path, exist_ok=True)
config: Dict[str, Any] = {
config: dict[str, Any] = {
"name": self.name,
"seed": self.seed,
"steps": [],
}
for step_index, pipeline_step in enumerate(self.steps):
step_entry: Dict[str, Any] = {
step_entry: dict[str, Any] = {
"class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}",
}
@@ -204,20 +207,20 @@ class RobotPipeline(ModelHubMixin):
json.dump(config, file_pointer, indent=2)
@classmethod
def from_pretrained(cls, source: str) -> "RobotPipeline":
def from_pretrained(cls, source: str) -> RobotPipeline:
"""Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier)."""
if Path(source).is_dir():
# Local path - use it directly
base_path = Path(source)
with open(base_path / cls._CFG_NAME) as file_pointer:
config: Dict[str, Any] = json.load(file_pointer)
config: dict[str, Any] = json.load(file_pointer)
else:
# Hugging Face Hub - download all required files
# First download the config file
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
with open(config_path) as file_pointer:
config: Dict[str, Any] = json.load(file_pointer)
config: dict[str, Any] = json.load(file_pointer)
# Store downloaded files in the same directory as the config
base_path = Path(config_path).parent
@@ -234,7 +237,7 @@ class RobotPipeline(ModelHubMixin):
else:
# Hugging Face Hub - download the state file
state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model")
step_instance.load_state_dict(load_file(state_path))
steps.append(step_instance)
@@ -254,11 +257,11 @@ class RobotPipeline(ModelHubMixin):
return RobotPipeline(self.steps[idx], self.name, self.seed)
return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
"""Attach fn to be executed before every pipeline step."""
self.before_step_hooks.append(fn)
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
"""Attach fn to be executed after every pipeline step."""
self.after_step_hooks.append(fn)
@@ -274,26 +277,26 @@ class RobotPipeline(ModelHubMixin):
for fn in self.reset_hooks:
fn()
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> Dict[str, float]:
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> dict[str, float]:
"""Profile the execution time of each step for performance optimization."""
import time
profile_results = {}
for idx, pipeline_step in enumerate(self.steps):
step_name = f"step_{idx}_{pipeline_step.__class__.__name__}"
# Warm up
for _ in range(5):
_ = pipeline_step(transition)
# Time the step
start_time = time.perf_counter()
for _ in range(num_runs):
transition = pipeline_step(transition)
end_time = time.perf_counter()
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
profile_results[step_name] = avg_time
return profile_results
return profile_results

View File

@@ -69,12 +69,11 @@ from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.envs.factory import make_env
from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -130,7 +129,7 @@ def rollout(
observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)
# Create observation processing pipeline
obs_pipeline = RobotPipeline([ObservationProcessor()])

View File

@@ -22,9 +22,8 @@ from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.envs.factory import make_env, make_env_config
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
@@ -50,7 +49,7 @@ def test_factory(env_name):
cfg = make_env_config(env_name)
env = make_env(cfg, n_envs=1)
obs, _ = env.reset()
# Process observation using pipeline
obs_pipeline = RobotPipeline([ObservationProcessor()])
transition = (obs, None, None, None, None, None, None)

View File

@@ -30,8 +30,6 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.utils import cycle, dataset_to_policy_features
from lerobot.envs.factory import make_env, make_env_config
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.policies.factory import (
@@ -41,6 +39,8 @@ from lerobot.policies.factory import (
)
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel

View File

@@ -20,447 +20,447 @@ import torch
from lerobot.processor.observation_processor import (
ImageProcessor,
StateProcessor,
ObservationProcessor,
StateProcessor,
)
from lerobot.processor.pipeline import EnvTransition
def test_process_single_image():
"""Test processing a single image."""
processor = ImageProcessor()
# Create a mock image (H, W, C) format, uint8
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that the image was processed correctly
assert "observation.image" in processed_obs
processed_img = processed_obs["observation.image"]
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
assert processed_img.shape == (1, 3, 64, 64)
# Check dtype and range
assert processed_img.dtype == torch.float32
assert processed_img.min() >= 0.0
assert processed_img.max() <= 1.0
def test_process_image_dict():
"""Test processing multiple images in a dictionary."""
processor = ImageProcessor()
# Create mock images
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
observation = {
"pixels": {
"camera1": image1,
"camera2": image2
}
}
observation = {"pixels": {"camera1": image1, "camera2": image2}}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that both images were processed
assert "observation.images.camera1" in processed_obs
assert "observation.images.camera2" in processed_obs
# Check shapes
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
def test_process_batched_image():
"""Test processing already batched images."""
processor = ImageProcessor()
# Create a batched image (B, H, W, C)
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that batch dimension is preserved
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
def test_invalid_image_format():
"""Test error handling for invalid image formats."""
processor = ImageProcessor()
# Test wrong channel order (channels first)
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
with pytest.raises(ValueError, match="Expected channel-last images"):
processor(transition)
def test_invalid_image_dtype():
"""Test error handling for invalid image dtype."""
processor = ImageProcessor()
# Test wrong dtype
image = np.random.rand(64, 64, 3).astype(np.float32)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
processor(transition)
def test_no_pixels_in_observation():
"""Test processor when no pixels are in observation."""
processor = ImageProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Should preserve other data unchanged
assert "other_data" in processed_obs
np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3]))
def test_none_observation():
"""Test processor with None observation."""
processor = ImageProcessor()
transition = (None, None, None, None, None, None, None)
result = processor(transition)
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = ImageProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
def test_process_environment_state():
"""Test processing environment_state."""
processor = StateProcessor()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
assert "environment_state" not in processed_obs
processed_state = processed_obs["observation.environment_state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
"""Test processing environment_state."""
processor = StateProcessor()
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
observation = {"environment_state": env_state}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that environment_state was renamed and processed
assert "observation.environment_state" in processed_obs
assert "environment_state" not in processed_obs
processed_state = processed_obs["observation.environment_state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
def test_process_agent_pos():
"""Test processing agent_pos."""
processor = StateProcessor()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
processed_state = processed_obs["observation.state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
"""Test processing agent_pos."""
processor = StateProcessor()
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that agent_pos was renamed and processed
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
processed_state = processed_obs["observation.state"]
assert processed_state.shape == (1, 3) # Batch dimension added
assert processed_state.dtype == torch.float32
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
def test_process_batched_states():
"""Test processing already batched states."""
processor = StateProcessor()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
observation = {
"environment_state": env_state,
"agent_pos": agent_pos
}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
assert processed_obs["observation.state"].shape == (2, 2)
"""Test processing already batched states."""
processor = StateProcessor()
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
observation = {"environment_state": env_state, "agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that batch dimensions are preserved
assert processed_obs["observation.environment_state"].shape == (2, 2)
assert processed_obs["observation.state"].shape == (2, 2)
def test_process_both_states():
"""Test processing both environment_state and agent_pos."""
processor = StateProcessor()
env_state = np.array([1.0, 2.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
observation = {
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "keep_me"
}
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that both states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "keep_me"
def test_no_states_in_observation():
"""Test processor when no states are in observation."""
processor = StateProcessor()
observation = {"other_data": np.array([1, 2, 3])}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Should preserve data unchanged
assert processed_obs == observation
def test_none_observation():
"""Test processor with None observation."""
processor = StateProcessor()
transition = (None, None, None, None, None, None, None)
result = processor(transition)
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = StateProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
def test_complete_observation_processing():
"""Test processing a complete observation with both images and states."""
processor = ObservationProcessor()
# Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "preserve_me"
}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that image was processed
assert "observation.image" in processed_obs
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
# Check that states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "pixels" not in processed_obs
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "preserve_me"
"""Test processing a complete observation with both images and states."""
processor = ObservationProcessor()
# Create mock data
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos,
"other_data": "preserve_me",
}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
# Check that image was processed
assert "observation.image" in processed_obs
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
# Check that states were processed
assert "observation.environment_state" in processed_obs
assert "observation.state" in processed_obs
# Check that original keys were removed
assert "pixels" not in processed_obs
assert "environment_state" not in processed_obs
assert "agent_pos" not in processed_obs
# Check that other data was preserved
assert processed_obs["other_data"] == "preserve_me"
def test_image_only_processing():
"""Test processing observation with only images."""
processor = ObservationProcessor()
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
"""Test processing observation with only images."""
processor = ObservationProcessor()
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
observation = {"pixels": image}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
assert "observation.image" in processed_obs
assert len(processed_obs) == 1
def test_state_only_processing():
"""Test processing observation with only states."""
processor = ObservationProcessor()
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
"""Test processing observation with only states."""
processor = ObservationProcessor()
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"agent_pos": agent_pos}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
assert "observation.state" in processed_obs
assert "agent_pos" not in processed_obs
def test_empty_observation():
"""Test processing empty observation."""
processor = ObservationProcessor()
observation = {}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
assert processed_obs == {}
"""Test processing empty observation."""
processor = ObservationProcessor()
observation = {}
transition = (observation, None, None, None, None, None, None)
result = processor(transition)
processed_obs = result[0]
assert processed_obs == {}
def test_none_observation():
"""Test processing None observation."""
processor = ObservationProcessor()
transition = (None, None, None, None, None, None, None)
result = processor(transition)
assert result == transition
"""Test processing None observation."""
processor = ObservationProcessor()
transition = (None, None, None, None, None, None, None)
result = processor(transition)
assert result == transition
def test_serialization_methods():
"""Test serialization methods."""
processor = ObservationProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
assert "image_processor" in config
assert "state_processor" in config
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
"""Test serialization methods."""
processor = ObservationProcessor()
# Test get_config
config = processor.get_config()
assert isinstance(config, dict)
assert "image_processor" in config
assert "state_processor" in config
# Test state_dict
state = processor.state_dict()
assert isinstance(state, dict)
# Test load_state_dict (should not raise)
processor.load_state_dict(state)
# Test reset (should not raise)
processor.reset()
def test_custom_sub_processors():
"""Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor()
state_proc = StateProcessor()
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
# Should use the provided processors
assert processor.image_processor is image_proc
assert processor.state_processor is state_proc
"""Test ObservationProcessor with custom sub-processors."""
image_proc = ImageProcessor()
state_proc = StateProcessor()
processor = ObservationProcessor(image_processor=image_proc, state_processor=state_proc)
# Should use the provided processors
assert processor.image_processor is image_proc
assert processor.state_processor is state_proc
def test_equivalent_to_original_function():
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
# Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {
"pixels": image,
"environment_state": env_state,
"agent_pos": agent_pos
}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
# Import the original function for comparison
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
# Create test data similar to what the original function expects
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
def test_equivalent_with_image_dict():
"""Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
# Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {
"pixels": {"cam1": image1, "cam2": image2},
"agent_pos": agent_pos
}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])
"""Test equivalence with dictionary of images."""
from lerobot.envs.utils import preprocess_observation
processor = ObservationProcessor()
# Create test data with multiple cameras
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos}
# Process with original function
original_result = preprocess_observation(observation)
# Process with new processor
transition = (observation, None, None, None, None, None, None)
processor_result = processor(transition)[0]
# Compare results
assert set(original_result.keys()) == set(processor_result.keys())
for key in original_result:
torch.testing.assert_close(original_result[key], processor_result[key])

View File

@@ -16,87 +16,86 @@
import json
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from dataclasses import dataclass
import numpy as np
import pytest
import torch
from lerobot.processor.pipeline import RobotPipeline, EnvTransition, PipelineStep
from lerobot.processor.pipeline import EnvTransition, RobotPipeline
@dataclass
class MockStep:
"""Mock pipeline step for testing - demonstrates best practices.
This example shows the proper separation:
- JSON-serializable attributes (name, counter) go in get_config()
- Only torch tensors go in state_dict()
Note: The counter is part of the configuration, so it will be restored
when the step is recreated from config during loading.
"""
name: str = "mock_step"
counter: int = 0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Add a counter to the complementary_data."""
obs, action, reward, done, truncated, info, comp_data = transition
if comp_data is None:
comp_data = {}
else:
comp_data = dict(comp_data) # Make a copy
comp_data[f"{self.name}_counter"] = self.counter
self.counter += 1
return (obs, action, reward, done, truncated, info, comp_data)
def get_config(self) -> dict[str, Any]:
# Return all JSON-serializable attributes that should be persisted
# These will be passed to __init__ when loading
return {"name": self.name, "counter": self.counter}
def state_dict(self) -> dict[str, torch.Tensor]:
# Only return torch tensors (empty in this case since we have no tensor state)
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
# No tensor state to load
pass
def reset(self) -> None:
self.counter = 0
@dataclass
@dataclass
class MockStepWithoutOptionalMethods:
"""Mock step that only implements the required __call__ method."""
multiplier: float = 2.0
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Multiply reward by multiplier."""
obs, action, reward, done, truncated, info, comp_data = transition
if reward is not None:
reward = reward * self.multiplier
return (obs, action, reward, done, truncated, info, comp_data)
@dataclass
class MockStepWithTensorState:
"""Mock step demonstrating mixed JSON attributes and tensor state."""
name: str = "tensor_step"
learning_rate: float = 0.01
window_size: int = 10
def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10):
self.name = name
self.learning_rate = learning_rate
@@ -104,19 +103,19 @@ class MockStepWithTensorState:
# Tensor state
self.running_mean = torch.zeros(window_size)
self.running_count = torch.tensor(0)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Update running statistics."""
obs, action, reward, done, truncated, info, comp_data = transition
if reward is not None:
# Update running mean
idx = self.running_count % self.window_size
self.running_mean[idx] = reward
self.running_count += 1
return transition
def get_config(self) -> dict[str, Any]:
# Only JSON-serializable attributes
return {
@@ -124,18 +123,18 @@ class MockStepWithTensorState:
"learning_rate": self.learning_rate,
"window_size": self.window_size,
}
def state_dict(self) -> dict[str, torch.Tensor]:
# Only tensor state
return {
"running_mean": self.running_mean,
"running_count": self.running_count,
}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
self.running_mean = state["running_mean"]
self.running_count = state["running_count"]
def reset(self) -> None:
self.running_mean.zero_()
self.running_count.zero_()
@@ -144,265 +143,275 @@ class MockStepWithTensorState:
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = RobotPipeline()
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert result == transition
assert len(pipeline) == 0
def test_single_step_pipeline():
"""Test pipeline with a single step."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert len(pipeline) == 1
assert result[6]["test_step_counter"] == 0 # complementary_data
# Call again to test counter increment
result = pipeline(transition)
assert result[6]["test_step_counter"] == 1
def test_multiple_steps_pipeline():
"""Test pipeline with multiple steps."""
step1 = MockStep("step1")
step2 = MockStep("step2")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert len(pipeline) == 2
assert result[6]["step1_counter"] == 0
assert result[6]["step2_counter"] == 0
def test_invalid_transition_format():
"""Test pipeline with invalid transition format."""
pipeline = RobotPipeline([MockStep()])
# Test with wrong number of elements
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline((None, None, 0.0)) # Only 3 elements
# Test with wrong type
with pytest.raises(ValueError, match="EnvTransition must be a 7-tuple"):
pipeline("not a tuple")
def test_step_through():
"""Test step_through method."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
results = list(pipeline.step_through(transition))
assert len(results) == 3 # Original + 2 steps
assert results[0] == transition # Original
assert "step1_counter" in results[1][6] # After step1
assert "step2_counter" in results[2][6] # After step2
def test_indexing():
"""Test pipeline indexing."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
# Test integer indexing
assert pipeline[0] is step1
assert pipeline[1] is step2
# Test slice indexing
sub_pipeline = pipeline[0:1]
assert isinstance(sub_pipeline, RobotPipeline)
assert len(sub_pipeline) == 1
assert sub_pipeline[0] is step1
def test_hooks():
"""Test before/after step hooks."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
before_calls = []
after_calls = []
def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx)
return transition
def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx)
return transition
pipeline.register_before_step_hook(before_hook)
pipeline.register_after_step_hook(after_hook)
transition = (None, None, 0.0, False, False, {}, {})
pipeline(transition)
assert before_calls == [0]
assert after_calls == [0]
def test_hook_modification():
"""Test that hooks can modify transitions."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
def modify_reward_hook(idx: int, transition: EnvTransition):
obs, action, reward, done, truncated, info, comp_data = transition
return (obs, action, 42.0, done, truncated, info, comp_data)
pipeline.register_before_step_hook(modify_reward_hook)
transition = (None, None, 0.0, False, False, {}, {})
result = pipeline(transition)
assert result[2] == 42.0 # reward modified by hook
def test_reset():
"""Test pipeline reset functionality."""
step = MockStep("test_step")
pipeline = RobotPipeline([step])
reset_called = []
def reset_hook():
reset_called.append(True)
pipeline.register_reset_hook(reset_hook)
# Make some calls to increment counter
transition = (None, None, 0.0, False, False, {}, {})
pipeline(transition)
pipeline(transition)
assert step.counter == 2
# Reset should reset step and call hook
pipeline.reset()
assert step.counter == 0
assert len(reset_called) == 1
def test_profile_steps():
"""Test step profiling functionality."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotPipeline([step1, step2])
transition = (None, None, 0.0, False, False, {}, {})
profile_results = pipeline.profile_steps(transition, num_runs=10)
assert len(profile_results) == 2
assert "step_0_MockStep" in profile_results
assert "step_1_MockStep" in profile_results
assert all(isinstance(time, float) and time >= 0 for time in profile_results.values())
def test_save_and_load_pretrained():
"""Test saving and loading pipeline.
This test demonstrates that JSON-serializable attributes (like counter)
are saved in the config and restored when the step is recreated.
"""
step1 = MockStep("step1")
step2 = MockStep("step2")
# Increment counters to have some state
step1.counter = 5
step2.counter = 10
pipeline = RobotPipeline([step1, step2], name="TestPipeline", seed=42)
with tempfile.TemporaryDirectory() as tmp_dir:
# Save pipeline
pipeline.save_pretrained(tmp_dir)
# Check files were created
config_path = Path(tmp_dir) / "pipeline.json"
assert config_path.exists()
# Check config content
with open(config_path) as f:
config = json.load(f)
assert config["name"] == "TestPipeline"
assert config["seed"] == 42
assert len(config["steps"]) == 2
# Verify counters are saved in config, not in separate state files
assert config["steps"][0]["config"]["counter"] == 5
assert config["steps"][1]["config"]["counter"] == 10
# Load pipeline
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
assert loaded_pipeline.name == "TestPipeline"
assert loaded_pipeline.seed == 42
assert len(loaded_pipeline) == 2
# Check that counter was restored from config
assert loaded_pipeline.steps[0].counter == 5
assert loaded_pipeline.steps[1].counter == 10
def test_step_without_optional_methods():
"""Test pipeline with steps that don't implement optional methods."""
step = MockStepWithoutOptionalMethods(multiplier=3.0)
pipeline = RobotPipeline([step])
transition = (None, None, 2.0, False, False, {}, {})
result = pipeline(transition)
assert result[2] == 6.0 # 2.0 * 3.0
# Reset should work even if step doesn't implement reset
pipeline.reset()
# Save/load should work even without optional methods
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
assert len(loaded_pipeline) == 1
def test_mixed_json_and_tensor_state():
"""Test step with both JSON attributes and tensor state."""
step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5)
pipeline = RobotPipeline([step])
# Process some transitions with rewards
for i in range(10):
transition = (None, None, float(i), False, False, {}, {})
pipeline(transition)
# Check state
assert step.running_count.item() == 10
assert step.learning_rate == 0.05
# Save and load
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
# Check that both config and state files were created
config_path = Path(tmp_dir) / "pipeline.json"
config_path = Path(tmp_dir) / "pipeline.json"
state_path = Path(tmp_dir) / "step_0.safetensors"
assert config_path.exists()
assert state_path.exists()
# Load and verify
loaded_pipeline = RobotPipeline.from_pretrained(tmp_dir)
loaded_step = loaded_pipeline.steps[0]
# Check JSON attributes were restored
assert loaded_step.name == "stats"
assert loaded_step.learning_rate == 0.05
assert loaded_step.window_size == 5
# Check tensor state was restored
assert loaded_step.running_count.item() == 10
assert torch.allclose(loaded_step.running_mean, step.running_mean)