[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-07-02 15:31:15 +00:00
committed by Adil Zouitine
parent f6c7287ae7
commit 769f531603
9 changed files with 485 additions and 475 deletions

View File

@@ -13,16 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .pipeline import RobotPipeline, PipelineStep, EnvTransition
from .observation_processor import (
ImageProcessor,
StateProcessor,
ObservationProcessor,
StateProcessor,
)
from .pipeline import EnvTransition, PipelineStep, RobotPipeline
__all__ = [
"RobotPipeline",
"PipelineStep",
"PipelineStep",
"EnvTransition",
"ImageProcessor",
"StateProcessor",

View File

@@ -15,35 +15,36 @@
# limitations under the License.
from __future__ import annotations
from typing import Any
from dataclasses import dataclass, field
from typing import Any
import einops
import numpy as np
import torch
import einops
from torch import Tensor
from lerobot.processor.pipeline import EnvTransition, PipelineStep, TransitionIndex
from lerobot.processor.pipeline import EnvTransition, TransitionIndex
@dataclass
class ImageProcessor:
"""Process image observations from environment format to policy format.
Converts images from:
- Channel-last (H, W, C) to channel-first (C, H, W)
- uint8 [0, 255] to float32 [0, 1]
- Adds batch dimension if needed
- Handles both single images and dictionaries of images
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
if observation is None:
return transition
processed_obs = {}
# Handle pixels key
if "pixels" in observation:
if isinstance(observation["pixels"], dict):
@@ -54,12 +55,12 @@ class ImageProcessor:
for imgkey, img in imgs.items():
processed_img = self._process_single_image(img)
processed_obs[imgkey] = processed_img
# Copy other observations unchanged
for key, value in observation.items():
if key != "pixels":
processed_obs[key] = value
# Return new transition with processed observation
return (
processed_obs,
@@ -70,44 +71,44 @@ class ImageProcessor:
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
def _process_single_image(self, img: np.ndarray) -> Tensor:
"""Process a single image array."""
# Convert to tensor
img_tensor = torch.from_numpy(img)
# Add batch dimension if needed
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
# Validate image format
_, h, w, c = img_tensor.shape
if not (c < h and c < w):
raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}")
if img_tensor.dtype != torch.uint8:
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
# Convert to channel-first format
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
# Convert to float32 and normalize to [0, 1]
img_tensor = img_tensor.type(torch.float32) / 255.0
return img_tensor
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
@@ -116,22 +117,22 @@ class ImageProcessor:
@dataclass
class StateProcessor:
"""Process state observations from environment format to policy format.
Handles:
- environment_state -> observation.environment_state
- agent_pos -> observation.state
- Converts numpy arrays to tensors
- Adds batch dimension if needed
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition[TransitionIndex.OBSERVATION]
if observation is None:
return transition
processed_obs = dict(observation) # Copy existing observations
# Process environment_state
if "environment_state" in observation:
env_state = torch.from_numpy(observation["environment_state"]).float()
@@ -140,16 +141,16 @@ class StateProcessor:
processed_obs["observation.environment_state"] = env_state
# Remove original key
del processed_obs["environment_state"]
# Process agent_pos
if "agent_pos" in observation:
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
processed_obs["observation.state"] = agent_pos
# Remove original key
# Remove original key
del processed_obs["agent_pos"]
# Return new transition with processed observation
return (
processed_obs,
@@ -160,19 +161,19 @@ class StateProcessor:
transition[TransitionIndex.INFO],
transition[TransitionIndex.COMPLEMENTARY_DATA],
)
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary (empty for this processor)."""
return {}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary (no-op for this processor)."""
pass
def reset(self) -> None:
"""Reset processor state (no-op for this processor)."""
pass
@@ -181,43 +182,47 @@ class StateProcessor:
@dataclass
class ObservationProcessor:
"""Complete observation processor that combines image and state processing.
This processor replicates the functionality of the original preprocess_observation
function but in a modular, composable way that fits into the pipeline architecture.
"""
image_processor: ImageProcessor = field(default_factory=ImageProcessor)
state_processor: StateProcessor = field(default_factory=StateProcessor)
def __call__(self, transition: EnvTransition) -> EnvTransition:
# First process images
transition = self.image_processor(transition)
# Then process state
transition = self.state_processor(transition)
return transition
def get_config(self) -> dict[str, Any]:
"""Return configuration for serialization."""
return {
"image_processor": self.image_processor.get_config(),
"state_processor": self.state_processor.get_config(),
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return state dictionary."""
state = {}
state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()})
state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()})
return state
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load state dictionary."""
image_state = {k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")}
state_state = {k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")}
image_state = {
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
}
state_state = {
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
}
self.image_processor.load_state_dict(image_state)
self.state_processor.load_state_dict(state_state)
def reset(self) -> None:
"""Reset processor state."""
self.image_processor.reset()

View File

@@ -14,19 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os, json
from typing import Any, Dict, Sequence, Iterable, Protocol, Optional, Tuple, Callable, Union
import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from enum import IntEnum
import numpy as np
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
import torch
from huggingface_hub import hf_hub_download, ModelHubMixin
from safetensors.torch import save_file, load_file
from huggingface_hub import ModelHubMixin, hf_hub_download
from safetensors.torch import load_file, save_file
class TransitionIndex(IntEnum):
"""Explicit indices for EnvTransition tuple components."""
OBSERVATION = 0
ACTION = 1
REWARD = 2
@@ -38,29 +41,28 @@ class TransitionIndex(IntEnum):
# (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[
Any| None, # observation
Any| None, # action
float| None, # reward
bool| None, # done
bool| None, # truncated
Dict[str, Any]| None, # info
Dict[str, Any]| None, # complementary_data
Any | None, # observation
Any | None, # action
float | None, # reward
bool | None, # done
bool | None, # truncated
Dict[str, Any] | None, # info
Dict[str, Any] | None, # complementary_data
]
class PipelineStep(Protocol):
"""Structural typing interface for a single pipeline step.
A step is any callable accepting a full `EnvTransition` tuple and
returning a (possibly modified) tuple of the same structure. Implementers
are encouraged—but not required—to expose the optional helper methods
listed below. When present, these hooks let `RobotPipeline`
automatically serialise the step's configuration and learnable state using
a safe-to-share JSON + SafeTensors format.
Optional helper protocol:
* ``get_config() -> Dict[str, Any]`` User-defined JSON-serializable
* ``get_config() -> Dict[str, Any]`` User-defined JSON-serializable
configuration and state. YOU decide what to save here. This is where all
non-tensor state goes (e.g., name, counter, threshold, window_size).
The config dict will be passed to your class constructor when loading.
@@ -70,7 +72,7 @@ class PipelineStep(Protocol):
* ``load_state_dict(state)`` Inverse of ``state_dict``. Receives a dict
containing torch tensors only.
* ``reset()`` Clear internal buffers at episode boundaries.
Example separation:
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
- state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)}
@@ -78,11 +80,11 @@ class PipelineStep(Protocol):
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
def get_config(self) -> Dict[str, Any]: ...
def get_config(self) -> dict[str, Any]: ...
def state_dict(self) -> Dict[str, torch.Tensor]: ...
def state_dict(self) -> dict[str, torch.Tensor]: ...
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: ...
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
def reset(self) -> None: ...
@@ -120,24 +122,25 @@ class RobotPipeline(ModelHubMixin):
pipe.push_to_hub("my-org/cartpole_pipe")
loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe")
"""
steps: Sequence[PipelineStep] = field(default_factory=list)
name: str = "RobotPipeline"
seed: Optional[int] = None
seed: int | None = None
# Pipeline-level hooks
# A hook can optionally return a modified transition. If it returns
# ``None`` the current value is left untouched.
before_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
default_factory=list, repr=False
)
after_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
default_factory=list, repr=False
)
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Run *transition* through every step, firing hooks on the way."""
# Basic validation with helpful error message
if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError(
@@ -168,23 +171,23 @@ class RobotPipeline(ModelHubMixin):
yield transition
_CFG_NAME = "pipeline.json"
def _save_pretrained(self, destination_path: str, **kwargs):
"""Internal save method for ModelHubMixin compatibility."""
self.save_pretrained(destination_path)
def save_pretrained(self, destination_path: str, **kwargs):
"""Serialize the pipeline definition and parameters to *destination_path*."""
os.makedirs(destination_path, exist_ok=True)
config: Dict[str, Any] = {
config: dict[str, Any] = {
"name": self.name,
"seed": self.seed,
"steps": [],
}
for step_index, pipeline_step in enumerate(self.steps):
step_entry: Dict[str, Any] = {
step_entry: dict[str, Any] = {
"class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}",
}
@@ -204,20 +207,20 @@ class RobotPipeline(ModelHubMixin):
json.dump(config, file_pointer, indent=2)
@classmethod
def from_pretrained(cls, source: str) -> "RobotPipeline":
def from_pretrained(cls, source: str) -> RobotPipeline:
"""Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier)."""
if Path(source).is_dir():
# Local path - use it directly
base_path = Path(source)
with open(base_path / cls._CFG_NAME) as file_pointer:
config: Dict[str, Any] = json.load(file_pointer)
config: dict[str, Any] = json.load(file_pointer)
else:
# Hugging Face Hub - download all required files
# First download the config file
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
with open(config_path) as file_pointer:
config: Dict[str, Any] = json.load(file_pointer)
config: dict[str, Any] = json.load(file_pointer)
# Store downloaded files in the same directory as the config
base_path = Path(config_path).parent
@@ -234,7 +237,7 @@ class RobotPipeline(ModelHubMixin):
else:
# Hugging Face Hub - download the state file
state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model")
step_instance.load_state_dict(load_file(state_path))
steps.append(step_instance)
@@ -254,11 +257,11 @@ class RobotPipeline(ModelHubMixin):
return RobotPipeline(self.steps[idx], self.name, self.seed)
return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
"""Attach fn to be executed before every pipeline step."""
self.before_step_hooks.append(fn)
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
"""Attach fn to be executed after every pipeline step."""
self.after_step_hooks.append(fn)
@@ -274,26 +277,26 @@ class RobotPipeline(ModelHubMixin):
for fn in self.reset_hooks:
fn()
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> Dict[str, float]:
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> dict[str, float]:
"""Profile the execution time of each step for performance optimization."""
import time
profile_results = {}
for idx, pipeline_step in enumerate(self.steps):
step_name = f"step_{idx}_{pipeline_step.__class__.__name__}"
# Warm up
for _ in range(5):
_ = pipeline_step(transition)
# Time the step
start_time = time.perf_counter()
for _ in range(num_runs):
transition = pipeline_step(transition)
end_time = time.perf_counter()
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
profile_results[step_name] = avg_time
return profile_results
return profile_results