change config

This commit is contained in:
Pepijn
2025-09-01 14:37:15 +02:00
parent ee48a80e4d
commit cf0c3f0a9a
3 changed files with 24 additions and 27 deletions

View File

@@ -139,7 +139,7 @@ def extract_episode_frames_and_gt(dataset, episode_idx):
@torch.no_grad()
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"):
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda", temporal_stride: int | None = None):
"""
Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i],
left-pad by repeating the first frame to length L (<= 16), and take the prediction
@@ -147,8 +147,13 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
Returns np.ndarray of shape (T,).
"""
T = frames.shape[0]
L = int(getattr(getattr(model, "config", object()), "max_seq_len", max_seq_len))
cfg = getattr(model, "config", object())
L = int(getattr(cfg, "max_seq_len", max_seq_len))
L = min(L, max_seq_len) # hard-cap at 16
# Use the same temporal stride as training (skip s-1 frames, take 1)
if temporal_stride is None:
temporal_stride = int(getattr(cfg, "temporal_sampling_stride", 1))
temporal_stride = max(1, int(temporal_stride))
# Preprocessed tensor on device
frames = frames.to(device)
@@ -158,21 +163,15 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
left_pad_counts = [] # Number of left-pad (OOB) frames per window
for i in range(T):
start = max(0, i - L + 1)
window = frames[start : i + 1] # (len<=L, C, H, W)
if window.shape[0] < L:
pad_needed = L - window.shape[0]
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame (clamp to frame 0)
window = torch.cat([pad, window], dim=0)
else:
pad_needed = 0
# IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode
# This ensures we use all 16 frame-specific MLPs and get varied outputs
# Frames 0-15 use MLPs 0-15, frames 16-31 use MLPs 0-15 again, etc.
frame_pos = i % L # Cycle through [0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ...]
# Build indices with stride s: [..., i-3, i] etc., left-padded by clamping to 0
idxs = [i - (L - 1 - j) * temporal_stride for j in range(L)]
pad_needed = sum(1 for k in idxs if k < 0)
clamped = [0 if k < 0 else (T - 1 if k >= T else k) for k in idxs]
window = frames[clamped] # (L, C, H, W)
# Use the last temporal position (current frame) for reading model output
frame_pos = L - 1
windows.append(window)
frame_positions.append(frame_pos)
left_pad_counts.append(pad_needed)