From d8c875e069d6bb94acbd9ac2500da11e80df8954 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 20:52:00 +0200 Subject: [PATCH] use patch tokens --- src/lerobot/policies/rlearn/eval_script.py | 13 ++- .../policies/rlearn/modeling_rlearn.py | 103 +++++++++++------- 2 files changed, 73 insertions(+), 43 deletions(-) diff --git a/src/lerobot/policies/rlearn/eval_script.py b/src/lerobot/policies/rlearn/eval_script.py index a3bd0a8cd..05b44903c 100644 --- a/src/lerobot/policies/rlearn/eval_script.py +++ b/src/lerobot/policies/rlearn/eval_script.py @@ -155,6 +155,7 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size= windows = [] frame_positions = [] # Track which temporal position each frame should use + left_pad_counts = [] # Number of left-pad (OOB) frames per window for i in range(T): start = max(0, i - L + 1) @@ -162,8 +163,10 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size= if window.shape[0] < L: pad_needed = L - window.shape[0] - pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame + pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame (clamp to frame 0) window = torch.cat([pad, window], dim=0) + else: + pad_needed = 0 # IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode # This ensures we use all 16 frame-specific MLPs and get varied outputs @@ -172,6 +175,7 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size= windows.append(window) frame_positions.append(frame_pos) + left_pad_counts.append(pad_needed) preds = np.zeros(T, dtype=float) @@ -185,6 +189,13 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size= # Model returns (B, L) predictions for each temporal position values = model.predict_rewards(batch) # torch.Tensor (B, L) + # Apply eval-time padding rule: predictions for left-padded (OOB) frames are zero + if values.dim() == 2 and len(left_pad_counts) >= (e - s): + for b_idx in range(e - s): + pad_n = left_pad_counts[s + b_idx] + if pad_n > 0: + values[b_idx, :pad_n] = 0.0 + # Debug output removed - issue was identified and fixed if values.dim() == 2: diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 4600e8cc0..963267b04 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -108,6 +108,9 @@ class RLearNPolicy(PreTrainedPolicy): # Stronger temporal positional encoding self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.1) + # Spatial (patch) positional encoding for patch tokens + self.max_patch_tokens = getattr(config, 'max_patch_tokens', 256) + self.spatial_pos_embedding = nn.Parameter(torch.randn(self.max_patch_tokens, config.dim_model) * 0.1) # Single MLP processes all frames self.frame_mlp = nn.Linear(config.dim_model, config.dim_model) @@ -226,8 +229,8 @@ class RLearNPolicy(PreTrainedPolicy): device = next(self.parameters()).device frames = frames.to(device) - # Process video frames - video_embeds = self._encode_video_frames(frames).to(device) # (B, T, D_vision) + # Process video frames -> patch tokens per frame + video_patch_embeds = self._encode_video_frames(frames).to(device) # (B, T, P, D_vision) # Language embeddings + mask lang_embeds, mask = self._encode_language_tokens(commands, device) @@ -237,10 +240,17 @@ class RLearNPolicy(PreTrainedPolicy): # Project embeddings lang_tokens = self.to_lang_tokens(lang_embeds) - video_tokens = self.to_video_tokens(video_embeds) - # Add temporal positional encoding (window-relative only) - T_video = video_tokens.shape[1] - video_tokens = video_tokens + self.temporal_pos_embedding[:T_video] + video_tokens = self.to_video_tokens(video_patch_embeds) # (B, T, P, D) + # Add temporal + spatial positional encoding (window-relative time + patch index) + Bv, T_video, P_video, Dm = video_tokens.shape + if P_video > self.spatial_pos_embedding.shape[0]: + raise ValueError(f"Number of patch tokens {P_video} exceeds max_patch_tokens {self.spatial_pos_embedding.shape[0]}") + t_pos = self.temporal_pos_embedding[:T_video] # (T, D) + p_pos = self.spatial_pos_embedding[:P_video] # (P, D) + pos = t_pos[:, None, :] + p_pos[None, :, :] # (T, P, D) + video_tokens = video_tokens + pos # broadcast over batch + # Flatten patch dimension for attention + video_tokens = rearrange(video_tokens, 'b t p d -> b (t p) d') # Pack all tokens for attention tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d') @@ -252,10 +262,11 @@ class RLearNPolicy(PreTrainedPolicy): attended = self.decoder(tokens, mask=mask) # Unpack and get video token features - _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') - - # Process all frames with single MLP - frame_tokens = self.frame_mlp(attended_video_tokens) # (B, T, D) + _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') # (B, T*P, D) + # Restore (B, T, P, D) and pool patches per frame + attended_video_tokens = rearrange(attended_video_tokens, 'b (t p) d -> b t p d', t=T_video, p=P_video) + frame_tokens = attended_video_tokens.mean(dim=2) # (B, T, D) + frame_tokens = self.frame_mlp(frame_tokens) # MLP predictor video_frame_embeds = self.mlp_predictor(frame_tokens) @@ -283,13 +294,13 @@ class RLearNPolicy(PreTrainedPolicy): return batch def _encode_video_frames(self, frames: Tensor) -> Tensor: - """Encode video frames through DinoV3 to get per-frame embeddings. + """Encode video frames through DinoV3 to get per-frame PATCH embeddings. Args: frames: (B, T, C, H, W) Returns: - (B, T, D_vision) + (B, T, P, D_vision) where P is number of patch tokens per frame (excludes CLS) """ B, T, C, H, W = frames.shape flat = rearrange(frames, 'b t c h w -> (b t) c h w') @@ -315,40 +326,40 @@ class RLearNPolicy(PreTrainedPolicy): # Process in batch through DINOv3 model vision_outputs = self.vision_model(**inputs) - # Prefer mean-pooled patch tokens over pooler/CLS to ensure input-dependent variation + # Prefer patch tokens from last_hidden_state (exclude CLS at index 0) if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None: tokens = vision_outputs.last_hidden_state # (BT, N_tokens, D) if tokens.dim() == 3 and tokens.shape[1] > 1: - # Exclude CLS/reg token at index 0, average over patch tokens - vision_features_flat = tokens[:, 1:, :].mean(dim=1) + patch_tokens_flat = tokens[:, 1:, :] # (BT, P, D) else: - # Fallback to first token if only one token is present - vision_features_flat = tokens[:, 0] + # Only one token available → treat as single patch + patch_tokens_flat = tokens[:, :1, :] elif hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None: - vision_features_flat = vision_outputs.pooler_output # (BT, D) + # No per-patch tokens available, synthesize single patch from pooler + patch_tokens_flat = vision_outputs.pooler_output[:, None, :] # (BT, 1, D) else: raise RuntimeError("DINOv3 outputs do not contain last_hidden_state or pooler_output") - # Robustly reshape to (B, T, D): detect correct flatten order by maximizing temporal variance + # Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean) try: - cand1 = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T) - cand2 = rearrange(vision_features_flat, '(t b) d -> b t d', t=T, b=B) - # Compute mean temporal difference per sample - def mean_time_diff(x): + cand1 = rearrange(patch_tokens_flat, '(b t) p d -> b t p d', b=B, t=T) + cand2 = rearrange(patch_tokens_flat, '(t b) p d -> b t p d', t=T, b=B) + def mean_time_diff_4d(x): if T <= 1: return torch.tensor(0.0, device=x.device) - diffs = (x[:, 1:, :] - x[:, :-1, :]).pow(2).sum(dim=-1).sqrt() + x_mean = x.mean(dim=2) # (B, T, D) + diffs = (x_mean[:, 1:, :] - x_mean[:, :-1, :]).pow(2).sum(dim=-1).sqrt() return diffs.mean() - diff1 = mean_time_diff(cand1) - diff2 = mean_time_diff(cand2) - vision_features = cand1 if diff1 >= diff2 else cand2 + diff1 = mean_time_diff_4d(cand1) + diff2 = mean_time_diff_4d(cand2) + patch_features = cand1 if diff1 >= diff2 else cand2 if self.training and torch.rand(1).item() < 0.05: print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}") except Exception: # Fallback to default - vision_features = rearrange(vision_features_flat, '(b t) d -> b t d', b=B, t=T) + patch_features = rearrange(patch_tokens_flat, '(b t) p d -> b t p d', b=B, t=T) - # DEBUG: Analyze vision feature variability + # DEBUG: Analyze vision feature variability (use per-frame pooled features for readability) if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging with torch.no_grad(): print(f"\nšŸ” DINOv3 VISION FEATURE DEBUG (B={B}, T={T}):") @@ -395,7 +406,8 @@ class RLearNPolicy(PreTrainedPolicy): else: print(f" āœ“ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}") - # Check feature statistics + # Check feature statistics (pooled over patches) + vision_features = patch_features.mean(dim=2) # (B, T, D) feature_mean = vision_features.mean().item() feature_std = vision_features.std().item() print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}") @@ -440,7 +452,7 @@ class RLearNPolicy(PreTrainedPolicy): print("=" * 50) - return vision_features + return patch_features def _mask_from_lens(self, lens: Tensor) -> Tensor: """Create mask from sequence lengths.""" @@ -497,9 +509,9 @@ class RLearNPolicy(PreTrainedPolicy): elif not isinstance(commands, list): commands = [str(commands)] * B - # Process video frames through SigLIP2 + # Process video frames through vision encoder (returns patch tokens) vision_start = time.perf_counter() - video_embeds = self._encode_video_frames(frames).to(device) # (B, T_eff, D_vision) + video_patch_embeds = self._encode_video_frames(frames).to(device) # (B, T_eff, P, D_vision) vision_time = time.perf_counter() - vision_start # Language embeddings + mask @@ -513,12 +525,18 @@ class RLearNPolicy(PreTrainedPolicy): # Project embeddings lang_tokens = self.to_lang_tokens(lang_embeds) - video_tokens = self.to_video_tokens(video_embeds) - + video_tokens = self.to_video_tokens(video_patch_embeds) # (B, T, P, D) - # Add temporal positional encoding (window-relative only) - T_video = video_tokens.shape[1] - video_tokens = video_tokens + self.temporal_pos_embedding[:T_video] + # Add temporal + spatial positional encoding (window-relative only) + Bv, T_video, P_video, Dm = video_tokens.shape + if P_video > self.spatial_pos_embedding.shape[0]: + raise ValueError(f"Number of patch tokens {P_video} exceeds max_patch_tokens {self.spatial_pos_embedding.shape[0]}") + t_pos = self.temporal_pos_embedding[:T_video] # (T, D) + p_pos = self.spatial_pos_embedding[:P_video] # (P, D) + pos = t_pos[:, None, :] + p_pos[None, :, :] # (T, P, D) + video_tokens = video_tokens + pos + # Flatten patches into sequence tokens + video_tokens = rearrange(video_tokens, 'b t p d -> b (t p) d') # Pack all tokens for attention [lang | register | video] tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d') @@ -531,10 +549,11 @@ class RLearNPolicy(PreTrainedPolicy): attended = self.decoder(tokens, mask=mask) # Unpack and get video token features - _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') - - # Process all frames with single MLP - frame_tokens = self.frame_mlp(attended_video_tokens) # (B, T, D) + _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') # (B, T*P, D) + # Restore (B, T, P, D) and pool patches per frame + attended_video_tokens = rearrange(attended_video_tokens, 'b (t p) d -> b t p d', t=T_video, p=P_video) + frame_tokens = attended_video_tokens.mean(dim=2) # (B, T, D) + frame_tokens = self.frame_mlp(frame_tokens) # MLP predictor video_frame_embeds = self.mlp_predictor(frame_tokens)