simple eval

This commit is contained in:
Pepijn
2025-08-31 14:11:47 +02:00
parent 28298fbe78
commit 4557655ab1

View File

@@ -165,10 +165,10 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame
window = torch.cat([pad, window], dim=0)
# CRITICAL FIX: Use the MLP corresponding to the frame's temporal position
# Frame 0 -> MLP[0], Frame 1 -> MLP[1], ..., Frame 15+ -> MLP[15]
# This matches how the model was trained with different MLPs for different temporal positions
frame_pos = min(i, L - 1) # Clamp to available MLP range [0, 15]
# IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode
# This ensures we use all 16 frame-specific MLPs and get varied outputs
# Frames 0-15 use MLPs 0-15, frames 16-31 use MLPs 0-15 again, etc.
frame_pos = i % L # Cycle through [0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ...]
windows.append(window)
frame_positions.append(frame_pos)