mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
huberman loss
This commit is contained in:
@@ -82,6 +82,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# First-frame positional embedding (only applied to the first video frame)
|
||||
self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model))
|
||||
# Full temporal positional embeddings (length = max_seq_len)
|
||||
self.max_time = config.max_seq_len
|
||||
self.time_pos = nn.Parameter(torch.zeros(1, self.max_time, config.dim_model))
|
||||
nn.init.trunc_normal_(self.time_pos, std=0.02)
|
||||
|
||||
# Cross-modal sequential aggregator – causal transformer over
|
||||
# [language tokens | video frame tokens] using PyTorch TransformerEncoder
|
||||
@@ -182,18 +186,30 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
if flat.max() > 1.0:
|
||||
flat = flat / 255.0
|
||||
|
||||
# Prepare inputs for processor: feed uint8 to avoid double-rescale; let processor rescale/normalize
|
||||
flat_clamped = flat.clamp(0.0, 1.0)
|
||||
flat_uint8 = (flat_clamped * 255.0).round().to(torch.uint8) # (BT, C, H, W)
|
||||
flat_numpy = flat_uint8.permute(0, 2, 3, 1).cpu().numpy() # (BT, H, W, C) uint8
|
||||
images_list = [flat_numpy[i] for i in range(B * T)]
|
||||
|
||||
# Process images with SigLIP2 processor
|
||||
inputs = self.processor(images=images_list, return_tensors="pt")
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Process in batch through SigLIP2 vision tower
|
||||
vision_outputs = self.vision_model(**inputs)
|
||||
# GPU-friendly image preprocessing (resize + normalize) without Python loops
|
||||
iproc = getattr(self.processor, 'image_processor', None)
|
||||
if iproc is not None:
|
||||
size_cfg = getattr(iproc, 'size', {})
|
||||
if isinstance(size_cfg, dict):
|
||||
target_h = size_cfg.get('height', size_cfg.get('shortest_edge', 224))
|
||||
target_w = size_cfg.get('width', target_h)
|
||||
else:
|
||||
target_h = target_w = 224
|
||||
mean = torch.tensor(getattr(iproc, 'image_mean', [0.5, 0.5, 0.5]), device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1)
|
||||
std = torch.tensor(getattr(iproc, 'image_std', [0.5, 0.5, 0.5]), device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1)
|
||||
else:
|
||||
target_h = target_w = 224
|
||||
mean = torch.tensor([0.5, 0.5, 0.5], device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1)
|
||||
std = torch.tensor([0.5, 0.5, 0.5], device=flat.device, dtype=flat.dtype).view(1, 3, 1, 1)
|
||||
|
||||
flat = flat.to(device, non_blocking=True)
|
||||
flat = torch.nn.functional.interpolate(flat, size=(target_h, target_w), mode='bilinear', align_corners=False)
|
||||
pixel_values = (flat - mean.to(device)) / std.to(device)
|
||||
pixel_values = pixel_values.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
use_amp = getattr(self.config, 'use_amp', False) and torch.cuda.is_available()
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
|
||||
vision_outputs = self.vision_model(pixel_values=pixel_values)
|
||||
|
||||
# Prefer CLS token from last_hidden_state at index 0
|
||||
if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None:
|
||||
@@ -249,7 +265,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
print(f" ✓ Input frames are different. Diff: {raw_frame_diffs:.6f}")
|
||||
|
||||
# Check processed pixel values
|
||||
first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels
|
||||
first_sample_pixels = pixel_values[:T]
|
||||
if T > 1:
|
||||
# FIXED: Use proper tensor operations
|
||||
pixel_frame_diffs = (first_sample_pixels[1:] - first_sample_pixels[:-1]).pow(2).sum(dim=(1, 2, 3)).sqrt()
|
||||
@@ -399,7 +415,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# SigLIP2 CLS per-frame already returned
|
||||
video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision)
|
||||
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D)
|
||||
# First-frame positional embedding only
|
||||
# Add temporal positional embeddings
|
||||
video_tokens = video_tokens + self.time_pos[:, :T_eff, :]
|
||||
# Optional: keep a first-frame tag
|
||||
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
|
||||
|
||||
# Build masks for TransformerEncoder
|
||||
@@ -457,7 +475,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
target_expanded = target # (B, T_eff)
|
||||
eps = self.config.logit_eps
|
||||
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
||||
loss = F.mse_loss(raw_logits, target_logits)
|
||||
# Robust Huber (Smooth L1) on logits
|
||||
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.5)
|
||||
total_loss = loss
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user