mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 21:01:26 +00:00
use patch tokens
This commit is contained in:
@@ -155,6 +155,7 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
|
||||
windows = []
|
||||
frame_positions = [] # Track which temporal position each frame should use
|
||||
left_pad_counts = [] # Number of left-pad (OOB) frames per window
|
||||
|
||||
for i in range(T):
|
||||
start = max(0, i - L + 1)
|
||||
@@ -162,8 +163,10 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
|
||||
if window.shape[0] < L:
|
||||
pad_needed = L - window.shape[0]
|
||||
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame
|
||||
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame (clamp to frame 0)
|
||||
window = torch.cat([pad, window], dim=0)
|
||||
else:
|
||||
pad_needed = 0
|
||||
|
||||
# IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode
|
||||
# This ensures we use all 16 frame-specific MLPs and get varied outputs
|
||||
@@ -172,6 +175,7 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
|
||||
windows.append(window)
|
||||
frame_positions.append(frame_pos)
|
||||
left_pad_counts.append(pad_needed)
|
||||
|
||||
preds = np.zeros(T, dtype=float)
|
||||
|
||||
@@ -185,6 +189,13 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
# Model returns (B, L) predictions for each temporal position
|
||||
values = model.predict_rewards(batch) # torch.Tensor (B, L)
|
||||
|
||||
# Apply eval-time padding rule: predictions for left-padded (OOB) frames are zero
|
||||
if values.dim() == 2 and len(left_pad_counts) >= (e - s):
|
||||
for b_idx in range(e - s):
|
||||
pad_n = left_pad_counts[s + b_idx]
|
||||
if pad_n > 0:
|
||||
values[b_idx, :pad_n] = 0.0
|
||||
|
||||
# Debug output removed - issue was identified and fixed
|
||||
|
||||
if values.dim() == 2:
|
||||
|
||||
Reference in New Issue
Block a user