mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
extend head
This commit is contained in:
@@ -73,6 +73,11 @@ class RLearNConfig(PreTrainedConfig):
|
||||
logit_eps: float = 1e-6
|
||||
head_lr_multiplier: float = 2.0
|
||||
head_weight_init_std: float = 0.05
|
||||
|
||||
# Reward head architecture
|
||||
head_hidden_dim: int = 1024 # Hidden dimension for reward head
|
||||
head_num_layers: int = 4 # Number of layers in reward head
|
||||
head_dropout: float = 0.1 # Dropout in reward head
|
||||
|
||||
# Normalization presets
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
|
||||
@@ -141,12 +141,9 @@ def extract_episode_frames_and_gt(dataset, episode_idx):
|
||||
@torch.no_grad()
|
||||
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"):
|
||||
"""
|
||||
Sliding-window prediction for episode-relative progress model.
|
||||
For each frame i, creates a window and extracts the prediction for that specific frame.
|
||||
|
||||
NOTE: This assumes we don't have episode context (episode_index, frame_index, episode_length).
|
||||
The model will use its fallback logic for window-relative progress.
|
||||
|
||||
Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i],
|
||||
left-pad by repeating the first frame to length L (<= 16), and take the prediction
|
||||
corresponding to the current frame's position in the window.
|
||||
Returns np.ndarray of shape (T,).
|
||||
"""
|
||||
T = frames.shape[0]
|
||||
@@ -156,45 +153,49 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
# Preprocessed tensor on device
|
||||
frames = frames.to(device)
|
||||
|
||||
# Simple approach: predict each 16-frame window and take the last prediction
|
||||
# This assumes the model can handle the lack of episode context gracefully
|
||||
preds = np.zeros(T, dtype=float)
|
||||
windows = []
|
||||
frame_positions = [] # Track which temporal position each frame should use
|
||||
|
||||
# Process non-overlapping windows for efficiency
|
||||
for start_idx in range(0, T, L):
|
||||
end_idx = min(start_idx + L, T)
|
||||
window_frames = frames[start_idx:end_idx]
|
||||
for i in range(T):
|
||||
start = max(0, i - L + 1)
|
||||
window = frames[start : i + 1] # (len<=L, C, H, W)
|
||||
|
||||
# Pad if needed
|
||||
if window_frames.shape[0] < L:
|
||||
pad_needed = L - window_frames.shape[0]
|
||||
if start_idx == 0:
|
||||
# Pad with first frame at beginning
|
||||
pad = window_frames[:1].expand(pad_needed, -1, -1, -1)
|
||||
window_frames = torch.cat([pad, window_frames], dim=0)
|
||||
else:
|
||||
# Pad with last frame at end
|
||||
pad = window_frames[-1:].expand(pad_needed, -1, -1, -1)
|
||||
window_frames = torch.cat([window_frames, pad], dim=0)
|
||||
if window.shape[0] < L:
|
||||
pad_needed = L - window.shape[0]
|
||||
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame
|
||||
window = torch.cat([pad, window], dim=0)
|
||||
|
||||
# Create batch (batch size = 1)
|
||||
batch = {
|
||||
OBS_IMAGES: window_frames.unsqueeze(0), # (1, L, C, H, W)
|
||||
OBS_LANGUAGE: [language]
|
||||
}
|
||||
|
||||
# Get predictions for this window
|
||||
window_preds = model.predict_rewards(batch) # (1, L)
|
||||
window_preds = window_preds.squeeze(0).cpu().numpy() # (L,)
|
||||
|
||||
# Extract the relevant predictions for the actual frames
|
||||
actual_frames = min(L, end_idx - start_idx)
|
||||
if start_idx == 0 and window_frames.shape[0] > actual_frames:
|
||||
# Skip padding at beginning
|
||||
preds[start_idx:end_idx] = window_preds[-actual_frames:]
|
||||
# IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode
|
||||
# This ensures we use all 16 frame-specific MLPs and get varied outputs
|
||||
# Frames 0-15 use MLPs 0-15, frames 16-31 use MLPs 0-15 again, etc.
|
||||
frame_pos = i % L # Cycle through [0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ...]
|
||||
|
||||
windows.append(window)
|
||||
frame_positions.append(frame_pos)
|
||||
|
||||
preds = np.zeros(T, dtype=float)
|
||||
|
||||
for s in range(0, T, batch_size):
|
||||
e = min(s + batch_size, T)
|
||||
batch_windows = torch.stack(windows[s:e]) # (B, L, C, H, W)
|
||||
batch_positions = frame_positions[s:e]
|
||||
|
||||
batch = {OBS_IMAGES: batch_windows, OBS_LANGUAGE: [language] * (e - s)} # expects (B, L, C, H, W)
|
||||
|
||||
# Model returns (B, L) predictions for each temporal position
|
||||
values = model.predict_rewards(batch) # torch.Tensor (B, L)
|
||||
|
||||
# Debug output removed - issue was identified and fixed
|
||||
|
||||
if values.dim() == 2:
|
||||
# Extract the prediction corresponding to each frame's position in its window
|
||||
batch_preds = []
|
||||
for b_idx, pos in enumerate(batch_positions):
|
||||
batch_preds.append(values[b_idx, pos].item())
|
||||
preds[s:e] = np.array(batch_preds)
|
||||
else:
|
||||
# Take the first predictions (no beginning padding)
|
||||
preds[start_idx:end_idx] = window_preds[:actual_frames]
|
||||
# Fallback: if model returns (B,), use as is
|
||||
preds[s:e] = values.detach().float().cpu().numpy()
|
||||
|
||||
return preds
|
||||
|
||||
|
||||
@@ -184,22 +184,36 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Layer normalization before reward head to stabilize MLP outputs
|
||||
self.pre_reward_norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
# Temporal-aware regression head (logit mode only)
|
||||
# Concatenates frame embedding with normalized temporal position
|
||||
self.reward_head = nn.Sequential(
|
||||
nn.Linear(config.dim_model + 1, config.dim_model), # +1 for temporal position
|
||||
nn.ReLU(),
|
||||
nn.Linear(config.dim_model, 1)
|
||||
)
|
||||
# Temporal-aware regression head with increased capacity
|
||||
# Build a deeper MLP for better visual-progress learning
|
||||
head_layers = []
|
||||
|
||||
# Initialize temporal-aware head for logit regression
|
||||
# Input layer: embedding + temporal position -> hidden
|
||||
head_layers.extend([
|
||||
nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position
|
||||
nn.ReLU(),
|
||||
nn.Dropout(config.head_dropout)
|
||||
])
|
||||
|
||||
# Hidden layers: multiple layers for complex visual-progress mapping
|
||||
for _ in range(config.head_num_layers - 2): # -2 for input and output layers
|
||||
head_layers.extend([
|
||||
nn.Linear(config.head_hidden_dim, config.head_hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(config.head_dropout)
|
||||
])
|
||||
|
||||
# Output layer: hidden -> logit
|
||||
head_layers.append(nn.Linear(config.head_hidden_dim, 1))
|
||||
|
||||
self.reward_head = nn.Sequential(*head_layers)
|
||||
|
||||
# Initialize the deeper temporal-aware head for logit regression
|
||||
with torch.no_grad():
|
||||
# First layer: embedding + position -> embedding
|
||||
nn.init.normal_(self.reward_head[0].weight, 0.0, config.head_weight_init_std)
|
||||
nn.init.zeros_(self.reward_head[0].bias)
|
||||
# Output layer: embedding -> logit
|
||||
nn.init.normal_(self.reward_head[2].weight, 0.0, config.head_weight_init_std)
|
||||
nn.init.zeros_(self.reward_head[2].bias)
|
||||
for module in self.reward_head:
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0.0, config.head_weight_init_std)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
# Simple frame dropout probability
|
||||
self.frame_dropout_p = config.frame_dropout_p
|
||||
|
||||
Reference in New Issue
Block a user