fix stride unique samplin

This commit is contained in:
Pepijn
2025-08-31 19:31:27 +02:00
parent ae57fe2d33
commit 16e82fd29f

View File

@@ -983,33 +983,44 @@ class RLearNPolicy(PreTrainedPolicy):
ep_length = ep_end - ep_start
episode_lengths.append(ep_length)
# Choose random anchor - need enough frames before for stride sampling
# For T=16 and stride=2, we need frames [anchor-30, anchor-28, ..., anchor-2, anchor]
# Proper window-relative stride sampling within available frames
stride = self.config.temporal_sampling_stride
min_anchor = (T - 1) * stride # Need (T-1)*stride frames before anchor
max_anchor = max(min_anchor, ep_length - 1)
anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item()
anchor_positions.append(anchor)
# Build window indices with configurable stride sampling and reflection padding
# Ensure we have room for T frames at given stride; shrink stride if needed
if available_T <= 1:
effective_stride = 1
else:
effective_stride = max(1, min(stride, (available_T - 1) // max(T - 1, 1) if (T - 1) > 0 else 1))
min_anchor_in_window = (T - 1) * effective_stride
max_anchor_in_window = max(min_anchor_in_window, available_T - 1)
anchor_in_window = torch.randint(min_anchor_in_window, max_anchor_in_window + 1, (1,)).item()
# Convert window-anchor to episode-anchor (absolute frame index within episode)
cur_frame_idx = frame_indices[b_idx].item()
anchor_abs = cur_frame_idx + (anchor_in_window - (available_T - 1))
anchor_abs = int(max(0, min(anchor_abs, ep_length - 1)))
anchor_positions.append(anchor_abs)
# Build window indices with stride and reflection within [0, available_T)
window_indices = []
frame_indices_for_progress = [] # Track actual frame positions for progress
frame_indices_for_progress = [] # Episode-relative absolute indices for progress
had_oob = False
# Sample with stride: [anchor-(T-1)*stride, anchor-(T-2)*stride, ..., anchor-stride, anchor]
for i in range(T):
delta = -(T - 1 - i) * stride # Work backwards from anchor with stride spacing
idx = anchor + delta
actual_frame_idx = idx # Store the actual frame index before reflection
if idx < 0:
idx = -idx # Reflect at start
delta = -(T - 1 - i) * effective_stride
w_idx = anchor_in_window + delta
if w_idx < 0:
w_idx = -w_idx
had_oob = True
elif idx >= ep_length:
idx = 2 * (ep_length - 1) - idx # Reflect at end
elif w_idx >= available_T:
w_idx = 2 * (available_T - 1) - w_idx
had_oob = True
window_indices.append(min(idx, available_T - 1))
# For reflected indices, use the reflected position for progress
frame_indices_for_progress.append(idx)
w_idx = max(0, min(w_idx, available_T - 1))
window_indices.append(w_idx)
# Map window index back to episode-relative absolute frame index
abs_idx = cur_frame_idx + (w_idx - (available_T - 1))
abs_idx = int(max(0, min(abs_idx, ep_length - 1)))
frame_indices_for_progress.append(abs_idx)
if had_oob:
oob_count += 1
@@ -1021,7 +1032,7 @@ class RLearNPolicy(PreTrainedPolicy):
# DEBUG: Check if stride sampling is producing different frames
if torch.rand(1).item() < 0.1 and b_idx == 0: # Debug first sample occasionally
print(f"\n🔍 STRIDE SAMPLING DEBUG (Sample {b_idx}):")
print(f"Episode length: {ep_length}, Anchor: {anchor}")
print(f"Episode length: {ep_length}, Anchor(abs): {anchor_abs}, Anchor(win): {anchor_in_window}, eff_stride: {effective_stride}")
print(f"Window indices: {window_indices[:5]}...{window_indices[-5:]}") # First and last 5
print(f"Frame indices for progress: {frame_indices_for_progress[:5]}...{frame_indices_for_progress[-5:]}")