From dad0babbf52db0d0c67d216896c01c5d212c4f43 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 13:54:03 +0200 Subject: [PATCH] simple eval --- src/lerobot/policies/rlearn/eval_script.py | 34 ++++++++++++++++------ 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/lerobot/policies/rlearn/eval_script.py b/src/lerobot/policies/rlearn/eval_script.py index 3500ce96d..30fa38312 100644 --- a/src/lerobot/policies/rlearn/eval_script.py +++ b/src/lerobot/policies/rlearn/eval_script.py @@ -142,7 +142,8 @@ def extract_episode_frames_and_gt(dataset, episode_idx): def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"): """ Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i], - left-pad by repeating the first frame to length L (<= 16), and take the last-step prediction. + left-pad by repeating the first frame to length L (<= 16), and take the prediction + corresponding to the current frame's position in the window. Returns np.ndarray of shape (T,). """ T = frames.shape[0] @@ -153,34 +154,49 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size= frames = frames.to(device) windows = [] + frame_positions = [] # Track which temporal position each frame occupies in its window + for i in range(T): start = max(0, i - L + 1) window = frames[start : i + 1] # (len<=L, C, H, W) - + + # Calculate the temporal position of the current frame within the padded window + actual_window_length = window.shape[0] + if window.shape[0] < L: pad_needed = L - window.shape[0] pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame window = torch.cat([pad, window], dim=0) - + # After padding, the current frame is at position: pad_needed + (actual_window_length - 1) + frame_pos = pad_needed + actual_window_length - 1 + else: + # No padding needed, current frame is at the last position + frame_pos = L - 1 + windows.append(window) + frame_positions.append(frame_pos) preds = np.zeros(T, dtype=float) for s in range(0, T, batch_size): e = min(s + batch_size, T) batch_windows = torch.stack(windows[s:e]) # (B, L, C, H, W) + batch_positions = frame_positions[s:e] batch = {OBS_IMAGES: batch_windows, OBS_LANGUAGE: [language] * (e - s)} # expects (B, L, C, H, W) - # Model should return (B, L) or (B,) final-step values. We take the last step. - values = model.predict_rewards(batch) # torch.Tensor + # Model returns (B, L) predictions for each temporal position + values = model.predict_rewards(batch) # torch.Tensor (B, L) if values.dim() == 2: - last = values[:, -1] + # Extract the prediction corresponding to each frame's position in its window + batch_preds = [] + for b_idx, pos in enumerate(batch_positions): + batch_preds.append(values[b_idx, pos].item()) + preds[s:e] = np.array(batch_preds) else: - last = values.squeeze(-1) - - preds[s:e] = last.detach().float().cpu().numpy() + # Fallback: if model returns (B,), use as is + preds[s:e] = values.detach().float().cpu().numpy() return preds