chore (output format): improves output format

This commit is contained in:
Adil Zouitine
2025-07-06 22:03:37 +02:00
parent 730c7b2f35
commit 83a4338f8b
4 changed files with 487 additions and 63 deletions

View File

@@ -239,34 +239,26 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
class RobotProcessor(ModelHubMixin):
"""
Composable, debuggable post-processing processor for robot transitions.
The class orchestrates an ordered collection of small, functional
transforms—steps—executed left-to-right on each incoming
`EnvTransition`.
Parameters:
steps : Sequence[ProcessorStep], optional
Ordered list executed on every call
name : str, default="RobotProcessor"
Human-readable identifier that is persisted inside the JSON config.
seed : int | None, optional
Global seed forwarded to steps that choose to consume it.
Examples:
Basic usage::
env = gym.make("CartPole-v1")
proc = RobotProcessor([
ObservationNormalizer(),
IntrinsicVelocity(),
VelocityBonus(0.02),
])
obs, info = env.reset(seed=0)
tr = (obs, None, 0.0, False, False, info, {})
obs, *_ = proc(tr) # agent sees a normalised observation
Inspecting intermediate results::
for idx, step_tr in enumerate(proc.step_through(tr)):
print(idx, step_tr)
Serialization to the Hugging Face Hub::
proc.save_pretrained("chkpt")
proc.push_to_hub("my-org/cartpole_proc")
loaded = RobotProcessor.from_pretrained("my-org/cartpole_proc")
The class orchestrates an ordered collection of small, functional transforms—steps—executed
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` tuples
and batch dictionaries, automatically converting between formats as needed.
Args:
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
name: Human-readable identifier that is persisted inside the JSON config.
Defaults to "RobotProcessor".
seed: Global seed forwarded to steps that choose to consume it. Defaults to None.
to_transition: Function to convert batch dict to EnvTransition tuple.
Defaults to _default_batch_to_transition.
to_output: Function to convert EnvTransition tuple to the desired output format.
Usually it is a batch dict or EnvTransition tuple.
Defaults to _default_transition_to_batch.
before_step_hooks: List of hooks called before each step. Each hook receives the step
index and transition, and can optionally return a modified transition.
after_step_hooks: List of hooks called after each step. Each hook receives the step
index and transition, and can optionally return a modified transition.
reset_hooks: List of hooks called during processor reset.
"""
steps: Sequence[ProcessorStep] = field(default_factory=list)
@@ -276,7 +268,7 @@ class RobotProcessor(ModelHubMixin):
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
default_factory=lambda: _default_batch_to_transition, repr=False
)
to_batch: Callable[[EnvTransition], dict[str, Any]] = field(
to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field(
default_factory=lambda: _default_transition_to_batch, repr=False
)
@@ -292,16 +284,22 @@ class RobotProcessor(ModelHubMixin):
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
def __call__(self, data: EnvTransition | dict[str, Any]):
"""Process *data* through all steps.
"""Process data through all steps.
The method accepts **either** the classic :pydata:`EnvTransition` tuple
**or** a *batch* dictionary (like the ones returned by
:class:`lerobot.utils.buffer.ReplayBuffer` or
:class:`lerobot.datasets.lerobot_dataset.LeRobotDataset`). If a dict is
supplied it is first converted to the internal tuple format using
:pyattr:`to_transition`; after all steps are executed the tuple is
transformed back into a dict with :pyattr:`to_batch` and the result is
returned thereby preserving the caller's original data type.
The method accepts either the classic EnvTransition tuple or a batch dictionary
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
it is first converted to the internal tuple format using to_transition; after all
steps are executed the tuple is transformed back into a dict with to_batch and the
result is returned thereby preserving the caller's original data type.
Args:
data: Either an EnvTransition tuple or a batch dictionary to process.
Returns:
The processed data in the same format as the input (tuple or dict).
Raises:
ValueError: If the transition is not a valid 7-tuple format.
"""
called_with_batch = isinstance(data, dict)
@@ -329,7 +327,7 @@ class RobotProcessor(ModelHubMixin):
if updated is not None:
transition = updated
return self.to_batch(transition) if called_with_batch else transition
return self.to_output(transition) if called_with_batch else transition
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
"""Yield the intermediate Transition instances after each processor step."""