mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
debug frames
This commit is contained in:
@@ -451,7 +451,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Frame-to-frame differences
|
||||
if T > 1:
|
||||
frame_diffs = torch.norm(sample_features[1:] - sample_features[:-1], dim=-1)
|
||||
frame_diffs = (sample_features[1:] - sample_features[:-1]).pow(2).sum(dim=-1).sqrt()
|
||||
avg_frame_diff = frame_diffs.mean().item()
|
||||
max_frame_diff = frame_diffs.max().item()
|
||||
min_frame_diff = frame_diffs.min().item()
|
||||
@@ -468,10 +468,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Overall batch statistics
|
||||
if B > 1 and T > 1:
|
||||
all_diffs = torch.norm(
|
||||
vision_features[:, 1:, :] - vision_features[:, :-1, :],
|
||||
dim=-1
|
||||
).flatten()
|
||||
all_diffs = (
|
||||
vision_features[:, 1:, :] - vision_features[:, :-1, :]
|
||||
).pow(2).sum(dim=-1).sqrt().flatten()
|
||||
print(f"Batch-wide frame differences: mean={all_diffs.mean():.6f}, "
|
||||
f"std={all_diffs.std():.6f}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user