From 3fb3edde3f4d519746bc0daf084fdee93e9b1a88 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Tue, 4 Nov 2025 03:23:05 +0700 Subject: [PATCH] fixup! fixup! Improve visualization: separate correction plot and fix axis scaling --- src/lerobot/policies/rtc/debug_handler.py | 111 ------------------- src/lerobot/policies/rtc/debug_visualizer.py | 37 ------- 2 files changed, 148 deletions(-) diff --git a/src/lerobot/policies/rtc/debug_handler.py b/src/lerobot/policies/rtc/debug_handler.py index dd3040016..135281f9e 100644 --- a/src/lerobot/policies/rtc/debug_handler.py +++ b/src/lerobot/policies/rtc/debug_handler.py @@ -19,7 +19,6 @@ from dataclasses import dataclass, field from typing import Any -import torch from torch import Tensor @@ -211,22 +210,6 @@ class Tracker: oldest_key = next(iter(self._steps)) del self._steps[oldest_key] - def get_recent_steps(self, n: int = 1) -> list[DebugStep]: - """Get the n most recent debug steps. - - Args: - n (int): Number of recent steps to retrieve. - - Returns: - List of DebugStep objects (may be empty if disabled or no steps recorded). - """ - if not self.enabled or self._steps is None: - return [] - - # Get all values and return the last n - all_steps = list(self._steps.values()) - return all_steps[-n:] - def get_all_steps(self) -> list[DebugStep]: """Get all recorded debug steps. @@ -238,102 +221,8 @@ class Tracker: return list(self._steps.values()) - def get_step_stats_summary(self) -> dict[str, Any]: - """Get summary statistics across all recorded steps. - - Returns: - Dictionary containing aggregate statistics. - """ - if not self.enabled or self._steps is None or len(self._steps) == 0: - return {"enabled": self.enabled, "total_steps": 0} - - # Aggregate statistics from dictionary values - corrections = [s.correction for s in self._steps.values() if s.correction is not None] - errors = [s.err for s in self._steps.values() if s.err is not None] - guidance_weights = [s.guidance_weight for s in self._steps.values() if s.guidance_weight is not None] - - summary = { - "enabled": self.enabled, - "total_steps": len(self._steps), - "step_counter": self._step_counter, - } - - if corrections: - correction_norms = torch.tensor([c.norm().item() for c in corrections]) - summary["correction_norms"] = { - "mean": correction_norms.mean().item(), - "std": correction_norms.std().item(), - "min": correction_norms.min().item(), - "max": correction_norms.max().item(), - } - - if errors: - error_norms = torch.tensor([e.norm().item() for e in errors]) - summary["error_norms"] = { - "mean": error_norms.mean().item(), - "std": error_norms.std().item(), - "min": error_norms.min().item(), - "max": error_norms.max().item(), - } - - if guidance_weights: - gw_tensor = torch.tensor([gw.item() if isinstance(gw, Tensor) else gw for gw in guidance_weights]) - summary["guidance_weights"] = { - "mean": gw_tensor.mean().item(), - "std": gw_tensor.std().item(), - "min": gw_tensor.min().item(), - "max": gw_tensor.max().item(), - } - - return summary - - def export_to_dict(self, include_tensors: bool = False) -> dict[str, Any]: - """Export all debug information to a dictionary. - - Args: - include_tensors (bool): If True, include full tensor values. If False, - only include tensor statistics. - - Returns: - Dictionary containing all debug information. - """ - if not self.enabled or self._steps is None: - return {"enabled": False, "steps": []} - - return { - "enabled": True, - "total_steps": len(self._steps), - "step_counter": self._step_counter, - "steps": [step.to_dict(include_tensors=include_tensors) for step in self._steps.values()], - } - def __len__(self) -> int: """Return the number of recorded debug steps.""" if not self.enabled or self._steps is None: return 0 return len(self._steps) - - @staticmethod - def tensor_stats(tensor: Tensor, name: str = "tensor") -> str: - """Generate readable statistics string for a tensor. - - Args: - tensor: Input tensor - name: Name to display - - Returns: - Formatted string with shape and statistics - """ - if tensor is None: - return f"{name}: None" - - stats = ( - f"{name}: shape={tuple(tensor.shape)}, " - f"dtype={tensor.dtype}, " - f"device={tensor.device}, " - f"min={tensor.min().item():.4f}, " - f"max={tensor.max().item():.4f}, " - f"mean={tensor.mean().item():.4f}, " - f"std={tensor.std().item():.4f}" - ) - return stats diff --git a/src/lerobot/policies/rtc/debug_visualizer.py b/src/lerobot/policies/rtc/debug_visualizer.py index a9c5ee86c..4c7c218af 100644 --- a/src/lerobot/policies/rtc/debug_visualizer.py +++ b/src/lerobot/policies/rtc/debug_visualizer.py @@ -421,40 +421,3 @@ class RTCDebugVisualizer: plt.close(fig) return fig - - @staticmethod - def print_debug_statistics(tracker: Tracker) -> None: - """Print summary statistics from the tracker. - - Args: - tracker (Tracker): Tracker with recorded steps. - """ - if not tracker.enabled: - print("Tracker is disabled.") - return - - stats = tracker.get_step_stats_summary() - - print("\n" + "=" * 60) - print("RTC Debug Statistics Summary") - print("=" * 60) - print(f"Enabled: {stats['enabled']}") - print(f"Total steps recorded: {stats['total_steps']}") - print(f"Step counter: {stats['step_counter']}") - - if "correction_norms" in stats: - print("\nCorrection Norms:") - for key, value in stats["correction_norms"].items(): - print(f" {key}: {value:.6f}") - - if "error_norms" in stats: - print("\nError Norms:") - for key, value in stats["error_norms"].items(): - print(f" {key}: {value:.6f}") - - if "guidance_weights" in stats: - print("\nGuidance Weights:") - for key, value in stats["guidance_weights"].items(): - print(f" {key}: {value:.6f}") - - print("=" * 60 + "\n")