refactor(pipeline): Enhance state filename generation and profiling method

- Updated state filename generation to use the registry name when available, improving clarity in saved files.
- Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling.
- Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results.
This commit is contained in:
Adil Zouitine
2025-07-22 11:28:30 +02:00
parent ae7a54de57
commit 699363f9fc

View File

@@ -448,7 +448,12 @@ class RobotProcessor(ModelHubMixin):
for key, tensor in state.items():
cloned_state[key] = tensor.clone()
state_filename = f"step_{step_index}.safetensors"
# Use registry name for more meaningful filenames when available
if registry_name:
state_filename = f"{registry_name}.safetensors"
else:
state_filename = f"step_{step_index}.safetensors"
save_file(cloned_state, os.path.join(destination_path, state_filename))
step_entry["state_file"] = state_filename
@@ -707,23 +712,37 @@ class RobotProcessor(ModelHubMixin):
for fn in self.reset_hooks:
fn()
def profile_steps(self, transition: EnvTransition, num_runs: int = 100) -> dict[str, float]:
def profile_steps(
self, transition: EnvTransition, num_runs: int = 100, warmup_runs: int = 5
) -> dict[str, float]:
"""Profile the execution time of each step for performance optimization."""
import copy
import time
profile_results = {}
# Make a copy to avoid altering the original transition
transition_copy = copy.deepcopy(transition)
# Get intermediate transitions for each step using step_through
intermediate_transitions = list(self.step_through(transition_copy))
for idx, processor_step in enumerate(self.steps):
step_name = f"step_{idx}_{processor_step.__class__.__name__}"
# Warm up
for _ in range(5):
_ = processor_step(transition)
# Use the appropriate input transition for this step
input_transition = intermediate_transitions[idx]
# Time the step
# Warm up - copy transition for each run to ensure consistent conditions
for _ in range(warmup_runs):
transition_copy = copy.deepcopy(input_transition)
_ = processor_step(transition_copy)
# Time the step - copy transition for each run to ensure consistent conditions
start_time = time.perf_counter()
for _ in range(num_runs):
transition = processor_step(transition)
transition_copy = copy.deepcopy(input_transition)
_ = processor_step(transition_copy)
end_time = time.perf_counter()
avg_time = (end_time - start_time) / num_runs * 1000 # Convert to milliseconds