mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
chore (output format): improves output format
This commit is contained in:
@@ -239,34 +239,26 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
class RobotProcessor(ModelHubMixin):
|
||||
"""
|
||||
Composable, debuggable post-processing processor for robot transitions.
|
||||
The class orchestrates an ordered collection of small, functional
|
||||
transforms—steps—executed left-to-right on each incoming
|
||||
`EnvTransition`.
|
||||
Parameters:
|
||||
steps : Sequence[ProcessorStep], optional
|
||||
Ordered list executed on every call
|
||||
name : str, default="RobotProcessor"
|
||||
Human-readable identifier that is persisted inside the JSON config.
|
||||
seed : int | None, optional
|
||||
Global seed forwarded to steps that choose to consume it.
|
||||
Examples:
|
||||
Basic usage::
|
||||
env = gym.make("CartPole-v1")
|
||||
proc = RobotProcessor([
|
||||
ObservationNormalizer(),
|
||||
IntrinsicVelocity(),
|
||||
VelocityBonus(0.02),
|
||||
])
|
||||
obs, info = env.reset(seed=0)
|
||||
tr = (obs, None, 0.0, False, False, info, {})
|
||||
obs, *_ = proc(tr) # agent sees a normalised observation
|
||||
Inspecting intermediate results::
|
||||
for idx, step_tr in enumerate(proc.step_through(tr)):
|
||||
print(idx, step_tr)
|
||||
Serialization to the Hugging Face Hub::
|
||||
proc.save_pretrained("chkpt")
|
||||
proc.push_to_hub("my-org/cartpole_proc")
|
||||
loaded = RobotProcessor.from_pretrained("my-org/cartpole_proc")
|
||||
|
||||
The class orchestrates an ordered collection of small, functional transforms—steps—executed
|
||||
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` tuples
|
||||
and batch dictionaries, automatically converting between formats as needed.
|
||||
|
||||
Args:
|
||||
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
|
||||
name: Human-readable identifier that is persisted inside the JSON config.
|
||||
Defaults to "RobotProcessor".
|
||||
seed: Global seed forwarded to steps that choose to consume it. Defaults to None.
|
||||
to_transition: Function to convert batch dict to EnvTransition tuple.
|
||||
Defaults to _default_batch_to_transition.
|
||||
to_output: Function to convert EnvTransition tuple to the desired output format.
|
||||
Usually it is a batch dict or EnvTransition tuple.
|
||||
Defaults to _default_transition_to_batch.
|
||||
before_step_hooks: List of hooks called before each step. Each hook receives the step
|
||||
index and transition, and can optionally return a modified transition.
|
||||
after_step_hooks: List of hooks called after each step. Each hook receives the step
|
||||
index and transition, and can optionally return a modified transition.
|
||||
reset_hooks: List of hooks called during processor reset.
|
||||
"""
|
||||
|
||||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||||
@@ -276,7 +268,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
|
||||
default_factory=lambda: _default_batch_to_transition, repr=False
|
||||
)
|
||||
to_batch: Callable[[EnvTransition], dict[str, Any]] = field(
|
||||
to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field(
|
||||
default_factory=lambda: _default_transition_to_batch, repr=False
|
||||
)
|
||||
|
||||
@@ -292,16 +284,22 @@ class RobotProcessor(ModelHubMixin):
|
||||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
||||
|
||||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||
"""Process *data* through all steps.
|
||||
"""Process data through all steps.
|
||||
|
||||
The method accepts **either** the classic :pydata:`EnvTransition` tuple
|
||||
**or** a *batch* dictionary (like the ones returned by
|
||||
:class:`lerobot.utils.buffer.ReplayBuffer` or
|
||||
:class:`lerobot.datasets.lerobot_dataset.LeRobotDataset`). If a dict is
|
||||
supplied it is first converted to the internal tuple format using
|
||||
:pyattr:`to_transition`; after all steps are executed the tuple is
|
||||
transformed back into a dict with :pyattr:`to_batch` and the result is
|
||||
returned – thereby preserving the caller's original data type.
|
||||
The method accepts either the classic EnvTransition tuple or a batch dictionary
|
||||
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
|
||||
it is first converted to the internal tuple format using to_transition; after all
|
||||
steps are executed the tuple is transformed back into a dict with to_batch and the
|
||||
result is returned – thereby preserving the caller's original data type.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition tuple or a batch dictionary to process.
|
||||
|
||||
Returns:
|
||||
The processed data in the same format as the input (tuple or dict).
|
||||
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid 7-tuple format.
|
||||
"""
|
||||
|
||||
called_with_batch = isinstance(data, dict)
|
||||
@@ -329,7 +327,7 @@ class RobotProcessor(ModelHubMixin):
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
|
||||
return self.to_batch(transition) if called_with_batch else transition
|
||||
return self.to_output(transition) if called_with_batch else transition
|
||||
|
||||
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
|
||||
"""Yield the intermediate Transition instances after each processor step."""
|
||||
|
||||
Reference in New Issue
Block a user