mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
f6c7287ae7
commit
769f531603
@@ -13,16 +13,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .pipeline import RobotPipeline, PipelineStep, EnvTransition
|
||||
from .observation_processor import (
|
||||
ImageProcessor,
|
||||
StateProcessor,
|
||||
ObservationProcessor,
|
||||
StateProcessor,
|
||||
)
|
||||
from .pipeline import EnvTransition, PipelineStep, RobotPipeline
|
||||
|
||||
__all__ = [
|
||||
"RobotPipeline",
|
||||
"PipelineStep",
|
||||
"PipelineStep",
|
||||
"EnvTransition",
|
||||
"ImageProcessor",
|
||||
"StateProcessor",
|
||||
|
||||
@@ -15,35 +15,36 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import einops
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.processor.pipeline import EnvTransition, PipelineStep, TransitionIndex
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionIndex
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageProcessor:
|
||||
"""Process image observations from environment format to policy format.
|
||||
|
||||
|
||||
Converts images from:
|
||||
- Channel-last (H, W, C) to channel-first (C, H, W)
|
||||
- uint8 [0, 255] to float32 [0, 1]
|
||||
- Adds batch dimension if needed
|
||||
- Handles both single images and dictionaries of images
|
||||
"""
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
|
||||
processed_obs = {}
|
||||
|
||||
|
||||
# Handle pixels key
|
||||
if "pixels" in observation:
|
||||
if isinstance(observation["pixels"], dict):
|
||||
@@ -54,12 +55,12 @@ class ImageProcessor:
|
||||
for imgkey, img in imgs.items():
|
||||
processed_img = self._process_single_image(img)
|
||||
processed_obs[imgkey] = processed_img
|
||||
|
||||
|
||||
# Copy other observations unchanged
|
||||
for key, value in observation.items():
|
||||
if key != "pixels":
|
||||
processed_obs[key] = value
|
||||
|
||||
|
||||
# Return new transition with processed observation
|
||||
return (
|
||||
processed_obs,
|
||||
@@ -70,44 +71,44 @@ class ImageProcessor:
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""Process a single image array."""
|
||||
# Convert to tensor
|
||||
img_tensor = torch.from_numpy(img)
|
||||
|
||||
|
||||
# Add batch dimension if needed
|
||||
if img_tensor.ndim == 3:
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
# Validate image format
|
||||
_, h, w, c = img_tensor.shape
|
||||
if not (c < h and c < w):
|
||||
raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}")
|
||||
|
||||
|
||||
if img_tensor.dtype != torch.uint8:
|
||||
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
|
||||
|
||||
|
||||
# Convert to channel-first format
|
||||
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
|
||||
|
||||
|
||||
# Convert to float32 and normalize to [0, 1]
|
||||
img_tensor = img_tensor.type(torch.float32) / 255.0
|
||||
|
||||
|
||||
return img_tensor
|
||||
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
@@ -116,22 +117,22 @@ class ImageProcessor:
|
||||
@dataclass
|
||||
class StateProcessor:
|
||||
"""Process state observations from environment format to policy format.
|
||||
|
||||
|
||||
Handles:
|
||||
- environment_state -> observation.environment_state
|
||||
- agent_pos -> observation.state
|
||||
- Converts numpy arrays to tensors
|
||||
- Adds batch dimension if needed
|
||||
"""
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
|
||||
if observation is None:
|
||||
return transition
|
||||
|
||||
|
||||
processed_obs = dict(observation) # Copy existing observations
|
||||
|
||||
|
||||
# Process environment_state
|
||||
if "environment_state" in observation:
|
||||
env_state = torch.from_numpy(observation["environment_state"]).float()
|
||||
@@ -140,16 +141,16 @@ class StateProcessor:
|
||||
processed_obs["observation.environment_state"] = env_state
|
||||
# Remove original key
|
||||
del processed_obs["environment_state"]
|
||||
|
||||
|
||||
# Process agent_pos
|
||||
if "agent_pos" in observation:
|
||||
agent_pos = torch.from_numpy(observation["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
processed_obs["observation.state"] = agent_pos
|
||||
# Remove original key
|
||||
# Remove original key
|
||||
del processed_obs["agent_pos"]
|
||||
|
||||
|
||||
# Return new transition with processed observation
|
||||
return (
|
||||
processed_obs,
|
||||
@@ -160,19 +161,19 @@ class StateProcessor:
|
||||
transition[TransitionIndex.INFO],
|
||||
transition[TransitionIndex.COMPLEMENTARY_DATA],
|
||||
)
|
||||
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {}
|
||||
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary (empty for this processor)."""
|
||||
return {}
|
||||
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary (no-op for this processor)."""
|
||||
pass
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state (no-op for this processor)."""
|
||||
pass
|
||||
@@ -181,43 +182,47 @@ class StateProcessor:
|
||||
@dataclass
|
||||
class ObservationProcessor:
|
||||
"""Complete observation processor that combines image and state processing.
|
||||
|
||||
|
||||
This processor replicates the functionality of the original preprocess_observation
|
||||
function but in a modular, composable way that fits into the pipeline architecture.
|
||||
"""
|
||||
|
||||
|
||||
image_processor: ImageProcessor = field(default_factory=ImageProcessor)
|
||||
state_processor: StateProcessor = field(default_factory=StateProcessor)
|
||||
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# First process images
|
||||
transition = self.image_processor(transition)
|
||||
# Then process state
|
||||
transition = self.state_processor(transition)
|
||||
return transition
|
||||
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {
|
||||
"image_processor": self.image_processor.get_config(),
|
||||
"state_processor": self.state_processor.get_config(),
|
||||
}
|
||||
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return state dictionary."""
|
||||
state = {}
|
||||
state.update({f"image_processor.{k}": v for k, v in self.image_processor.state_dict().items()})
|
||||
state.update({f"state_processor.{k}": v for k, v in self.state_processor.state_dict().items()})
|
||||
return state
|
||||
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load state dictionary."""
|
||||
image_state = {k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")}
|
||||
state_state = {k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")}
|
||||
|
||||
image_state = {
|
||||
k.replace("image_processor.", ""): v for k, v in state.items() if k.startswith("image_processor.")
|
||||
}
|
||||
state_state = {
|
||||
k.replace("state_processor.", ""): v for k, v in state.items() if k.startswith("state_processor.")
|
||||
}
|
||||
|
||||
self.image_processor.load_state_dict(image_state)
|
||||
self.state_processor.load_state_dict(state_state)
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset processor state."""
|
||||
self.image_processor.reset()
|
||||
|
||||
@@ -14,19 +14,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import os, json
|
||||
from typing import Any, Dict, Sequence, Iterable, Protocol, Optional, Tuple, Callable, Union
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from enum import IntEnum
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, ModelHubMixin
|
||||
from safetensors.torch import save_file, load_file
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
class TransitionIndex(IntEnum):
|
||||
"""Explicit indices for EnvTransition tuple components."""
|
||||
|
||||
OBSERVATION = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
@@ -38,29 +41,28 @@ class TransitionIndex(IntEnum):
|
||||
|
||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||
EnvTransition = Tuple[
|
||||
Any| None, # observation
|
||||
Any| None, # action
|
||||
float| None, # reward
|
||||
bool| None, # done
|
||||
bool| None, # truncated
|
||||
Dict[str, Any]| None, # info
|
||||
Dict[str, Any]| None, # complementary_data
|
||||
Any | None, # observation
|
||||
Any | None, # action
|
||||
float | None, # reward
|
||||
bool | None, # done
|
||||
bool | None, # truncated
|
||||
Dict[str, Any] | None, # info
|
||||
Dict[str, Any] | None, # complementary_data
|
||||
]
|
||||
|
||||
|
||||
|
||||
class PipelineStep(Protocol):
|
||||
"""Structural typing interface for a single pipeline step.
|
||||
|
||||
|
||||
A step is any callable accepting a full `EnvTransition` tuple and
|
||||
returning a (possibly modified) tuple of the same structure. Implementers
|
||||
are encouraged—but not required—to expose the optional helper methods
|
||||
listed below. When present, these hooks let `RobotPipeline`
|
||||
automatically serialise the step's configuration and learnable state using
|
||||
a safe-to-share JSON + SafeTensors format.
|
||||
|
||||
|
||||
Optional helper protocol:
|
||||
* ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable
|
||||
* ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable
|
||||
configuration and state. YOU decide what to save here. This is where all
|
||||
non-tensor state goes (e.g., name, counter, threshold, window_size).
|
||||
The config dict will be passed to your class constructor when loading.
|
||||
@@ -70,7 +72,7 @@ class PipelineStep(Protocol):
|
||||
* ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict
|
||||
containing torch tensors only.
|
||||
* ``reset()`` – Clear internal buffers at episode boundaries.
|
||||
|
||||
|
||||
Example separation:
|
||||
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
|
||||
- state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)}
|
||||
@@ -78,11 +80,11 @@ class PipelineStep(Protocol):
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
|
||||
|
||||
def get_config(self) -> Dict[str, Any]: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]: ...
|
||||
def state_dict(self) -> dict[str, torch.Tensor]: ...
|
||||
|
||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: ...
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
|
||||
|
||||
def reset(self) -> None: ...
|
||||
|
||||
@@ -120,24 +122,25 @@ class RobotPipeline(ModelHubMixin):
|
||||
pipe.push_to_hub("my-org/cartpole_pipe")
|
||||
loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe")
|
||||
"""
|
||||
|
||||
steps: Sequence[PipelineStep] = field(default_factory=list)
|
||||
name: str = "RobotPipeline"
|
||||
seed: Optional[int] = None
|
||||
seed: int | None = None
|
||||
|
||||
# Pipeline-level hooks
|
||||
# A hook can optionally return a modified transition. If it returns
|
||||
# ``None`` the current value is left untouched.
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||||
default_factory=list, repr=False
|
||||
)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||||
default_factory=list, repr=False
|
||||
)
|
||||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Run *transition* through every step, firing hooks on the way."""
|
||||
|
||||
|
||||
# Basic validation with helpful error message
|
||||
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||
raise ValueError(
|
||||
@@ -168,23 +171,23 @@ class RobotPipeline(ModelHubMixin):
|
||||
yield transition
|
||||
|
||||
_CFG_NAME = "pipeline.json"
|
||||
|
||||
|
||||
def _save_pretrained(self, destination_path: str, **kwargs):
|
||||
"""Internal save method for ModelHubMixin compatibility."""
|
||||
self.save_pretrained(destination_path)
|
||||
|
||||
|
||||
def save_pretrained(self, destination_path: str, **kwargs):
|
||||
"""Serialize the pipeline definition and parameters to *destination_path*."""
|
||||
os.makedirs(destination_path, exist_ok=True)
|
||||
|
||||
config: Dict[str, Any] = {
|
||||
config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"seed": self.seed,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
for step_index, pipeline_step in enumerate(self.steps):
|
||||
step_entry: Dict[str, Any] = {
|
||||
step_entry: dict[str, Any] = {
|
||||
"class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}",
|
||||
}
|
||||
|
||||
@@ -204,20 +207,20 @@ class RobotPipeline(ModelHubMixin):
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, source: str) -> "RobotPipeline":
|
||||
def from_pretrained(cls, source: str) -> RobotPipeline:
|
||||
"""Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier)."""
|
||||
if Path(source).is_dir():
|
||||
# Local path - use it directly
|
||||
base_path = Path(source)
|
||||
with open(base_path / cls._CFG_NAME) as file_pointer:
|
||||
config: Dict[str, Any] = json.load(file_pointer)
|
||||
config: dict[str, Any] = json.load(file_pointer)
|
||||
else:
|
||||
# Hugging Face Hub - download all required files
|
||||
# First download the config file
|
||||
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
|
||||
with open(config_path) as file_pointer:
|
||||
config: Dict[str, Any] = json.load(file_pointer)
|
||||
|
||||
config: dict[str, Any] = json.load(file_pointer)
|
||||
|
||||
# Store downloaded files in the same directory as the config
|
||||
base_path = Path(config_path).parent
|
||||
|
||||
@@ -234,7 +237,7 @@ class RobotPipeline(ModelHubMixin):
|
||||
else:
|
||||
# Hugging Face Hub - download the state file
|
||||
state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model")
|
||||
|
||||
|
||||
step_instance.load_state_dict(load_file(state_path))
|
||||
|
||||
steps.append(step_instance)
|
||||
@@ -254,11 +257,11 @@ class RobotPipeline(ModelHubMixin):
|
||||
return RobotPipeline(self.steps[idx], self.name, self.seed)
|
||||
return self.steps[idx]
|
||||
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
"""Attach fn to be executed before every pipeline step."""
|
||||
self.before_step_hooks.append(fn)
|
||||
|
||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
|
||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
"""Attach fn to be executed after every pipeline step."""
|
||||
self.after_step_hooks.append(fn)
|
||||
|
||||
@@ -274,26 +277,26 @@ class RobotPipeline(ModelHubMixin):
|
||||
for fn in self.reset_hooks:
|
||||
fn()
|
||||
|
||||
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> Dict[str, float]:
|
||||
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> dict[str, float]:
|
||||
"""Profile the execution time of each step for performance optimization."""
|
||||
import time
|
||||
|
||||
|
||||
profile_results = {}
|
||||
|
||||
|
||||
for idx, pipeline_step in enumerate(self.steps):
|
||||
step_name = f"step_{idx}_{pipeline_step.__class__.__name__}"
|
||||
|
||||
|
||||
# Warm up
|
||||
for _ in range(5):
|
||||
_ = pipeline_step(transition)
|
||||
|
||||
|
||||
# Time the step
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(num_runs):
|
||||
transition = pipeline_step(transition)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
|
||||
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
|
||||
profile_results[step_name] = avg_time
|
||||
|
||||
return profile_results
|
||||
|
||||
return profile_results
|
||||
|
||||
Reference in New Issue
Block a user