mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
fixup! fixup! Improve visualization: separate correction plot and fix axis scaling
This commit is contained in:
@@ -19,7 +19,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@@ -211,22 +210,6 @@ class Tracker:
|
||||
oldest_key = next(iter(self._steps))
|
||||
del self._steps[oldest_key]
|
||||
|
||||
def get_recent_steps(self, n: int = 1) -> list[DebugStep]:
|
||||
"""Get the n most recent debug steps.
|
||||
|
||||
Args:
|
||||
n (int): Number of recent steps to retrieve.
|
||||
|
||||
Returns:
|
||||
List of DebugStep objects (may be empty if disabled or no steps recorded).
|
||||
"""
|
||||
if not self.enabled or self._steps is None:
|
||||
return []
|
||||
|
||||
# Get all values and return the last n
|
||||
all_steps = list(self._steps.values())
|
||||
return all_steps[-n:]
|
||||
|
||||
def get_all_steps(self) -> list[DebugStep]:
|
||||
"""Get all recorded debug steps.
|
||||
|
||||
@@ -238,102 +221,8 @@ class Tracker:
|
||||
|
||||
return list(self._steps.values())
|
||||
|
||||
def get_step_stats_summary(self) -> dict[str, Any]:
|
||||
"""Get summary statistics across all recorded steps.
|
||||
|
||||
Returns:
|
||||
Dictionary containing aggregate statistics.
|
||||
"""
|
||||
if not self.enabled or self._steps is None or len(self._steps) == 0:
|
||||
return {"enabled": self.enabled, "total_steps": 0}
|
||||
|
||||
# Aggregate statistics from dictionary values
|
||||
corrections = [s.correction for s in self._steps.values() if s.correction is not None]
|
||||
errors = [s.err for s in self._steps.values() if s.err is not None]
|
||||
guidance_weights = [s.guidance_weight for s in self._steps.values() if s.guidance_weight is not None]
|
||||
|
||||
summary = {
|
||||
"enabled": self.enabled,
|
||||
"total_steps": len(self._steps),
|
||||
"step_counter": self._step_counter,
|
||||
}
|
||||
|
||||
if corrections:
|
||||
correction_norms = torch.tensor([c.norm().item() for c in corrections])
|
||||
summary["correction_norms"] = {
|
||||
"mean": correction_norms.mean().item(),
|
||||
"std": correction_norms.std().item(),
|
||||
"min": correction_norms.min().item(),
|
||||
"max": correction_norms.max().item(),
|
||||
}
|
||||
|
||||
if errors:
|
||||
error_norms = torch.tensor([e.norm().item() for e in errors])
|
||||
summary["error_norms"] = {
|
||||
"mean": error_norms.mean().item(),
|
||||
"std": error_norms.std().item(),
|
||||
"min": error_norms.min().item(),
|
||||
"max": error_norms.max().item(),
|
||||
}
|
||||
|
||||
if guidance_weights:
|
||||
gw_tensor = torch.tensor([gw.item() if isinstance(gw, Tensor) else gw for gw in guidance_weights])
|
||||
summary["guidance_weights"] = {
|
||||
"mean": gw_tensor.mean().item(),
|
||||
"std": gw_tensor.std().item(),
|
||||
"min": gw_tensor.min().item(),
|
||||
"max": gw_tensor.max().item(),
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
def export_to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
|
||||
"""Export all debug information to a dictionary.
|
||||
|
||||
Args:
|
||||
include_tensors (bool): If True, include full tensor values. If False,
|
||||
only include tensor statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing all debug information.
|
||||
"""
|
||||
if not self.enabled or self._steps is None:
|
||||
return {"enabled": False, "steps": []}
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"total_steps": len(self._steps),
|
||||
"step_counter": self._step_counter,
|
||||
"steps": [step.to_dict(include_tensors=include_tensors) for step in self._steps.values()],
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of recorded debug steps."""
|
||||
if not self.enabled or self._steps is None:
|
||||
return 0
|
||||
return len(self._steps)
|
||||
|
||||
@staticmethod
|
||||
def tensor_stats(tensor: Tensor, name: str = "tensor") -> str:
|
||||
"""Generate readable statistics string for a tensor.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor
|
||||
name: Name to display
|
||||
|
||||
Returns:
|
||||
Formatted string with shape and statistics
|
||||
"""
|
||||
if tensor is None:
|
||||
return f"{name}: None"
|
||||
|
||||
stats = (
|
||||
f"{name}: shape={tuple(tensor.shape)}, "
|
||||
f"dtype={tensor.dtype}, "
|
||||
f"device={tensor.device}, "
|
||||
f"min={tensor.min().item():.4f}, "
|
||||
f"max={tensor.max().item():.4f}, "
|
||||
f"mean={tensor.mean().item():.4f}, "
|
||||
f"std={tensor.std().item():.4f}"
|
||||
)
|
||||
return stats
|
||||
|
||||
@@ -421,40 +421,3 @@ class RTCDebugVisualizer:
|
||||
plt.close(fig)
|
||||
|
||||
return fig
|
||||
|
||||
@staticmethod
|
||||
def print_debug_statistics(tracker: Tracker) -> None:
|
||||
"""Print summary statistics from the tracker.
|
||||
|
||||
Args:
|
||||
tracker (Tracker): Tracker with recorded steps.
|
||||
"""
|
||||
if not tracker.enabled:
|
||||
print("Tracker is disabled.")
|
||||
return
|
||||
|
||||
stats = tracker.get_step_stats_summary()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("RTC Debug Statistics Summary")
|
||||
print("=" * 60)
|
||||
print(f"Enabled: {stats['enabled']}")
|
||||
print(f"Total steps recorded: {stats['total_steps']}")
|
||||
print(f"Step counter: {stats['step_counter']}")
|
||||
|
||||
if "correction_norms" in stats:
|
||||
print("\nCorrection Norms:")
|
||||
for key, value in stats["correction_norms"].items():
|
||||
print(f" {key}: {value:.6f}")
|
||||
|
||||
if "error_norms" in stats:
|
||||
print("\nError Norms:")
|
||||
for key, value in stats["error_norms"].items():
|
||||
print(f" {key}: {value:.6f}")
|
||||
|
||||
if "guidance_weights" in stats:
|
||||
print("\nGuidance Weights:")
|
||||
for key, value in stats["guidance_weights"].items():
|
||||
print(f" {key}: {value:.6f}")
|
||||
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
Reference in New Issue
Block a user