mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
change config
This commit is contained in:
@@ -45,7 +45,7 @@ class RLearNConfig(PreTrainedConfig):
|
||||
|
||||
# Sequence length, amount of past frames including current one to use in the temporal model
|
||||
max_seq_len: int = 16
|
||||
# Temporal sampling stride (2 = skip every other frame for wider temporal coverage)
|
||||
# Temporal sampling stride
|
||||
temporal_sampling_stride: int = 3 # Open x mostly has fps 10, and rewind has seq len 16, ours is 30fps so 30/10 = 3 stride lenght to have same timeframe!
|
||||
|
||||
# Model dimensions and transformer
|
||||
@@ -53,8 +53,7 @@ class RLearNConfig(PreTrainedConfig):
|
||||
num_layers: int = 4
|
||||
num_heads: int = 8
|
||||
ff_mult: int = 4 # Feed-forward multiplier, hidden = dim_model * ff_mult
|
||||
dropout: float = 0.10
|
||||
num_register_tokens: int = 4
|
||||
dropout: float = 0.05
|
||||
|
||||
# --- reward head options ---
|
||||
use_categorical_rewards: bool = False # classification over bins
|
||||
@@ -69,7 +68,7 @@ class RLearNConfig(PreTrainedConfig):
|
||||
frame_dropout_p: float = 0.10
|
||||
|
||||
# Training
|
||||
learning_rate: float = 1e-3
|
||||
learning_rate: float = 5e-4
|
||||
weight_decay: float = 0.01
|
||||
head_lr_multiplier: float = 5.0
|
||||
logit_eps: float = 1e-4
|
||||
@@ -79,9 +78,9 @@ class RLearNConfig(PreTrainedConfig):
|
||||
compile_model: bool = True
|
||||
|
||||
# ReWiND augmentation
|
||||
rewind_prob: float = 0.8
|
||||
rewind_last3_prob: float = 0.3
|
||||
mismatch_prob: float = 0.2
|
||||
rewind_prob: float = 0.3 #0.8
|
||||
rewind_last3_prob: float = 0.0 #0.3
|
||||
mismatch_prob: float = 0.0# 0.2
|
||||
|
||||
# Normalization presets
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
|
||||
@@ -139,7 +139,7 @@ def extract_episode_frames_and_gt(dataset, episode_idx):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"):
|
||||
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda", temporal_stride: int | None = None):
|
||||
"""
|
||||
Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i],
|
||||
left-pad by repeating the first frame to length L (<= 16), and take the prediction
|
||||
@@ -147,8 +147,13 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
Returns np.ndarray of shape (T,).
|
||||
"""
|
||||
T = frames.shape[0]
|
||||
L = int(getattr(getattr(model, "config", object()), "max_seq_len", max_seq_len))
|
||||
cfg = getattr(model, "config", object())
|
||||
L = int(getattr(cfg, "max_seq_len", max_seq_len))
|
||||
L = min(L, max_seq_len) # hard-cap at 16
|
||||
# Use the same temporal stride as training (skip s-1 frames, take 1)
|
||||
if temporal_stride is None:
|
||||
temporal_stride = int(getattr(cfg, "temporal_sampling_stride", 1))
|
||||
temporal_stride = max(1, int(temporal_stride))
|
||||
|
||||
# Preprocessed tensor on device
|
||||
frames = frames.to(device)
|
||||
@@ -158,21 +163,15 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
left_pad_counts = [] # Number of left-pad (OOB) frames per window
|
||||
|
||||
for i in range(T):
|
||||
start = max(0, i - L + 1)
|
||||
window = frames[start : i + 1] # (len<=L, C, H, W)
|
||||
|
||||
if window.shape[0] < L:
|
||||
pad_needed = L - window.shape[0]
|
||||
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame (clamp to frame 0)
|
||||
window = torch.cat([pad, window], dim=0)
|
||||
else:
|
||||
pad_needed = 0
|
||||
|
||||
# IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode
|
||||
# This ensures we use all 16 frame-specific MLPs and get varied outputs
|
||||
# Frames 0-15 use MLPs 0-15, frames 16-31 use MLPs 0-15 again, etc.
|
||||
frame_pos = i % L # Cycle through [0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ...]
|
||||
|
||||
# Build indices with stride s: [..., i-3, i] etc., left-padded by clamping to 0
|
||||
idxs = [i - (L - 1 - j) * temporal_stride for j in range(L)]
|
||||
pad_needed = sum(1 for k in idxs if k < 0)
|
||||
clamped = [0 if k < 0 else (T - 1 if k >= T else k) for k in idxs]
|
||||
window = frames[clamped] # (L, C, H, W)
|
||||
|
||||
# Use the last temporal position (current frame) for reading model output
|
||||
frame_pos = L - 1
|
||||
|
||||
windows.append(window)
|
||||
frame_positions.append(frame_pos)
|
||||
left_pad_counts.append(pad_needed)
|
||||
|
||||
@@ -509,7 +509,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
|
||||
|
||||
# Total loss
|
||||
total_loss = total_loss + L_mismatch + 0.3 * L_rank + 0.05 * L_flat
|
||||
total_loss = total_loss + L_mismatch
|
||||
loss_time = time.perf_counter() - loss_start
|
||||
|
||||
# DEBUG: Clean logit regression monitoring with full array printing
|
||||
@@ -571,8 +571,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
"loss": float(total_loss.detach().item()),
|
||||
"loss_main": float(loss.detach().item() if isinstance(loss, torch.Tensor) else 0.0),
|
||||
"loss_mismatch": float(L_mismatch.detach().item()),
|
||||
"loss_rank": float(L_rank.detach().item()),
|
||||
"loss_flat": float(L_flat.detach().item()),
|
||||
|
||||
"t_eff": float(T_eff),
|
||||
"lang_len_mean": float(mask.sum().float().mean().item()), # Use mask to get actual lengths
|
||||
# Target statistics for monitoring
|
||||
|
||||
Reference in New Issue
Block a user