mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
feat(pipeline): Enhance step_through method to support both tuple and dict inputs
This commit is contained in:
@@ -329,12 +329,46 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
return self.to_output(transition) if called_with_batch else transition
|
||||
|
||||
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
|
||||
"""Yield the intermediate Transition instances after each processor step."""
|
||||
yield transition
|
||||
for processor_step in self.steps:
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
|
||||
Like __call__, this method accepts either EnvTransition tuples or batch dictionaries
|
||||
and preserves the input format in the yielded results.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition tuple or a batch dictionary to process.
|
||||
|
||||
Yields:
|
||||
The intermediate results after each step, in the same format as the input.
|
||||
"""
|
||||
called_with_batch = isinstance(data, dict)
|
||||
transition = self.to_transition(data) if called_with_batch else data
|
||||
|
||||
# Basic validation with helpful error message for tuple input
|
||||
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||
raise ValueError(
|
||||
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
|
||||
"truncated, info, complementary_data). "
|
||||
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
|
||||
)
|
||||
|
||||
# Yield initial state
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
for hook in self.before_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
for hook in self.after_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
|
||||
_CFG_NAME = "processor.json"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user