This commit is contained in:
Pepijn
2025-08-28 08:52:48 +02:00
parent 34ca077d78
commit a4c88d6340
3 changed files with 16 additions and 12 deletions

View File

@@ -59,7 +59,7 @@ class RLearNConfig(PreTrainedConfig):
use_tanh_head: bool = False # when True, bound outputs in [-1, 1]
# Training
learning_rate: float = 5e-5 # Reduced for stability
learning_rate: float = 1e-4
weight_decay: float = 0.01
loss_type: str = "composite" # Always use composite loss with spatial awareness
ranking_margin: float = 0.1

View File

@@ -374,7 +374,7 @@ class RLearNPolicy(PreTrainedPolicy):
# Align target with sampled timesteps
if target.dim() == 1:
target = target.unsqueeze(1) # (B, 1)
# Handle target padding to match frame sequence if needed
if target.shape[1] < self.config.max_seq_len:
# Pad targets by repeating the first value (assuming it's the earliest)
@@ -382,10 +382,13 @@ class RLearNPolicy(PreTrainedPolicy):
first_target = target[:, :1] # (B, 1)
padding = first_target.expand(target.shape[0], padding_needed)
target = torch.cat([padding, target], dim=1) # Prepend padding
import logging
logging.debug(f"Padded targets from {target.shape[1] - padding_needed} to {self.config.max_seq_len}")
logging.debug(
f"Padded targets from {target.shape[1] - padding_needed} to {self.config.max_seq_len}"
)
# Now safely index with idx
target = target[:, idx]
@@ -617,12 +620,12 @@ def generate_causal_mask(T: int, device=None) -> Tensor:
def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None) -> Tensor:
"""Extract visual sequence from batch and ensure it has the expected temporal length.
Args:
batch: Input batch containing image data
target_seq_len: Expected sequence length. If provided and the actual sequence is shorter,
it will be padded by repeating the first frame.
Returns:
Tensor of shape (B, T, C, H, W)
"""
@@ -670,7 +673,7 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
f"Available keys with 'image': {image_like_keys}. "
f"All keys: {available_keys}"
)
# Pad sequence if needed
if target_seq_len is not None:
B, T, C, H, W = frames.shape
@@ -680,10 +683,11 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
first_frame = frames[:, :1] # (B, 1, C, H, W)
padding = first_frame.expand(B, padding_needed, C, H, W)
frames = torch.cat([padding, frames], dim=1) # Prepend padding
import logging
logging.debug(f"Padded sequence from {T} to {target_seq_len} frames by repeating first frame")
return frames

View File

@@ -128,11 +128,11 @@ Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$
- Only vlc loss then eval []
- Vlc + rewind loss then eval []
- Convert 1% of bc-z []
- Cleanup code []
- Cleanup code []
- Try DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ? []
- Add more artificial text to dataset generated by vlm (google gemini) []
- See google gemini vlm caption [] https://gemini.google.com/app/7e332ffaf32580f2
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
- How can we improve spatial aware learning? co generating captions for each frame with language decoder?
- Add droid []
- Extend evaluation []