From 16e82fd29f47e4a990caa1a2e2c085eb9203e016 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 19:31:27 +0200 Subject: [PATCH] fix stride unique samplin --- .../policies/rlearn/modeling_rlearn.py | 55 +++++++++++-------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 80e72b465..25f832147 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -983,33 +983,44 @@ class RLearNPolicy(PreTrainedPolicy): ep_length = ep_end - ep_start episode_lengths.append(ep_length) - # Choose random anchor - need enough frames before for stride sampling - # For T=16 and stride=2, we need frames [anchor-30, anchor-28, ..., anchor-2, anchor] + # Proper window-relative stride sampling within available frames stride = self.config.temporal_sampling_stride - min_anchor = (T - 1) * stride # Need (T-1)*stride frames before anchor - max_anchor = max(min_anchor, ep_length - 1) - anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item() - anchor_positions.append(anchor) - - # Build window indices with configurable stride sampling and reflection padding + # Ensure we have room for T frames at given stride; shrink stride if needed + if available_T <= 1: + effective_stride = 1 + else: + effective_stride = max(1, min(stride, (available_T - 1) // max(T - 1, 1) if (T - 1) > 0 else 1)) + min_anchor_in_window = (T - 1) * effective_stride + max_anchor_in_window = max(min_anchor_in_window, available_T - 1) + anchor_in_window = torch.randint(min_anchor_in_window, max_anchor_in_window + 1, (1,)).item() + + # Convert window-anchor to episode-anchor (absolute frame index within episode) + cur_frame_idx = frame_indices[b_idx].item() + anchor_abs = cur_frame_idx + (anchor_in_window - (available_T - 1)) + anchor_abs = int(max(0, min(anchor_abs, ep_length - 1))) + anchor_positions.append(anchor_abs) + + # Build window indices with stride and reflection within [0, available_T) window_indices = [] - frame_indices_for_progress = [] # Track actual frame positions for progress + frame_indices_for_progress = [] # Episode-relative absolute indices for progress had_oob = False - # Sample with stride: [anchor-(T-1)*stride, anchor-(T-2)*stride, ..., anchor-stride, anchor] for i in range(T): - delta = -(T - 1 - i) * stride # Work backwards from anchor with stride spacing - idx = anchor + delta - actual_frame_idx = idx # Store the actual frame index before reflection - if idx < 0: - idx = -idx # Reflect at start + delta = -(T - 1 - i) * effective_stride + w_idx = anchor_in_window + delta + if w_idx < 0: + w_idx = -w_idx had_oob = True - elif idx >= ep_length: - idx = 2 * (ep_length - 1) - idx # Reflect at end + elif w_idx >= available_T: + w_idx = 2 * (available_T - 1) - w_idx had_oob = True - window_indices.append(min(idx, available_T - 1)) - # For reflected indices, use the reflected position for progress - frame_indices_for_progress.append(idx) - + w_idx = max(0, min(w_idx, available_T - 1)) + window_indices.append(w_idx) + + # Map window index back to episode-relative absolute frame index + abs_idx = cur_frame_idx + (w_idx - (available_T - 1)) + abs_idx = int(max(0, min(abs_idx, ep_length - 1))) + frame_indices_for_progress.append(abs_idx) + if had_oob: oob_count += 1 @@ -1021,7 +1032,7 @@ class RLearNPolicy(PreTrainedPolicy): # DEBUG: Check if stride sampling is producing different frames if torch.rand(1).item() < 0.1 and b_idx == 0: # Debug first sample occasionally print(f"\nšŸ” STRIDE SAMPLING DEBUG (Sample {b_idx}):") - print(f"Episode length: {ep_length}, Anchor: {anchor}") + print(f"Episode length: {ep_length}, Anchor(abs): {anchor_abs}, Anchor(win): {anchor_in_window}, eff_stride: {effective_stride}") print(f"Window indices: {window_indices[:5]}...{window_indices[-5:]}") # First and last 5 print(f"Frame indices for progress: {frame_indices_for_progress[:5]}...{frame_indices_for_progress[-5:]}")