diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index a3eff5049..5724c1ca3 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -73,6 +73,11 @@ class RLearNConfig(PreTrainedConfig): logit_eps: float = 1e-6 head_lr_multiplier: float = 2.0 head_weight_init_std: float = 0.05 + + # Reward head architecture + head_hidden_dim: int = 1024 # Hidden dimension for reward head + head_num_layers: int = 4 # Number of layers in reward head + head_dropout: float = 0.1 # Dropout in reward head # Normalization presets normalization_mapping: dict[str, NormalizationMode] = field( diff --git a/src/lerobot/policies/rlearn/eval_script.py b/src/lerobot/policies/rlearn/eval_script.py index 842a921f0..a3bd0a8cd 100644 --- a/src/lerobot/policies/rlearn/eval_script.py +++ b/src/lerobot/policies/rlearn/eval_script.py @@ -141,12 +141,9 @@ def extract_episode_frames_and_gt(dataset, episode_idx): @torch.no_grad() def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"): """ - Sliding-window prediction for episode-relative progress model. - For each frame i, creates a window and extracts the prediction for that specific frame. - - NOTE: This assumes we don't have episode context (episode_index, frame_index, episode_length). - The model will use its fallback logic for window-relative progress. - + Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i], + left-pad by repeating the first frame to length L (<= 16), and take the prediction + corresponding to the current frame's position in the window. Returns np.ndarray of shape (T,). """ T = frames.shape[0] @@ -156,45 +153,49 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size= # Preprocessed tensor on device frames = frames.to(device) - # Simple approach: predict each 16-frame window and take the last prediction - # This assumes the model can handle the lack of episode context gracefully - preds = np.zeros(T, dtype=float) + windows = [] + frame_positions = [] # Track which temporal position each frame should use - # Process non-overlapping windows for efficiency - for start_idx in range(0, T, L): - end_idx = min(start_idx + L, T) - window_frames = frames[start_idx:end_idx] + for i in range(T): + start = max(0, i - L + 1) + window = frames[start : i + 1] # (len<=L, C, H, W) - # Pad if needed - if window_frames.shape[0] < L: - pad_needed = L - window_frames.shape[0] - if start_idx == 0: - # Pad with first frame at beginning - pad = window_frames[:1].expand(pad_needed, -1, -1, -1) - window_frames = torch.cat([pad, window_frames], dim=0) - else: - # Pad with last frame at end - pad = window_frames[-1:].expand(pad_needed, -1, -1, -1) - window_frames = torch.cat([window_frames, pad], dim=0) + if window.shape[0] < L: + pad_needed = L - window.shape[0] + pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame + window = torch.cat([pad, window], dim=0) - # Create batch (batch size = 1) - batch = { - OBS_IMAGES: window_frames.unsqueeze(0), # (1, L, C, H, W) - OBS_LANGUAGE: [language] - } - - # Get predictions for this window - window_preds = model.predict_rewards(batch) # (1, L) - window_preds = window_preds.squeeze(0).cpu().numpy() # (L,) - - # Extract the relevant predictions for the actual frames - actual_frames = min(L, end_idx - start_idx) - if start_idx == 0 and window_frames.shape[0] > actual_frames: - # Skip padding at beginning - preds[start_idx:end_idx] = window_preds[-actual_frames:] + # IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode + # This ensures we use all 16 frame-specific MLPs and get varied outputs + # Frames 0-15 use MLPs 0-15, frames 16-31 use MLPs 0-15 again, etc. + frame_pos = i % L # Cycle through [0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ...] + + windows.append(window) + frame_positions.append(frame_pos) + + preds = np.zeros(T, dtype=float) + + for s in range(0, T, batch_size): + e = min(s + batch_size, T) + batch_windows = torch.stack(windows[s:e]) # (B, L, C, H, W) + batch_positions = frame_positions[s:e] + + batch = {OBS_IMAGES: batch_windows, OBS_LANGUAGE: [language] * (e - s)} # expects (B, L, C, H, W) + + # Model returns (B, L) predictions for each temporal position + values = model.predict_rewards(batch) # torch.Tensor (B, L) + + # Debug output removed - issue was identified and fixed + + if values.dim() == 2: + # Extract the prediction corresponding to each frame's position in its window + batch_preds = [] + for b_idx, pos in enumerate(batch_positions): + batch_preds.append(values[b_idx, pos].item()) + preds[s:e] = np.array(batch_preds) else: - # Take the first predictions (no beginning padding) - preds[start_idx:end_idx] = window_preds[:actual_frames] + # Fallback: if model returns (B,), use as is + preds[s:e] = values.detach().float().cpu().numpy() return preds diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 315df1251..b6768b0d8 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -184,22 +184,36 @@ class RLearNPolicy(PreTrainedPolicy): # Layer normalization before reward head to stabilize MLP outputs self.pre_reward_norm = nn.LayerNorm(config.dim_model) - # Temporal-aware regression head (logit mode only) - # Concatenates frame embedding with normalized temporal position - self.reward_head = nn.Sequential( - nn.Linear(config.dim_model + 1, config.dim_model), # +1 for temporal position - nn.ReLU(), - nn.Linear(config.dim_model, 1) - ) + # Temporal-aware regression head with increased capacity + # Build a deeper MLP for better visual-progress learning + head_layers = [] - # Initialize temporal-aware head for logit regression + # Input layer: embedding + temporal position -> hidden + head_layers.extend([ + nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position + nn.ReLU(), + nn.Dropout(config.head_dropout) + ]) + + # Hidden layers: multiple layers for complex visual-progress mapping + for _ in range(config.head_num_layers - 2): # -2 for input and output layers + head_layers.extend([ + nn.Linear(config.head_hidden_dim, config.head_hidden_dim), + nn.ReLU(), + nn.Dropout(config.head_dropout) + ]) + + # Output layer: hidden -> logit + head_layers.append(nn.Linear(config.head_hidden_dim, 1)) + + self.reward_head = nn.Sequential(*head_layers) + + # Initialize the deeper temporal-aware head for logit regression with torch.no_grad(): - # First layer: embedding + position -> embedding - nn.init.normal_(self.reward_head[0].weight, 0.0, config.head_weight_init_std) - nn.init.zeros_(self.reward_head[0].bias) - # Output layer: embedding -> logit - nn.init.normal_(self.reward_head[2].weight, 0.0, config.head_weight_init_std) - nn.init.zeros_(self.reward_head[2].bias) + for module in self.reward_head: + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, 0.0, config.head_weight_init_std) + nn.init.zeros_(module.bias) # Simple frame dropout probability self.frame_dropout_p = config.frame_dropout_p