mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
fix stride unique samplin
This commit is contained in:
@@ -983,33 +983,44 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
ep_length = ep_end - ep_start
|
||||
episode_lengths.append(ep_length)
|
||||
|
||||
# Choose random anchor - need enough frames before for stride sampling
|
||||
# For T=16 and stride=2, we need frames [anchor-30, anchor-28, ..., anchor-2, anchor]
|
||||
# Proper window-relative stride sampling within available frames
|
||||
stride = self.config.temporal_sampling_stride
|
||||
min_anchor = (T - 1) * stride # Need (T-1)*stride frames before anchor
|
||||
max_anchor = max(min_anchor, ep_length - 1)
|
||||
anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item()
|
||||
anchor_positions.append(anchor)
|
||||
|
||||
# Build window indices with configurable stride sampling and reflection padding
|
||||
# Ensure we have room for T frames at given stride; shrink stride if needed
|
||||
if available_T <= 1:
|
||||
effective_stride = 1
|
||||
else:
|
||||
effective_stride = max(1, min(stride, (available_T - 1) // max(T - 1, 1) if (T - 1) > 0 else 1))
|
||||
min_anchor_in_window = (T - 1) * effective_stride
|
||||
max_anchor_in_window = max(min_anchor_in_window, available_T - 1)
|
||||
anchor_in_window = torch.randint(min_anchor_in_window, max_anchor_in_window + 1, (1,)).item()
|
||||
|
||||
# Convert window-anchor to episode-anchor (absolute frame index within episode)
|
||||
cur_frame_idx = frame_indices[b_idx].item()
|
||||
anchor_abs = cur_frame_idx + (anchor_in_window - (available_T - 1))
|
||||
anchor_abs = int(max(0, min(anchor_abs, ep_length - 1)))
|
||||
anchor_positions.append(anchor_abs)
|
||||
|
||||
# Build window indices with stride and reflection within [0, available_T)
|
||||
window_indices = []
|
||||
frame_indices_for_progress = [] # Track actual frame positions for progress
|
||||
frame_indices_for_progress = [] # Episode-relative absolute indices for progress
|
||||
had_oob = False
|
||||
# Sample with stride: [anchor-(T-1)*stride, anchor-(T-2)*stride, ..., anchor-stride, anchor]
|
||||
for i in range(T):
|
||||
delta = -(T - 1 - i) * stride # Work backwards from anchor with stride spacing
|
||||
idx = anchor + delta
|
||||
actual_frame_idx = idx # Store the actual frame index before reflection
|
||||
if idx < 0:
|
||||
idx = -idx # Reflect at start
|
||||
delta = -(T - 1 - i) * effective_stride
|
||||
w_idx = anchor_in_window + delta
|
||||
if w_idx < 0:
|
||||
w_idx = -w_idx
|
||||
had_oob = True
|
||||
elif idx >= ep_length:
|
||||
idx = 2 * (ep_length - 1) - idx # Reflect at end
|
||||
elif w_idx >= available_T:
|
||||
w_idx = 2 * (available_T - 1) - w_idx
|
||||
had_oob = True
|
||||
window_indices.append(min(idx, available_T - 1))
|
||||
# For reflected indices, use the reflected position for progress
|
||||
frame_indices_for_progress.append(idx)
|
||||
|
||||
w_idx = max(0, min(w_idx, available_T - 1))
|
||||
window_indices.append(w_idx)
|
||||
|
||||
# Map window index back to episode-relative absolute frame index
|
||||
abs_idx = cur_frame_idx + (w_idx - (available_T - 1))
|
||||
abs_idx = int(max(0, min(abs_idx, ep_length - 1)))
|
||||
frame_indices_for_progress.append(abs_idx)
|
||||
|
||||
if had_oob:
|
||||
oob_count += 1
|
||||
|
||||
@@ -1021,7 +1032,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# DEBUG: Check if stride sampling is producing different frames
|
||||
if torch.rand(1).item() < 0.1 and b_idx == 0: # Debug first sample occasionally
|
||||
print(f"\n🔍 STRIDE SAMPLING DEBUG (Sample {b_idx}):")
|
||||
print(f"Episode length: {ep_length}, Anchor: {anchor}")
|
||||
print(f"Episode length: {ep_length}, Anchor(abs): {anchor_abs}, Anchor(win): {anchor_in_window}, eff_stride: {effective_stride}")
|
||||
print(f"Window indices: {window_indices[:5]}...{window_indices[-5:]}") # First and last 5
|
||||
print(f"Frame indices for progress: {frame_indices_for_progress[:5]}...{frame_indices_for_progress[-5:]}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user