diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index d5225d29d..3a28fa770 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -82,6 +82,10 @@ class RLearNPolicy(PreTrainedPolicy): # First-frame positional embedding (only applied to the first video frame) self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model)) + # Full temporal positional embeddings (length = max_seq_len) + self.max_time = config.max_seq_len + self.time_pos = nn.Parameter(torch.zeros(1, self.max_time, config.dim_model)) + nn.init.trunc_normal_(self.time_pos, std=0.02) # Cross-modal sequential aggregator – causal transformer over # [language tokens | video frame tokens] using PyTorch TransformerEncoder @@ -182,18 +186,30 @@ class RLearNPolicy(PreTrainedPolicy): if flat.max() > 1.0: flat = flat / 255.0 - # Prepare inputs for processor: feed uint8 to avoid double-rescale; let processor rescale/normalize - flat_clamped = flat.clamp(0.0, 1.0) - flat_uint8 = (flat_clamped * 255.0).round().to(torch.uint8) # (BT, C, H, W) - flat_numpy = flat_uint8.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) uint8 - images_list = [flat_numpy[i] for i in range(B * T)] - - # Process images with SigLIP2 processor - inputs = self.processor(images=images_list, return_tensors="pt") - inputs = {k: v.to(device) for k, v in inputs.items()} - - # Process in batch through SigLIP2 vision tower - vision_outputs = self.vision_model(**inputs) + # GPU-friendly image preprocessing (resize + normalize) without Python loops + iproc = getattr(self.processor, 'image_processor', None) + if iproc is not None: + size_cfg = getattr(iproc, 'size', {}) + if isinstance(size_cfg, dict): + target_h = size_cfg.get('height', size_cfg.get('shortest_edge', 224)) + target_w = size_cfg.get('width', target_h) + else: + target_h = target_w = 224 + mean = torch.tensor(getattr(iproc, 'image_mean', [0.5, 0.5, 0.5]), device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1) + std = torch.tensor(getattr(iproc, 'image_std', [0.5, 0.5, 0.5]), device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1) + else: + target_h = target_w = 224 + mean = torch.tensor([0.5, 0.5, 0.5], device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1) + std = torch.tensor([0.5, 0.5, 0.5], device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1) + + flat = flat.to(device, non_blocking=True) + flat = torch.nn.functional.interpolate(flat, size=(target_h, target_w), mode='bilinear', align_corners=False) + pixel_values = (flat - mean.to(device)) / std.to(device) + pixel_values = pixel_values.contiguous(memory_format=torch.channels_last) + + use_amp = getattr(self.config, 'use_amp', False) and torch.cuda.is_available() + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp): + vision_outputs = self.vision_model(pixel_values=pixel_values) # Prefer CLS token from last_hidden_state at index 0 if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None: @@ -249,7 +265,7 @@ class RLearNPolicy(PreTrainedPolicy): print(f" ✓ Input frames are different. Diff: {raw_frame_diffs:.6f}") # Check processed pixel values - first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels + first_sample_pixels = pixel_values[:T] if T > 1: # FIXED: Use proper tensor operations pixel_frame_diffs = (first_sample_pixels[1:] - first_sample_pixels[:-1]).pow(2).sum(dim=(1, 2, 3)).sqrt() @@ -399,7 +415,9 @@ class RLearNPolicy(PreTrainedPolicy): # SigLIP2 CLS per-frame already returned video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision) video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D) - # First-frame positional embedding only + # Add temporal positional embeddings + video_tokens = video_tokens + self.time_pos[:, :T_eff, :] + # Optional: keep a first-frame tag video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos # Build masks for TransformerEncoder @@ -457,7 +475,8 @@ class RLearNPolicy(PreTrainedPolicy): target_expanded = target # (B, T_eff) eps = self.config.logit_eps target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps)) - loss = F.mse_loss(raw_logits, target_logits) + # Robust Huber (Smooth L1) on logits + loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.5) total_loss = loss