huberman loss

This commit is contained in:
Pepijn
2025-09-01 11:53:30 +02:00
parent 9a19f8f6f4
commit 0b710932e2

View File

@@ -82,6 +82,10 @@ class RLearNPolicy(PreTrainedPolicy):
# First-frame positional embedding (only applied to the first video frame)
self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model))
# Full temporal positional embeddings (length = max_seq_len)
self.max_time = config.max_seq_len
self.time_pos = nn.Parameter(torch.zeros(1, self.max_time, config.dim_model))
nn.init.trunc_normal_(self.time_pos, std=0.02)
# Cross-modal sequential aggregator causal transformer over
# [language tokens | video frame tokens] using PyTorch TransformerEncoder
@@ -182,18 +186,30 @@ class RLearNPolicy(PreTrainedPolicy):
if flat.max() > 1.0:
flat = flat / 255.0
# Prepare inputs for processor: feed uint8 to avoid double-rescale; let processor rescale/normalize
flat_clamped = flat.clamp(0.0, 1.0)
flat_uint8 = (flat_clamped * 255.0).round().to(torch.uint8) # (BT, C, H, W)
flat_numpy = flat_uint8.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) uint8
images_list = [flat_numpy[i] for i in range(B * T)]
# Process images with SigLIP2 processor
inputs = self.processor(images=images_list, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Process in batch through SigLIP2 vision tower
vision_outputs = self.vision_model(**inputs)
# GPU-friendly image preprocessing (resize + normalize) without Python loops
iproc = getattr(self.processor, 'image_processor', None)
if iproc is not None:
size_cfg = getattr(iproc, 'size', {})
if isinstance(size_cfg, dict):
target_h = size_cfg.get('height', size_cfg.get('shortest_edge', 224))
target_w = size_cfg.get('width', target_h)
else:
target_h = target_w = 224
mean = torch.tensor(getattr(iproc, 'image_mean', [0.5, 0.5, 0.5]), device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1)
std = torch.tensor(getattr(iproc, 'image_std', [0.5, 0.5, 0.5]), device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1)
else:
target_h = target_w = 224
mean = torch.tensor([0.5, 0.5, 0.5], device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1)
std = torch.tensor([0.5, 0.5, 0.5], device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1)
flat = flat.to(device, non_blocking=True)
flat = torch.nn.functional.interpolate(flat, size=(target_h, target_w), mode='bilinear', align_corners=False)
pixel_values = (flat - mean.to(device)) / std.to(device)
pixel_values = pixel_values.contiguous(memory_format=torch.channels_last)
use_amp = getattr(self.config, 'use_amp', False) and torch.cuda.is_available()
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
vision_outputs = self.vision_model(pixel_values=pixel_values)
# Prefer CLS token from last_hidden_state at index 0
if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None:
@@ -249,7 +265,7 @@ class RLearNPolicy(PreTrainedPolicy):
print(f" ✓ Input frames are different. Diff: {raw_frame_diffs:.6f}")
# Check processed pixel values
first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels
first_sample_pixels = pixel_values[:T]
if T > 1:
# FIXED: Use proper tensor operations
pixel_frame_diffs = (first_sample_pixels[1:] - first_sample_pixels[:-1]).pow(2).sum(dim=(1, 2, 3)).sqrt()
@@ -399,7 +415,9 @@ class RLearNPolicy(PreTrainedPolicy):
# SigLIP2 CLS per-frame already returned
video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision)
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D)
# First-frame positional embedding only
# Add temporal positional embeddings
video_tokens = video_tokens + self.time_pos[:, :T_eff, :]
# Optional: keep a first-frame tag
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
# Build masks for TransformerEncoder
@@ -457,7 +475,8 @@ class RLearNPolicy(PreTrainedPolicy):
target_expanded = target # (B, T_eff)
eps = self.config.logit_eps
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
loss = F.mse_loss(raw_logits, target_logits)
# Robust Huber (Smooth L1) on logits
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.5)
total_loss = loss