mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
simple eval
This commit is contained in:
@@ -142,7 +142,8 @@ def extract_episode_frames_and_gt(dataset, episode_idx):
|
||||
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"):
|
||||
"""
|
||||
Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i],
|
||||
left-pad by repeating the first frame to length L (<= 16), and take the last-step prediction.
|
||||
left-pad by repeating the first frame to length L (<= 16), and take the prediction
|
||||
corresponding to the current frame's position in the window.
|
||||
Returns np.ndarray of shape (T,).
|
||||
"""
|
||||
T = frames.shape[0]
|
||||
@@ -153,34 +154,49 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
frames = frames.to(device)
|
||||
|
||||
windows = []
|
||||
frame_positions = [] # Track which temporal position each frame occupies in its window
|
||||
|
||||
for i in range(T):
|
||||
start = max(0, i - L + 1)
|
||||
window = frames[start : i + 1] # (len<=L, C, H, W)
|
||||
|
||||
|
||||
# Calculate the temporal position of the current frame within the padded window
|
||||
actual_window_length = window.shape[0]
|
||||
|
||||
if window.shape[0] < L:
|
||||
pad_needed = L - window.shape[0]
|
||||
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame
|
||||
window = torch.cat([pad, window], dim=0)
|
||||
|
||||
# After padding, the current frame is at position: pad_needed + (actual_window_length - 1)
|
||||
frame_pos = pad_needed + actual_window_length - 1
|
||||
else:
|
||||
# No padding needed, current frame is at the last position
|
||||
frame_pos = L - 1
|
||||
|
||||
windows.append(window)
|
||||
frame_positions.append(frame_pos)
|
||||
|
||||
preds = np.zeros(T, dtype=float)
|
||||
|
||||
for s in range(0, T, batch_size):
|
||||
e = min(s + batch_size, T)
|
||||
batch_windows = torch.stack(windows[s:e]) # (B, L, C, H, W)
|
||||
batch_positions = frame_positions[s:e]
|
||||
|
||||
batch = {OBS_IMAGES: batch_windows, OBS_LANGUAGE: [language] * (e - s)} # expects (B, L, C, H, W)
|
||||
|
||||
# Model should return (B, L) or (B,) final-step values. We take the last step.
|
||||
values = model.predict_rewards(batch) # torch.Tensor
|
||||
# Model returns (B, L) predictions for each temporal position
|
||||
values = model.predict_rewards(batch) # torch.Tensor (B, L)
|
||||
|
||||
if values.dim() == 2:
|
||||
last = values[:, -1]
|
||||
# Extract the prediction corresponding to each frame's position in its window
|
||||
batch_preds = []
|
||||
for b_idx, pos in enumerate(batch_positions):
|
||||
batch_preds.append(values[b_idx, pos].item())
|
||||
preds[s:e] = np.array(batch_preds)
|
||||
else:
|
||||
last = values.squeeze(-1)
|
||||
|
||||
preds[s:e] = last.detach().float().cpu().numpy()
|
||||
# Fallback: if model returns (B,), use as is
|
||||
preds[s:e] = values.detach().float().cpu().numpy()
|
||||
|
||||
return preds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user