mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
Apply suggestions from code review
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -32,6 +32,7 @@ from .pipeline import (
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessor,
|
||||
RobotProcessor,
|
||||
TransitionIndex,
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from .rename_processor import RenameProcessor
|
||||
@@ -53,6 +54,7 @@ __all__ = [
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"StateProcessor",
|
||||
"TransitionIndex",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
]
|
||||
|
||||
@@ -21,7 +21,7 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
|
||||
from typing import Any, Callable, Iterable, Protocol, Sequence
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
@@ -41,14 +41,14 @@ class TransitionIndex(IntEnum):
|
||||
|
||||
|
||||
# (observation, action, reward, done, truncated, info, complementary_data)
|
||||
EnvTransition = Tuple[
|
||||
EnvTransition = tuple[
|
||||
dict[str, Any] | None, # observation
|
||||
Any | torch.Tensor | None, # action
|
||||
float | torch.Tensor | None, # reward
|
||||
bool | torch.Tensor | None, # done
|
||||
bool | torch.Tensor | None, # truncated
|
||||
Dict[str, Any] | None, # info
|
||||
Dict[str, Any] | None, # complementary_data
|
||||
dict[str, Any] | None, # info
|
||||
dict[str, Any] | None, # complementary_data
|
||||
]
|
||||
|
||||
|
||||
@@ -135,11 +135,11 @@ class ProcessorStep(Protocol):
|
||||
a safe-to-share JSON + SafeTensors format.
|
||||
|
||||
Optional helper protocol:
|
||||
* ``get_config() -> Dict[str, Any]`` – User-defined JSON-serializable
|
||||
* ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable
|
||||
configuration and state. YOU decide what to save here. This is where all
|
||||
non-tensor state goes (e.g., name, counter, threshold, window_size).
|
||||
The config dict will be passed to your class constructor when loading.
|
||||
* ``state_dict() -> Dict[str, torch.Tensor]`` – PyTorch tensor state ONLY.
|
||||
* ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY.
|
||||
This is exclusively for torch.Tensor objects (e.g., learned weights,
|
||||
running statistics as tensors). Never put simple Python types here.
|
||||
* ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict
|
||||
|
||||
Reference in New Issue
Block a user