feat(pipeline): Enhance step_through method to support both tuple and dict inputs

This commit is contained in:
Adil Zouitine
2025-07-08 13:14:58 +02:00
parent e9f7f5127b
commit fa26290e8c
2 changed files with 74 additions and 6 deletions

View File

@@ -329,12 +329,46 @@ class RobotProcessor(ModelHubMixin):
return self.to_output(transition) if called_with_batch else transition
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
"""Yield the intermediate Transition instances after each processor step."""
yield transition
for processor_step in self.steps:
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
"""Yield the intermediate results after each processor step.
Like __call__, this method accepts either EnvTransition tuples or batch dictionaries
and preserves the input format in the yielded results.
Args:
data: Either an EnvTransition tuple or a batch dictionary to process.
Yields:
The intermediate results after each step, in the same format as the input.
"""
called_with_batch = isinstance(data, dict)
transition = self.to_transition(data) if called_with_batch else data
# Basic validation with helpful error message for tuple input
if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError(
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
"truncated, info, complementary_data). "
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
)
# Yield initial state
yield self.to_output(transition) if called_with_batch else transition
for idx, processor_step in enumerate(self.steps):
for hook in self.before_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
transition = processor_step(transition)
yield transition
for hook in self.after_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
yield self.to_output(transition) if called_with_batch else transition
_CFG_NAME = "processor.json"