change config

This commit is contained in:
Pepijn
2025-09-01 14:37:15 +02:00
parent ee48a80e4d
commit cf0c3f0a9a
3 changed files with 24 additions and 27 deletions

View File

@@ -45,7 +45,7 @@ class RLearNConfig(PreTrainedConfig):
# Sequence length, amount of past frames including current one to use in the temporal model
max_seq_len: int = 16
# Temporal sampling stride (2 = skip every other frame for wider temporal coverage)
# Temporal sampling stride
temporal_sampling_stride: int = 3 # Open x mostly has fps 10, and rewind has seq len 16, ours is 30fps so 30/10 = 3 stride lenght to have same timeframe!
# Model dimensions and transformer
@@ -53,8 +53,7 @@ class RLearNConfig(PreTrainedConfig):
num_layers: int = 4
num_heads: int = 8
ff_mult: int = 4 # Feed-forward multiplier, hidden = dim_model * ff_mult
dropout: float = 0.10
num_register_tokens: int = 4
dropout: float = 0.05
# --- reward head options ---
use_categorical_rewards: bool = False # classification over bins
@@ -69,7 +68,7 @@ class RLearNConfig(PreTrainedConfig):
frame_dropout_p: float = 0.10
# Training
learning_rate: float = 1e-3
learning_rate: float = 5e-4
weight_decay: float = 0.01
head_lr_multiplier: float = 5.0
logit_eps: float = 1e-4
@@ -79,9 +78,9 @@ class RLearNConfig(PreTrainedConfig):
compile_model: bool = True
# ReWiND augmentation
rewind_prob: float = 0.8
rewind_last3_prob: float = 0.3
mismatch_prob: float = 0.2
rewind_prob: float = 0.3 #0.8
rewind_last3_prob: float = 0.0 #0.3
mismatch_prob: float = 0.0# 0.2
# Normalization presets
normalization_mapping: dict[str, NormalizationMode] = field(

View File

@@ -139,7 +139,7 @@ def extract_episode_frames_and_gt(dataset, episode_idx):
@torch.no_grad()
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"):
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda", temporal_stride: int | None = None):
"""
Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i],
left-pad by repeating the first frame to length L (<= 16), and take the prediction
@@ -147,8 +147,13 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
Returns np.ndarray of shape (T,).
"""
T = frames.shape[0]
L = int(getattr(getattr(model, "config", object()), "max_seq_len", max_seq_len))
cfg = getattr(model, "config", object())
L = int(getattr(cfg, "max_seq_len", max_seq_len))
L = min(L, max_seq_len) # hard-cap at 16
# Use the same temporal stride as training (skip s-1 frames, take 1)
if temporal_stride is None:
temporal_stride = int(getattr(cfg, "temporal_sampling_stride", 1))
temporal_stride = max(1, int(temporal_stride))
# Preprocessed tensor on device
frames = frames.to(device)
@@ -158,21 +163,15 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
left_pad_counts = [] # Number of left-pad (OOB) frames per window
for i in range(T):
start = max(0, i - L + 1)
window = frames[start : i + 1] # (len<=L, C, H, W)
if window.shape[0] < L:
pad_needed = L - window.shape[0]
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame (clamp to frame 0)
window = torch.cat([pad, window], dim=0)
else:
pad_needed = 0
# IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode
# This ensures we use all 16 frame-specific MLPs and get varied outputs
# Frames 0-15 use MLPs 0-15, frames 16-31 use MLPs 0-15 again, etc.
frame_pos = i % L # Cycle through [0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ...]
# Build indices with stride s: [..., i-3, i] etc., left-padded by clamping to 0
idxs = [i - (L - 1 - j) * temporal_stride for j in range(L)]
pad_needed = sum(1 for k in idxs if k < 0)
clamped = [0 if k < 0 else (T - 1 if k >= T else k) for k in idxs]
window = frames[clamped] # (L, C, H, W)
# Use the last temporal position (current frame) for reading model output
frame_pos = L - 1
windows.append(window)
frame_positions.append(frame_pos)
left_pad_counts.append(pad_needed)

View File

@@ -509,7 +509,7 @@ class RLearNPolicy(PreTrainedPolicy):
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
# Total loss
total_loss = total_loss + L_mismatch + 0.3 * L_rank + 0.05 * L_flat
total_loss = total_loss + L_mismatch
loss_time = time.perf_counter() - loss_start
# DEBUG: Clean logit regression monitoring with full array printing
@@ -571,8 +571,7 @@ class RLearNPolicy(PreTrainedPolicy):
"loss": float(total_loss.detach().item()),
"loss_main": float(loss.detach().item() if isinstance(loss, torch.Tensor) else 0.0),
"loss_mismatch": float(L_mismatch.detach().item()),
"loss_rank": float(L_rank.detach().item()),
"loss_flat": float(L_flat.detach().item()),
"t_eff": float(T_eff),
"lang_len_mean": float(mask.sum().float().mean().item()), # Use mask to get actual lengths
# Target statistics for monitoring