mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Adil Zouitine
parent
f6c7287ae7
commit
769f531603
@@ -14,19 +14,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
import os, json
|
||||
from typing import Any, Dict, Sequence, Iterable, Protocol, Optional, Tuple, Callable, Union
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from enum import IntEnum
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, ModelHubMixin
|
||||
from safetensors.torch import save_file, load_file
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
class TransitionIndex(IntEnum):
|
||||
"""Explicit indices for EnvTransition tuple components."""
|
||||
|
||||
OBSERVATION = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
@@ -38,29 +41,28 @@ class TransitionIndex(IntEnum):
|
||||
|
||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||
EnvTransition = Tuple[
|
||||
Any| None, # observation
|
||||
Any| None, # action
|
||||
float| None, # reward
|
||||
bool| None, # done
|
||||
bool| None, # truncated
|
||||
Dict[str, Any]| None, # info
|
||||
Dict[str, Any]| None, # complementary_data
|
||||
Any | None, # observation
|
||||
Any | None, # action
|
||||
float | None, # reward
|
||||
bool | None, # done
|
||||
bool | None, # truncated
|
||||
Dict[str, Any] | None, # info
|
||||
Dict[str, Any] | None, # complementary_data
|
||||
]
|
||||
|
||||
|
||||
|
||||
class PipelineStep(Protocol):
|
||||
"""Structural typing interface for a single pipeline step.
|
||||
|
||||
|
||||
A step is any callable accepting a full `EnvTransition` tuple and
|
||||
returning a (possibly modified) tuple of the same structure. Implementers
|
||||
are encouraged—but not required—to expose the optional helper methods
|
||||
listed below. When present, these hooks let `RobotPipeline`
|
||||
automatically serialise the step's configuration and learnable state using
|
||||
a safe-to-share JSON + SafeTensors format.
|
||||
|
||||
|
||||
Optional helper protocol:
|
||||
* ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable
|
||||
* ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable
|
||||
configuration and state. YOU decide what to save here. This is where all
|
||||
non-tensor state goes (e.g., name, counter, threshold, window_size).
|
||||
The config dict will be passed to your class constructor when loading.
|
||||
@@ -70,7 +72,7 @@ class PipelineStep(Protocol):
|
||||
* ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict
|
||||
containing torch tensors only.
|
||||
* ``reset()`` – Clear internal buffers at episode boundaries.
|
||||
|
||||
|
||||
Example separation:
|
||||
- get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10}
|
||||
- state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)}
|
||||
@@ -78,11 +80,11 @@ class PipelineStep(Protocol):
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition: ...
|
||||
|
||||
def get_config(self) -> Dict[str, Any]: ...
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
def state_dict(self) -> Dict[str, torch.Tensor]: ...
|
||||
def state_dict(self) -> dict[str, torch.Tensor]: ...
|
||||
|
||||
def load_state_dict(self, state: Dict[str, torch.Tensor]) -> None: ...
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ...
|
||||
|
||||
def reset(self) -> None: ...
|
||||
|
||||
@@ -120,24 +122,25 @@ class RobotPipeline(ModelHubMixin):
|
||||
pipe.push_to_hub("my-org/cartpole_pipe")
|
||||
loaded = RobotPipeline.from_pretrained("my-org/cartpole_pipe")
|
||||
"""
|
||||
|
||||
steps: Sequence[PipelineStep] = field(default_factory=list)
|
||||
name: str = "RobotPipeline"
|
||||
seed: Optional[int] = None
|
||||
seed: int | None = None
|
||||
|
||||
# Pipeline-level hooks
|
||||
# A hook can optionally return a modified transition. If it returns
|
||||
# ``None`` the current value is left untouched.
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||||
default_factory=list, repr=False
|
||||
)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], Optional[EnvTransition]]] = field(
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||||
default_factory=list, repr=False
|
||||
)
|
||||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Run *transition* through every step, firing hooks on the way."""
|
||||
|
||||
|
||||
# Basic validation with helpful error message
|
||||
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||
raise ValueError(
|
||||
@@ -168,23 +171,23 @@ class RobotPipeline(ModelHubMixin):
|
||||
yield transition
|
||||
|
||||
_CFG_NAME = "pipeline.json"
|
||||
|
||||
|
||||
def _save_pretrained(self, destination_path: str, **kwargs):
|
||||
"""Internal save method for ModelHubMixin compatibility."""
|
||||
self.save_pretrained(destination_path)
|
||||
|
||||
|
||||
def save_pretrained(self, destination_path: str, **kwargs):
|
||||
"""Serialize the pipeline definition and parameters to *destination_path*."""
|
||||
os.makedirs(destination_path, exist_ok=True)
|
||||
|
||||
config: Dict[str, Any] = {
|
||||
config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"seed": self.seed,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
for step_index, pipeline_step in enumerate(self.steps):
|
||||
step_entry: Dict[str, Any] = {
|
||||
step_entry: dict[str, Any] = {
|
||||
"class": f"{pipeline_step.__class__.__module__}.{pipeline_step.__class__.__name__}",
|
||||
}
|
||||
|
||||
@@ -204,20 +207,20 @@ class RobotPipeline(ModelHubMixin):
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, source: str) -> "RobotPipeline":
|
||||
def from_pretrained(cls, source: str) -> RobotPipeline:
|
||||
"""Load a serialized pipeline from *source* (local path or Hugging Face Hub identifier)."""
|
||||
if Path(source).is_dir():
|
||||
# Local path - use it directly
|
||||
base_path = Path(source)
|
||||
with open(base_path / cls._CFG_NAME) as file_pointer:
|
||||
config: Dict[str, Any] = json.load(file_pointer)
|
||||
config: dict[str, Any] = json.load(file_pointer)
|
||||
else:
|
||||
# Hugging Face Hub - download all required files
|
||||
# First download the config file
|
||||
config_path = hf_hub_download(source, cls._CFG_NAME, repo_type="model")
|
||||
with open(config_path) as file_pointer:
|
||||
config: Dict[str, Any] = json.load(file_pointer)
|
||||
|
||||
config: dict[str, Any] = json.load(file_pointer)
|
||||
|
||||
# Store downloaded files in the same directory as the config
|
||||
base_path = Path(config_path).parent
|
||||
|
||||
@@ -234,7 +237,7 @@ class RobotPipeline(ModelHubMixin):
|
||||
else:
|
||||
# Hugging Face Hub - download the state file
|
||||
state_path = hf_hub_download(source, step_entry["state_file"], repo_type="model")
|
||||
|
||||
|
||||
step_instance.load_state_dict(load_file(state_path))
|
||||
|
||||
steps.append(step_instance)
|
||||
@@ -254,11 +257,11 @@ class RobotPipeline(ModelHubMixin):
|
||||
return RobotPipeline(self.steps[idx], self.name, self.seed)
|
||||
return self.steps[idx]
|
||||
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
"""Attach fn to be executed before every pipeline step."""
|
||||
self.before_step_hooks.append(fn)
|
||||
|
||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], Optional[EnvTransition]]):
|
||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
"""Attach fn to be executed after every pipeline step."""
|
||||
self.after_step_hooks.append(fn)
|
||||
|
||||
@@ -274,26 +277,26 @@ class RobotPipeline(ModelHubMixin):
|
||||
for fn in self.reset_hooks:
|
||||
fn()
|
||||
|
||||
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> Dict[str, float]:
|
||||
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> dict[str, float]:
|
||||
"""Profile the execution time of each step for performance optimization."""
|
||||
import time
|
||||
|
||||
|
||||
profile_results = {}
|
||||
|
||||
|
||||
for idx, pipeline_step in enumerate(self.steps):
|
||||
step_name = f"step_{idx}_{pipeline_step.__class__.__name__}"
|
||||
|
||||
|
||||
# Warm up
|
||||
for _ in range(5):
|
||||
_ = pipeline_step(transition)
|
||||
|
||||
|
||||
# Time the step
|
||||
start_time = time.perf_counter()
|
||||
for _ in range(num_runs):
|
||||
transition = pipeline_step(transition)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
|
||||
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds
|
||||
profile_results[step_name] = avg_time
|
||||
|
||||
return profile_results
|
||||
|
||||
return profile_results
|
||||
|
||||
Reference in New Issue
Block a user