mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
change config
This commit is contained in:
@@ -139,7 +139,7 @@ def extract_episode_frames_and_gt(dataset, episode_idx):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda"):
|
||||
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda", temporal_stride: int | None = None):
|
||||
"""
|
||||
Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i],
|
||||
left-pad by repeating the first frame to length L (<= 16), and take the prediction
|
||||
@@ -147,8 +147,13 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
Returns np.ndarray of shape (T,).
|
||||
"""
|
||||
T = frames.shape[0]
|
||||
L = int(getattr(getattr(model, "config", object()), "max_seq_len", max_seq_len))
|
||||
cfg = getattr(model, "config", object())
|
||||
L = int(getattr(cfg, "max_seq_len", max_seq_len))
|
||||
L = min(L, max_seq_len) # hard-cap at 16
|
||||
# Use the same temporal stride as training (skip s-1 frames, take 1)
|
||||
if temporal_stride is None:
|
||||
temporal_stride = int(getattr(cfg, "temporal_sampling_stride", 1))
|
||||
temporal_stride = max(1, int(temporal_stride))
|
||||
|
||||
# Preprocessed tensor on device
|
||||
frames = frames.to(device)
|
||||
@@ -158,21 +163,15 @@ def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=
|
||||
left_pad_counts = [] # Number of left-pad (OOB) frames per window
|
||||
|
||||
for i in range(T):
|
||||
start = max(0, i - L + 1)
|
||||
window = frames[start : i + 1] # (len<=L, C, H, W)
|
||||
|
||||
if window.shape[0] < L:
|
||||
pad_needed = L - window.shape[0]
|
||||
pad = window[:1].expand(pad_needed, -1, -1, -1) # repeat first frame (clamp to frame 0)
|
||||
window = torch.cat([pad, window], dim=0)
|
||||
else:
|
||||
pad_needed = 0
|
||||
|
||||
# IMPROVED FIX: Cycle through MLPs to get varied predictions throughout the episode
|
||||
# This ensures we use all 16 frame-specific MLPs and get varied outputs
|
||||
# Frames 0-15 use MLPs 0-15, frames 16-31 use MLPs 0-15 again, etc.
|
||||
frame_pos = i % L # Cycle through [0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ...]
|
||||
|
||||
# Build indices with stride s: [..., i-3, i] etc., left-padded by clamping to 0
|
||||
idxs = [i - (L - 1 - j) * temporal_stride for j in range(L)]
|
||||
pad_needed = sum(1 for k in idxs if k < 0)
|
||||
clamped = [0 if k < 0 else (T - 1 if k >= T else k) for k in idxs]
|
||||
window = frames[clamped] # (L, C, H, W)
|
||||
|
||||
# Use the last temporal position (current frame) for reading model output
|
||||
frame_pos = L - 1
|
||||
|
||||
windows.append(window)
|
||||
frame_positions.append(frame_pos)
|
||||
left_pad_counts.append(pad_needed)
|
||||
|
||||
Reference in New Issue
Block a user