Apply suggestions from code review

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Adil Zouitine
2025-07-09 18:20:43 +02:00
parent 33969a0337
commit 1e0d667a22
9 changed files with 15 additions and 18 deletions

View File

@@ -32,6 +32,7 @@ from .pipeline import (
ProcessorStepRegistry,
RewardProcessor,
RobotProcessor,
TransitionIndex,
TruncatedProcessor,
)
from .rename_processor import RenameProcessor
@@ -53,6 +54,7 @@ __all__ = [
"RewardProcessor",
"RobotProcessor",
"StateProcessor",
"TransitionIndex",
"TruncatedProcessor",
"VanillaObservationProcessor",
]

View File

@@ -21,7 +21,7 @@ import os
from dataclasses import dataclass, field
from enum import IntEnum
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Protocol, Sequence, Tuple
from typing import Any, Callable, Iterable, Protocol, Sequence
import torch
from huggingface_hub import ModelHubMixin, hf_hub_download
@@ -41,14 +41,14 @@ class TransitionIndex(IntEnum):
# (observation, action, reward, done, truncated, info, complementary_data)
EnvTransition = Tuple[
EnvTransition = tuple[
dict[str, Any] | None, # observation
Any | torch.Tensor | None, # action
float | torch.Tensor | None, # reward
bool | torch.Tensor | None, # done
bool | torch.Tensor | None, # truncated
Dict[str, Any] | None, # info
Dict[str, Any] | None, # complementary_data
dict[str, Any] | None, # info
dict[str, Any] | None, # complementary_data
]
@@ -135,11 +135,11 @@ class ProcessorStep(Protocol):
a safe-to-share JSON + SafeTensors format.
Optional helper protocol:
* ``get_config() -> Dict[str, Any]`` User-defined JSON-serializable
* ``get_config() -> dict[str, Any]`` User-defined JSON-serializable
configuration and state. YOU decide what to save here. This is where all
non-tensor state goes (e.g., name, counter, threshold, window_size).
The config dict will be passed to your class constructor when loading.
* ``state_dict() -> Dict[str, torch.Tensor]`` PyTorch tensor state ONLY.
* ``state_dict() -> dict[str, torch.Tensor]`` PyTorch tensor state ONLY.
This is exclusively for torch.Tensor objects (e.g., learned weights,
running statistics as tensors). Never put simple Python types here.
* ``load_state_dict(state)`` Inverse of ``state_dict``. Receives a dict