debug frames

This commit is contained in:
Pepijn
2025-08-31 19:20:18 +02:00
parent e3306951c0
commit ae57fe2d33

View File

@@ -451,7 +451,7 @@ class RLearNPolicy(PreTrainedPolicy):
# Frame-to-frame differences
if T > 1:
frame_diffs = torch.norm(sample_features[1:] - sample_features[:-1], dim=-1)
frame_diffs = (sample_features[1:] - sample_features[:-1]).pow(2).sum(dim=-1).sqrt()
avg_frame_diff = frame_diffs.mean().item()
max_frame_diff = frame_diffs.max().item()
min_frame_diff = frame_diffs.min().item()
@@ -468,10 +468,9 @@ class RLearNPolicy(PreTrainedPolicy):
# Overall batch statistics
if B > 1 and T > 1:
all_diffs = torch.norm(
vision_features[:, 1:, :] - vision_features[:, :-1, :],
dim=-1
).flatten()
all_diffs = (
vision_features[:, 1:, :] - vision_features[:, :-1, :]
).pow(2).sum(dim=-1).sqrt().flatten()
print(f"Batch-wide frame differences: mean={all_diffs.mean():.6f}, "
f"std={all_diffs.std():.6f}")