fixup! fixup! Improve visualization: separate correction plot and fix axis scaling

This commit is contained in:
Eugene Mironov
2025-11-04 03:23:05 +07:00
parent 7dae02cec1
commit b5ff2b38df
2 changed files with 0 additions and 148 deletions

View File

@@ -19,7 +19,6 @@
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
@@ -211,22 +210,6 @@ class Tracker:
oldest_key = next(iter(self._steps))
del self._steps[oldest_key]
def get_recent_steps(self, n: int = 1) -> list[DebugStep]:
"""Get the n most recent debug steps.
Args:
n (int): Number of recent steps to retrieve.
Returns:
List of DebugStep objects (may be empty if disabled or no steps recorded).
"""
if not self.enabled or self._steps is None:
return []
# Get all values and return the last n
all_steps = list(self._steps.values())
return all_steps[-n:]
def get_all_steps(self) -> list[DebugStep]:
"""Get all recorded debug steps.
@@ -238,102 +221,8 @@ class Tracker:
return list(self._steps.values())
def get_step_stats_summary(self) -> dict[str, Any]:
"""Get summary statistics across all recorded steps.
Returns:
Dictionary containing aggregate statistics.
"""
if not self.enabled or self._steps is None or len(self._steps) == 0:
return {"enabled": self.enabled, "total_steps": 0}
# Aggregate statistics from dictionary values
corrections = [s.correction for s in self._steps.values() if s.correction is not None]
errors = [s.err for s in self._steps.values() if s.err is not None]
guidance_weights = [s.guidance_weight for s in self._steps.values() if s.guidance_weight is not None]
summary = {
"enabled": self.enabled,
"total_steps": len(self._steps),
"step_counter": self._step_counter,
}
if corrections:
correction_norms = torch.tensor([c.norm().item() for c in corrections])
summary["correction_norms"] = {
"mean": correction_norms.mean().item(),
"std": correction_norms.std().item(),
"min": correction_norms.min().item(),
"max": correction_norms.max().item(),
}
if errors:
error_norms = torch.tensor([e.norm().item() for e in errors])
summary["error_norms"] = {
"mean": error_norms.mean().item(),
"std": error_norms.std().item(),
"min": error_norms.min().item(),
"max": error_norms.max().item(),
}
if guidance_weights:
gw_tensor = torch.tensor([gw.item() if isinstance(gw, Tensor) else gw for gw in guidance_weights])
summary["guidance_weights"] = {
"mean": gw_tensor.mean().item(),
"std": gw_tensor.std().item(),
"min": gw_tensor.min().item(),
"max": gw_tensor.max().item(),
}
return summary
def export_to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
"""Export all debug information to a dictionary.
Args:
include_tensors (bool): If True, include full tensor values. If False,
only include tensor statistics.
Returns:
Dictionary containing all debug information.
"""
if not self.enabled or self._steps is None:
return {"enabled": False, "steps": []}
return {
"enabled": True,
"total_steps": len(self._steps),
"step_counter": self._step_counter,
"steps": [step.to_dict(include_tensors=include_tensors) for step in self._steps.values()],
}
def __len__(self) -> int:
"""Return the number of recorded debug steps."""
if not self.enabled or self._steps is None:
return 0
return len(self._steps)
@staticmethod
def tensor_stats(tensor: Tensor, name: str = "tensor") -> str:
"""Generate readable statistics string for a tensor.
Args:
tensor: Input tensor
name: Name to display
Returns:
Formatted string with shape and statistics
"""
if tensor is None:
return f"{name}: None"
stats = (
f"{name}: shape={tuple(tensor.shape)}, "
f"dtype={tensor.dtype}, "
f"device={tensor.device}, "
f"min={tensor.min().item():.4f}, "
f"max={tensor.max().item():.4f}, "
f"mean={tensor.mean().item():.4f}, "
f"std={tensor.std().item():.4f}"
)
return stats

View File

@@ -421,40 +421,3 @@ class RTCDebugVisualizer:
plt.close(fig)
return fig
@staticmethod
def print_debug_statistics(tracker: Tracker) -> None:
"""Print summary statistics from the tracker.
Args:
tracker (Tracker): Tracker with recorded steps.
"""
if not tracker.enabled:
print("Tracker is disabled.")
return
stats = tracker.get_step_stats_summary()
print("\n" + "=" * 60)
print("RTC Debug Statistics Summary")
print("=" * 60)
print(f"Enabled: {stats['enabled']}")
print(f"Total steps recorded: {stats['total_steps']}")
print(f"Step counter: {stats['step_counter']}")
if "correction_norms" in stats:
print("\nCorrection Norms:")
for key, value in stats["correction_norms"].items():
print(f" {key}: {value:.6f}")
if "error_norms" in stats:
print("\nError Norms:")
for key, value in stats["error_norms"].items():
print(f" {key}: {value:.6f}")
if "guidance_weights" in stats:
print("\nGuidance Weights:")
for key, value in stats["guidance_weights"].items():
print(f" {key}: {value:.6f}")
print("=" * 60 + "\n")