From ae57fe2d339ec284c5a9e7fca1b1f353e4d38da6 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 19:20:18 +0200 Subject: [PATCH] debug frames --- src/lerobot/policies/rlearn/modeling_rlearn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index dfd0b961f..80e72b465 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -451,7 +451,7 @@ class RLearNPolicy(PreTrainedPolicy): # Frame-to-frame differences if T > 1: - frame_diffs = torch.norm(sample_features[1:] - sample_features[:-1], dim=-1) + frame_diffs = (sample_features[1:] - sample_features[:-1]).pow(2).sum(dim=-1).sqrt() avg_frame_diff = frame_diffs.mean().item() max_frame_diff = frame_diffs.max().item() min_frame_diff = frame_diffs.min().item() @@ -468,10 +468,9 @@ class RLearNPolicy(PreTrainedPolicy): # Overall batch statistics if B > 1 and T > 1: - all_diffs = torch.norm( - vision_features[:, 1:, :] - vision_features[:, :-1, :], - dim=-1 - ).flatten() + all_diffs = ( + vision_features[:, 1:, :] - vision_features[:, :-1, :] + ).pow(2).sum(dim=-1).sqrt().flatten() print(f"Batch-wide frame differences: mean={all_diffs.mean():.6f}, " f"std={all_diffs.std():.6f}")