mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
clean
This commit is contained in:
@@ -43,24 +43,28 @@ class RLearNConfig(PreTrainedConfig):
|
||||
text_model_name: str = "google/siglip2-base-patch16-224"
|
||||
freeze_backbones: bool = True
|
||||
|
||||
# Temporal aggregator
|
||||
dim_model: int = 512
|
||||
n_heads: int = 8
|
||||
n_layers: int = 4
|
||||
dim_feedforward: int = 2048
|
||||
dropout: float = 0.1
|
||||
pre_norm: bool = True
|
||||
frame_dropout_p: float = 0.0
|
||||
stride: int = 1
|
||||
|
||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||
max_seq_len: int = 16
|
||||
# Temporal sampling stride (2 = skip every other frame for wider temporal coverage)
|
||||
temporal_sampling_stride: int = 2
|
||||
|
||||
# Model dimensions and transformer
|
||||
dim_model: int = 512
|
||||
num_layers: int = 4
|
||||
num_heads: int = 8
|
||||
ff_mult: int = 4 # Feed-forward multiplier, hidden = dim_model * ff_mult
|
||||
dropout: float = 0.10
|
||||
num_register_tokens: int = 4
|
||||
|
||||
# Inference-time subsampling and regularization
|
||||
inference_stride: int = 1
|
||||
frame_dropout_p: float = 0.10
|
||||
|
||||
# Training
|
||||
learning_rate: float = 1e-3
|
||||
weight_decay: float = 0.01
|
||||
head_lr_multiplier: float = 5.0
|
||||
logit_eps: float = 1e-4
|
||||
|
||||
# Performance optimizations
|
||||
use_amp: bool = True
|
||||
@@ -71,18 +75,6 @@ class RLearNConfig(PreTrainedConfig):
|
||||
rewind_last3_prob: float = 0.3
|
||||
mismatch_prob: float = 0.2
|
||||
|
||||
# Logit regression (only supported mode) - FIXED: Larger eps to prevent extreme targets
|
||||
logit_eps: float = 0.02 # Was 1e-6 → logit(±13.8), now 0.02 → logit(±3.9)
|
||||
head_lr_multiplier: float = 10.0
|
||||
head_weight_init_std: float = 0.05
|
||||
# Initialize head bias toward this target probability to avoid 0.5 plateau
|
||||
head_initial_bias_target: float = 0.3
|
||||
|
||||
# Reward head architecture - FIXED: Simpler architecture to prevent flat basins
|
||||
head_hidden_dim: int = 1024 # Hidden dimension for reward head
|
||||
head_num_layers: int = 2 # REDUCED: 2 layers instead of 4 to prevent over-regularization
|
||||
head_dropout: float = 0.05 # REDUCED: Less dropout to prevent conservatism
|
||||
|
||||
# Normalization presets
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -90,10 +82,6 @@ class RLearNConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture
|
||||
num_register_tokens: int = 4
|
||||
mlp_predictor_depth: int = 3
|
||||
|
||||
# Required path to episodes.jsonl for episode boundaries
|
||||
episodes_jsonl_path: str | None = "meta/episodes.jsonl"
|
||||
|
||||
|
||||
@@ -43,8 +43,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
- Visual encoder: frozen DinoV3 encoder, returns per-frame embeddings.
|
||||
- Text encoder: frozen SigLIP2, returns a language embedding.
|
||||
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
|
||||
- Output: per-timestep rewards via simple linear regression head.
|
||||
|
||||
"""
|
||||
|
||||
config_class = RLearNConfig
|
||||
@@ -90,71 +89,43 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
self.vision_model.eval()
|
||||
self.text_model.eval()
|
||||
|
||||
# x_transformers Decoder (matching ReWiND exactly)
|
||||
self.decoder = Decoder(
|
||||
dim=config.dim_model,
|
||||
depth=config.n_layers,
|
||||
heads=config.n_heads,
|
||||
attn_dim_head=64, # ReWiND default
|
||||
ff_mult=config.dim_feedforward // config.dim_model, # Convert to multiplier
|
||||
# Note: x_transformers uses attn_dropout and ff_dropout separately
|
||||
attn_dropout=config.dropout,
|
||||
ff_dropout=config.dropout,
|
||||
)
|
||||
|
||||
# Linear projections to the shared temporal model dimension
|
||||
self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model)
|
||||
self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model)
|
||||
|
||||
# Stronger temporal positional encoding
|
||||
self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.1)
|
||||
# Spatial (patch) positional encoding for patch tokens
|
||||
self.max_patch_tokens = getattr(config, 'max_patch_tokens', 256)
|
||||
self.spatial_pos_embedding = nn.Parameter(torch.randn(self.max_patch_tokens, config.dim_model) * 0.1)
|
||||
|
||||
# Single MLP processes all frames
|
||||
self.frame_mlp = nn.Linear(config.dim_model, config.dim_model)
|
||||
|
||||
# Register / memory / attention sink tokens
|
||||
self.num_register_tokens = config.num_register_tokens
|
||||
self.register_tokens = nn.Parameter(torch.randn(config.num_register_tokens, config.dim_model) * 1e-2)
|
||||
# First-frame positional embedding (only applied to the first video frame)
|
||||
self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model))
|
||||
|
||||
# MLP predictor (matching ReWiND's Feedforwards)
|
||||
from x_mlps_pytorch import Feedforwards
|
||||
self.mlp_predictor = Feedforwards(
|
||||
# Cross-modal sequential aggregator – causal transformer over
|
||||
# [language tokens | video frame tokens]
|
||||
self.decoder = Decoder(
|
||||
dim=config.dim_model,
|
||||
dim_out=None,
|
||||
depth=config.mlp_predictor_depth
|
||||
depth=config.num_layers,
|
||||
heads=config.num_heads,
|
||||
ff_mult=config.ff_mult,
|
||||
attn_dropout=config.dropout,
|
||||
ff_dropout=config.dropout,
|
||||
cross_attend=False,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Per-frame predictor head
|
||||
self.frame_mlp = nn.Sequential(
|
||||
nn.LayerNorm(config.dim_model),
|
||||
nn.Linear(config.dim_model, config.dim_model),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
)
|
||||
|
||||
# FIXED: Simpler head architecture to prevent constant output pathology
|
||||
# Remove LayerNorm (causes flat basin), reduce depth, larger init, less dropout
|
||||
|
||||
# Simple 2-layer MLP with larger initialization to encourage exploration
|
||||
self.reward_head = nn.Sequential(
|
||||
nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.05), # Reduced dropout to prevent noise-induced conservatism
|
||||
nn.Linear(config.head_hidden_dim, 1)
|
||||
nn.Linear(config.dim_model, config.dim_model),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.dim_model, 1),
|
||||
)
|
||||
|
||||
# FIXED: Larger weight initialization + head bias warm-start to escape 0.5 plateau
|
||||
with torch.no_grad():
|
||||
for i, module in enumerate(self.reward_head):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Use Xavier/Glorot initialization for better gradient flow
|
||||
nn.init.xavier_uniform_(module.weight, gain=1.0)
|
||||
nn.init.zeros_(module.bias)
|
||||
# Set last layer bias to logit(target0) where target0 is a prior (e.g., 0.3)
|
||||
target0 = float(getattr(self.config, 'head_initial_bias_target', 0.3))
|
||||
target0 = min(max(target0, 1e-3), 1 - 1e-3)
|
||||
initial_bias = torch.log(torch.tensor(target0) / (1 - torch.tensor(target0)))
|
||||
last_linear: nn.Linear = self.reward_head[-1] # type: ignore
|
||||
last_linear.bias.copy_(initial_bias)
|
||||
|
||||
# Simple frame dropout probability
|
||||
self.frame_dropout_p = config.frame_dropout_p
|
||||
self.stride = max(1, config.stride)
|
||||
|
||||
# Sampling and regularization knobs
|
||||
self.stride = max(1, int(config.inference_stride))
|
||||
self.frame_dropout_p = float(config.frame_dropout_p)
|
||||
|
||||
# Auto-load episode_data_index from episodes.jsonl if not provided
|
||||
if self.episode_data_index is None and getattr(config, "episodes_jsonl_path", None):
|
||||
@@ -203,102 +174,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor: # Required by base class
|
||||
raise NotImplementedError("RLearN is a reward model and does not select actions")
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict per-timestep rewards for evaluation using ReWiND architecture.
|
||||
|
||||
Args:
|
||||
batch: Input batch with OBS_IMAGES and optionally OBS_LANGUAGE
|
||||
|
||||
Returns:
|
||||
Predicted rewards tensor of shape (B, T)
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Extract frames and form (B, T, C, H, W)
|
||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||
B, T, C, H, W = frames.shape
|
||||
|
||||
# CRITICAL FIX: Do NOT apply stride during evaluation
|
||||
# During evaluation, we want to process all frames in the sliding window
|
||||
# Stride should only be used during training to reduce computational cost
|
||||
T_eff = T # Use all frames during evaluation
|
||||
|
||||
# Get language commands
|
||||
commands = batch.get(OBS_LANGUAGE, None)
|
||||
if commands is None:
|
||||
commands = [""] * B
|
||||
elif not isinstance(commands, list):
|
||||
commands = [str(commands)] * B
|
||||
|
||||
# Forward through ReWiND model (inference mode)
|
||||
device = next(self.parameters()).device
|
||||
frames = frames.to(device)
|
||||
|
||||
# Process video frames -> patch tokens per frame
|
||||
video_patch_embeds = self._encode_video_frames(frames).to(device) # (B, T, P, D_vision)
|
||||
|
||||
# Language embeddings + mask
|
||||
lang_embeds, mask = self._encode_language_tokens(commands, device)
|
||||
|
||||
# Register tokens
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
|
||||
|
||||
# Project embeddings
|
||||
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||
video_tokens = self.to_video_tokens(video_patch_embeds) # (B, T, P, D)
|
||||
# Add temporal + spatial positional encoding (window-relative time + patch index)
|
||||
Bv, T_video, P_video, Dm = video_tokens.shape
|
||||
if P_video > self.spatial_pos_embedding.shape[0]:
|
||||
raise ValueError(f"Number of patch tokens {P_video} exceeds max_patch_tokens {self.spatial_pos_embedding.shape[0]}")
|
||||
t_pos = self.temporal_pos_embedding[:T_video] # (T, D)
|
||||
p_pos = self.spatial_pos_embedding[:P_video] # (P, D)
|
||||
pos = t_pos[:, None, :] + p_pos[None, :, :] # (T, P, D)
|
||||
video_tokens = video_tokens + pos # broadcast over batch
|
||||
# Flatten patch dimension for attention
|
||||
video_tokens = rearrange(video_tokens, 'b t p d -> b (t p) d')
|
||||
|
||||
# Pack all tokens for attention
|
||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||
|
||||
# Extend mask for register and video tokens
|
||||
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||
|
||||
# Forward through decoder
|
||||
attended = self.decoder(tokens, mask=mask)
|
||||
|
||||
# Unpack and get video token features
|
||||
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') # (B, T*P, D)
|
||||
# Restore (B, T, P, D) and pool patches per frame
|
||||
attended_video_tokens = rearrange(attended_video_tokens, 'b (t p) d -> b t p d', t=T_video, p=P_video)
|
||||
frame_tokens = attended_video_tokens.mean(dim=2) # (B, T, D)
|
||||
frame_tokens = self.frame_mlp(frame_tokens)
|
||||
|
||||
# MLP predictor
|
||||
video_frame_embeds = self.mlp_predictor(frame_tokens)
|
||||
|
||||
# Get rewards via temporal-aware logit regression head (no pre-normalization)
|
||||
|
||||
# Add temporal position information
|
||||
B, T_pred = video_frame_embeds.shape[:2]
|
||||
temporal_pos = torch.linspace(0, 1, T_pred, device=video_frame_embeds.device)
|
||||
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_pred, 1) # (B, T, 1)
|
||||
|
||||
# Concatenate embeddings with temporal position
|
||||
temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T, D+1)
|
||||
|
||||
# Forward through temporal-aware head
|
||||
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T)
|
||||
return torch.sigmoid(raw_logits) # Apply sigmoid for final predictions
|
||||
|
||||
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# Initial version: no-op; rely on upstream processors if any
|
||||
return batch
|
||||
|
||||
def normalize_targets(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# Initial version: no-op
|
||||
return batch
|
||||
|
||||
def _encode_video_frames(self, frames: Tensor) -> Tensor:
|
||||
"""Encode video frames through DinoV3 to get per-frame PATCH embeddings.
|
||||
|
||||
@@ -483,11 +358,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
return patch_features
|
||||
|
||||
def _mask_from_lens(self, lens: Tensor) -> Tensor:
|
||||
"""Create mask from sequence lengths."""
|
||||
seq = torch.arange(lens.amax().item(), device=lens.device)
|
||||
return einx.less('n, b -> b n', seq, lens)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Compute ReWiND training loss with on-the-fly progress label generation.
|
||||
|
||||
@@ -500,9 +370,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
import time
|
||||
forward_start = time.perf_counter()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Always use random anchor window sampling
|
||||
frames, anchor_stats = self._sample_random_anchor_windows(batch)
|
||||
@@ -549,45 +416,29 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
lang_time = time.perf_counter() - lang_start
|
||||
|
||||
# Token preparation
|
||||
# Register tokens
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
|
||||
|
||||
# Project embeddings
|
||||
lang_tokens = self.to_lang_tokens(lang_embeds)
|
||||
video_tokens = self.to_video_tokens(video_patch_embeds) # (B, T, P, D)
|
||||
lang_tokens = self.to_lang_tokens(lang_embeds) # (B, L, D)
|
||||
# Collapse patches to per-frame tokens then project
|
||||
video_frame_embeds = video_patch_embeds.mean(dim=2) # (B, T_eff, D_vision)
|
||||
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D)
|
||||
# First-frame positional embedding only
|
||||
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
|
||||
|
||||
# Add temporal + spatial positional encoding (window-relative only)
|
||||
Bv, T_video, P_video, Dm = video_tokens.shape
|
||||
if P_video > self.spatial_pos_embedding.shape[0]:
|
||||
raise ValueError(f"Number of patch tokens {P_video} exceeds max_patch_tokens {self.spatial_pos_embedding.shape[0]}")
|
||||
t_pos = self.temporal_pos_embedding[:T_video] # (T, D)
|
||||
p_pos = self.spatial_pos_embedding[:P_video] # (P, D)
|
||||
pos = t_pos[:, None, :] + p_pos[None, :, :] # (T, P, D)
|
||||
video_tokens = video_tokens + pos
|
||||
# Flatten patches into sequence tokens
|
||||
video_tokens = rearrange(video_tokens, 'b t p d -> b (t p) d')
|
||||
|
||||
# Pack all tokens for attention [lang | register | video]
|
||||
tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d')
|
||||
|
||||
# Extend mask for register and video tokens
|
||||
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||
|
||||
# Forward through x_transformers Decoder
|
||||
# Build attention mask for decoder (True = keep)
|
||||
# Language mask from tokenizer, rest are fully valid
|
||||
full_mask = F.pad(mask, (0, video_tokens.shape[1]), value=True)
|
||||
# Pack and run transformer
|
||||
transformer_start = time.perf_counter()
|
||||
attended = self.decoder(tokens, mask=mask)
|
||||
|
||||
# Unpack and get video token features
|
||||
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') # (B, T*P, D)
|
||||
# Restore (B, T, P, D) and pool patches per frame
|
||||
attended_video_tokens = rearrange(attended_video_tokens, 'b (t p) d -> b t p d', t=T_video, p=P_video)
|
||||
frame_tokens = attended_video_tokens.mean(dim=2) # (B, T, D)
|
||||
frame_tokens = self.frame_mlp(frame_tokens)
|
||||
|
||||
# MLP predictor
|
||||
video_frame_embeds = self.mlp_predictor(frame_tokens)
|
||||
tokens_packed, packed_shape = pack((lang_tokens, video_tokens), 'b * d')
|
||||
attended = self.decoder(tokens_packed, mask=full_mask)
|
||||
attended_lang, attended_video = unpack(attended, packed_shape, 'b * d')
|
||||
transformer_time = time.perf_counter() - transformer_start
|
||||
|
||||
# Per-frame prediction
|
||||
frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D)
|
||||
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
|
||||
predicted_rewards = torch.sigmoid(raw_logits)
|
||||
|
||||
# Generate progress labels on-the-fly (ReWiND approach)
|
||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||
loss_dict: dict[str, float] = {}
|
||||
@@ -609,52 +460,18 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
)
|
||||
target = self._calculate_anchor_based_progress(T_eff, anchor_stats)
|
||||
|
||||
# During inference, we might not want to compute loss
|
||||
if not self.training and target is None:
|
||||
# Return predictions without loss using temporal-aware head
|
||||
|
||||
# Add temporal position information
|
||||
B_inf, T_inf = video_frame_embeds.shape[:2]
|
||||
temporal_pos = torch.linspace(0, 1, T_inf, device=video_frame_embeds.device)
|
||||
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B_inf, T_inf, 1)
|
||||
|
||||
# Concatenate and forward through temporal-aware head
|
||||
temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1)
|
||||
raw_logits = self.reward_head(temporal_input).squeeze(-1)
|
||||
rewards = torch.sigmoid(raw_logits)
|
||||
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
||||
|
||||
# Calculate loss using logit regression
|
||||
# Compute main loss (or just return predictions in eval)
|
||||
loss_start = time.perf_counter()
|
||||
if target is None:
|
||||
total_loss = raw_logits.mean() * 0.0
|
||||
loss = total_loss
|
||||
else:
|
||||
target_expanded = target # (B, T_eff)
|
||||
eps = self.config.logit_eps
|
||||
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
||||
loss = F.mse_loss(raw_logits, target_logits)
|
||||
total_loss = loss
|
||||
|
||||
# Get model outputs with temporal-aware head
|
||||
|
||||
# Add temporal position information
|
||||
temporal_pos = torch.linspace(0, 1, T_eff, device=video_frame_embeds.device)
|
||||
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_eff, 1) # (B, T_eff, 1)
|
||||
|
||||
# Concatenate embeddings with temporal position
|
||||
temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1)
|
||||
|
||||
# Forward through temporal-aware head
|
||||
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T_eff)
|
||||
|
||||
# FIXED: More robust logit regression with gradient protection
|
||||
eps = self.config.logit_eps
|
||||
target_expanded = target.expand(B, -1)[:, :T_eff] # Expand and trim to T_eff
|
||||
target_clamped = torch.clamp(target_expanded, eps, 1 - eps)
|
||||
target_logits = torch.logit(target_clamped)
|
||||
|
||||
# Use Smooth L1 loss instead of MSE for better gradient stability
|
||||
loss = F.smooth_l1_loss(raw_logits, target_logits, reduction='mean', beta=1.0)
|
||||
|
||||
# Clip gradients specifically for the reward head during backward pass
|
||||
# This prevents extreme gradients from corrupting AdamW momentum
|
||||
if self.training:
|
||||
raw_logits.register_hook(lambda grad: torch.clamp(grad, -5.0, 5.0))
|
||||
|
||||
# For logging, compute sigmoid predictions
|
||||
predicted_rewards = torch.sigmoid(raw_logits)
|
||||
|
||||
# Mismatched video-language pairs loss (only when languages actually differ)
|
||||
L_mismatch = torch.zeros((), device=device)
|
||||
@@ -683,54 +500,34 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
|
||||
|
||||
# Pack and forward
|
||||
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
|
||||
mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, video_tokens), 'b * d')
|
||||
mask_mm = F.pad(mask_mm, (0, video_tokens.shape[1]), value=True)
|
||||
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
|
||||
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
|
||||
_, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
|
||||
|
||||
# Process mismatch frames with single MLP
|
||||
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
|
||||
mismatch_embeds = self.mlp_predictor(mismatch_tokens)
|
||||
|
||||
# Predict near-zero progress for mismatched pairs with temporal awareness
|
||||
|
||||
# Add temporal position information for mismatch computation
|
||||
T_mismatch = mismatch_embeds.shape[1]
|
||||
temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=mismatch_embeds.device)
|
||||
temporal_pos_mm = temporal_pos_mm.unsqueeze(0).unsqueeze(-1).expand(B, T_mismatch, 1)
|
||||
|
||||
# Concatenate mismatch embeddings with temporal position
|
||||
temporal_input_mm = torch.cat([mismatch_embeds, temporal_pos_mm], dim=-1)
|
||||
|
||||
# Forward through temporal-aware head
|
||||
mismatch_raw_logits = self.reward_head(temporal_input_mm).squeeze(-1)
|
||||
|
||||
# Create mask tensor for loss calculation
|
||||
mismatch_raw_logits = self.reward_head(mismatch_tokens).squeeze(-1)
|
||||
|
||||
mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool)
|
||||
|
||||
if mismatch_tensor.any():
|
||||
# Target logit corresponding to sigmoid ≈ 0
|
||||
eps = self.config.logit_eps
|
||||
zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps))
|
||||
|
||||
# Only compute loss for samples that are actually mismatched
|
||||
zeros_target_logits = torch.logit(torch.full_like(mismatch_raw_logits, eps))
|
||||
mismatch_loss_per_sample = F.mse_loss(
|
||||
mismatch_raw_logits, zeros_target_logits, reduction='none'
|
||||
).mean(dim=1) # (B,)
|
||||
|
||||
# Apply mask and average only over true mismatches
|
||||
).mean(dim=1)
|
||||
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
|
||||
|
||||
# Total loss
|
||||
total_loss = loss + L_mismatch
|
||||
total_loss = total_loss + L_mismatch
|
||||
loss_time = time.perf_counter() - loss_start
|
||||
|
||||
# DEBUG: Clean logit regression monitoring with full array printing
|
||||
if self.training and torch.rand(1).item() < 0.03:
|
||||
with torch.no_grad():
|
||||
sample_idx = torch.randint(0, B, (1,)).item()
|
||||
sample_targets = target_expanded[sample_idx, :T_eff].cpu().numpy()
|
||||
sample_preds = predicted_rewards[sample_idx].cpu().numpy()
|
||||
sample_targets = target_expanded[sample_idx, :T_eff].cpu().numpy() if target is not None else np.zeros((T_eff,), dtype=np.float32)
|
||||
sample_preds = predicted_rewards[sample_idx].detach().cpu().numpy()
|
||||
|
||||
print(f"\n=== LOGIT REGRESSION DEBUG ===")
|
||||
print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}")
|
||||
@@ -767,7 +564,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
print(f"Sample {sample_idx}: T_eff={T_eff}, target ∈ [{sample_targets.min():.3f}, {sample_targets.max():.3f}], pred ∈ [{sample_preds.min():.3f}, {sample_preds.max():.3f}]")
|
||||
|
||||
print(f"Loss: {loss:.6f}")
|
||||
print(f"Loss: {total_loss:.6f}")
|
||||
print("=" * 60)
|
||||
|
||||
total_forward_time = time.perf_counter() - forward_start
|
||||
@@ -775,15 +572,15 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Log individual loss components
|
||||
loss_dict.update({
|
||||
"loss": float(total_loss.detach().item()),
|
||||
"loss_main": float(loss.detach().item()),
|
||||
"loss_main": float(loss.detach().item() if isinstance(loss, torch.Tensor) else 0.0),
|
||||
"loss_mismatch": float(L_mismatch.detach().item()),
|
||||
"t_eff": float(T_eff),
|
||||
"lang_len_mean": float(mask.sum().float().mean().item()), # Use mask to get actual lengths
|
||||
# Target statistics for monitoring
|
||||
"target_min": float(target.min().item()),
|
||||
"target_max": float(target.max().item()),
|
||||
"target_mean": float(target.mean().item()),
|
||||
"target_std": float(target.std().item()),
|
||||
"target_min": float(target.min().item()) if target is not None else 0.0,
|
||||
"target_max": float(target.max().item()) if target is not None else 0.0,
|
||||
"target_mean": float(target.mean().item()) if target is not None else 0.0,
|
||||
"target_std": float(target.std().item()) if target is not None else 0.0,
|
||||
# Prediction statistics
|
||||
"pred_mean": float(predicted_rewards.mean().item()),
|
||||
"pred_std": float(predicted_rewards.std().item()),
|
||||
@@ -1122,8 +919,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
return torch.stack(all_progress) # (B, T_eff)
|
||||
|
||||
|
||||
|
||||
|
||||
def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]:
|
||||
import json
|
||||
lengths: list[int] = []
|
||||
|
||||
Reference in New Issue
Block a user