From 5eb5bf71642e1f44aeae9f5af2e99c3b7fc6481b Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 1 Sep 2025 10:14:43 +0200 Subject: [PATCH] clean --- .../policies/rlearn/configuration_rlearn.py | 40 +- .../policies/rlearn/modeling_rlearn.py | 352 ++++-------------- 2 files changed, 88 insertions(+), 304 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 6fa7d40a5..0989c0c89 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -43,24 +43,28 @@ class RLearNConfig(PreTrainedConfig): text_model_name: str = "google/siglip2-base-patch16-224" freeze_backbones: bool = True - # Temporal aggregator - dim_model: int = 512 - n_heads: int = 8 - n_layers: int = 4 - dim_feedforward: int = 2048 - dropout: float = 0.1 - pre_norm: bool = True - frame_dropout_p: float = 0.0 - stride: int = 1 - # Sequence length, amount of past frames including current one to use in the temporal model max_seq_len: int = 16 # Temporal sampling stride (2 = skip every other frame for wider temporal coverage) temporal_sampling_stride: int = 2 + # Model dimensions and transformer + dim_model: int = 512 + num_layers: int = 4 + num_heads: int = 8 + ff_mult: int = 4 # Feed-forward multiplier, hidden = dim_model * ff_mult + dropout: float = 0.10 + num_register_tokens: int = 4 + + # Inference-time subsampling and regularization + inference_stride: int = 1 + frame_dropout_p: float = 0.10 + # Training learning_rate: float = 1e-3 weight_decay: float = 0.01 + head_lr_multiplier: float = 5.0 + logit_eps: float = 1e-4 # Performance optimizations use_amp: bool = True @@ -71,18 +75,6 @@ class RLearNConfig(PreTrainedConfig): rewind_last3_prob: float = 0.3 mismatch_prob: float = 0.2 - # Logit regression (only supported mode) - FIXED: Larger eps to prevent extreme targets - logit_eps: float = 0.02 # Was 1e-6 → logit(±13.8), now 0.02 → logit(±3.9) - head_lr_multiplier: float = 10.0 - head_weight_init_std: float = 0.05 - # Initialize head bias toward this target probability to avoid 0.5 plateau - head_initial_bias_target: float = 0.3 - - # Reward head architecture - FIXED: Simpler architecture to prevent flat basins - head_hidden_dim: int = 1024 # Hidden dimension for reward head - head_num_layers: int = 2 # REDUCED: 2 layers instead of 4 to prevent over-regularization - head_dropout: float = 0.05 # REDUCED: Less dropout to prevent conservatism - # Normalization presets normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { @@ -90,10 +82,6 @@ class RLearNConfig(PreTrainedConfig): } ) - # Architecture - num_register_tokens: int = 4 - mlp_predictor_depth: int = 3 - # Required path to episodes.jsonl for episode boundaries episodes_jsonl_path: str | None = "meta/episodes.jsonl" diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index eacbcc2f2..5e5c583c5 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -43,8 +43,7 @@ class RLearNPolicy(PreTrainedPolicy): - Visual encoder: frozen DinoV3 encoder, returns per-frame embeddings. - Text encoder: frozen SigLIP2, returns a language embedding. - - Temporal module: x_transformers Decoder with packed tokens [lang | register | video]. - - Output: per-timestep rewards via simple linear regression head. + """ config_class = RLearNConfig @@ -90,71 +89,43 @@ class RLearNPolicy(PreTrainedPolicy): self.vision_model.eval() self.text_model.eval() - # x_transformers Decoder (matching ReWiND exactly) - self.decoder = Decoder( - dim=config.dim_model, - depth=config.n_layers, - heads=config.n_heads, - attn_dim_head=64, # ReWiND default - ff_mult=config.dim_feedforward // config.dim_model, # Convert to multiplier - # Note: x_transformers uses attn_dropout and ff_dropout separately - attn_dropout=config.dropout, - ff_dropout=config.dropout, - ) - # Linear projections to the shared temporal model dimension self.to_lang_tokens = nn.Linear(self.text_hidden, config.dim_model) self.to_video_tokens = nn.Linear(self.vision_hidden, config.dim_model) - # Stronger temporal positional encoding - self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.1) - # Spatial (patch) positional encoding for patch tokens - self.max_patch_tokens = getattr(config, 'max_patch_tokens', 256) - self.spatial_pos_embedding = nn.Parameter(torch.randn(self.max_patch_tokens, config.dim_model) * 0.1) - - # Single MLP processes all frames - self.frame_mlp = nn.Linear(config.dim_model, config.dim_model) - - # Register / memory / attention sink tokens - self.num_register_tokens = config.num_register_tokens - self.register_tokens = nn.Parameter(torch.randn(config.num_register_tokens, config.dim_model) * 1e-2) + # First-frame positional embedding (only applied to the first video frame) + self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model)) - # MLP predictor (matching ReWiND's Feedforwards) - from x_mlps_pytorch import Feedforwards - self.mlp_predictor = Feedforwards( + # Cross-modal sequential aggregator – causal transformer over + # [language tokens | video frame tokens] + self.decoder = Decoder( dim=config.dim_model, - dim_out=None, - depth=config.mlp_predictor_depth + depth=config.num_layers, + heads=config.num_heads, + ff_mult=config.ff_mult, + attn_dropout=config.dropout, + ff_dropout=config.dropout, + cross_attend=False, + causal=True, + ) + + # Per-frame predictor head + self.frame_mlp = nn.Sequential( + nn.LayerNorm(config.dim_model), + nn.Linear(config.dim_model, config.dim_model), + nn.GELU(), + nn.Dropout(config.dropout), ) - - # FIXED: Simpler head architecture to prevent constant output pathology - # Remove LayerNorm (causes flat basin), reduce depth, larger init, less dropout - - # Simple 2-layer MLP with larger initialization to encourage exploration self.reward_head = nn.Sequential( - nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position - nn.ReLU(), - nn.Dropout(0.05), # Reduced dropout to prevent noise-induced conservatism - nn.Linear(config.head_hidden_dim, 1) + nn.Linear(config.dim_model, config.dim_model), + nn.GELU(), + nn.Dropout(config.dropout), + nn.Linear(config.dim_model, 1), ) - - # FIXED: Larger weight initialization + head bias warm-start to escape 0.5 plateau - with torch.no_grad(): - for i, module in enumerate(self.reward_head): - if isinstance(module, nn.Linear): - # Use Xavier/Glorot initialization for better gradient flow - nn.init.xavier_uniform_(module.weight, gain=1.0) - nn.init.zeros_(module.bias) - # Set last layer bias to logit(target0) where target0 is a prior (e.g., 0.3) - target0 = float(getattr(self.config, 'head_initial_bias_target', 0.3)) - target0 = min(max(target0, 1e-3), 1 - 1e-3) - initial_bias = torch.log(torch.tensor(target0) / (1 - torch.tensor(target0))) - last_linear: nn.Linear = self.reward_head[-1] # type: ignore - last_linear.bias.copy_(initial_bias) - - # Simple frame dropout probability - self.frame_dropout_p = config.frame_dropout_p - self.stride = max(1, config.stride) + + # Sampling and regularization knobs + self.stride = max(1, int(config.inference_stride)) + self.frame_dropout_p = float(config.frame_dropout_p) # Auto-load episode_data_index from episodes.jsonl if not provided if self.episode_data_index is None and getattr(config, "episodes_jsonl_path", None): @@ -203,102 +174,6 @@ class RLearNPolicy(PreTrainedPolicy): def select_action(self, batch: dict[str, Tensor]) -> Tensor: # Required by base class raise NotImplementedError("RLearN is a reward model and does not select actions") - @torch.no_grad() - def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor: - """Predict per-timestep rewards for evaluation using ReWiND architecture. - - Args: - batch: Input batch with OBS_IMAGES and optionally OBS_LANGUAGE - - Returns: - Predicted rewards tensor of shape (B, T) - """ - batch = self.normalize_inputs(batch) - - # Extract frames and form (B, T, C, H, W) - frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) - B, T, C, H, W = frames.shape - - # CRITICAL FIX: Do NOT apply stride during evaluation - # During evaluation, we want to process all frames in the sliding window - # Stride should only be used during training to reduce computational cost - T_eff = T # Use all frames during evaluation - - # Get language commands - commands = batch.get(OBS_LANGUAGE, None) - if commands is None: - commands = [""] * B - elif not isinstance(commands, list): - commands = [str(commands)] * B - - # Forward through ReWiND model (inference mode) - device = next(self.parameters()).device - frames = frames.to(device) - - # Process video frames -> patch tokens per frame - video_patch_embeds = self._encode_video_frames(frames).to(device) # (B, T, P, D_vision) - - # Language embeddings + mask - lang_embeds, mask = self._encode_language_tokens(commands, device) - - # Register tokens - register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B) - - # Project embeddings - lang_tokens = self.to_lang_tokens(lang_embeds) - video_tokens = self.to_video_tokens(video_patch_embeds) # (B, T, P, D) - # Add temporal + spatial positional encoding (window-relative time + patch index) - Bv, T_video, P_video, Dm = video_tokens.shape - if P_video > self.spatial_pos_embedding.shape[0]: - raise ValueError(f"Number of patch tokens {P_video} exceeds max_patch_tokens {self.spatial_pos_embedding.shape[0]}") - t_pos = self.temporal_pos_embedding[:T_video] # (T, D) - p_pos = self.spatial_pos_embedding[:P_video] # (P, D) - pos = t_pos[:, None, :] + p_pos[None, :, :] # (T, P, D) - video_tokens = video_tokens + pos # broadcast over batch - # Flatten patch dimension for attention - video_tokens = rearrange(video_tokens, 'b t p d -> b (t p) d') - - # Pack all tokens for attention - tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d') - - # Extend mask for register and video tokens - mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) - - # Forward through decoder - attended = self.decoder(tokens, mask=mask) - - # Unpack and get video token features - _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') # (B, T*P, D) - # Restore (B, T, P, D) and pool patches per frame - attended_video_tokens = rearrange(attended_video_tokens, 'b (t p) d -> b t p d', t=T_video, p=P_video) - frame_tokens = attended_video_tokens.mean(dim=2) # (B, T, D) - frame_tokens = self.frame_mlp(frame_tokens) - - # MLP predictor - video_frame_embeds = self.mlp_predictor(frame_tokens) - - # Get rewards via temporal-aware logit regression head (no pre-normalization) - - # Add temporal position information - B, T_pred = video_frame_embeds.shape[:2] - temporal_pos = torch.linspace(0, 1, T_pred, device=video_frame_embeds.device) - temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_pred, 1) # (B, T, 1) - - # Concatenate embeddings with temporal position - temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T, D+1) - - # Forward through temporal-aware head - raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T) - return torch.sigmoid(raw_logits) # Apply sigmoid for final predictions - - def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - # Initial version: no-op; rely on upstream processors if any - return batch - - def normalize_targets(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - # Initial version: no-op - return batch - def _encode_video_frames(self, frames: Tensor) -> Tensor: """Encode video frames through DinoV3 to get per-frame PATCH embeddings. @@ -483,11 +358,6 @@ class RLearNPolicy(PreTrainedPolicy): return patch_features - def _mask_from_lens(self, lens: Tensor) -> Tensor: - """Create mask from sequence lengths.""" - seq = torch.arange(lens.amax().item(), device=lens.device) - return einx.less('n, b -> b n', seq, lens) - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Compute ReWiND training loss with on-the-fly progress label generation. @@ -500,9 +370,6 @@ class RLearNPolicy(PreTrainedPolicy): """ import time forward_start = time.perf_counter() - - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) # Always use random anchor window sampling frames, anchor_stats = self._sample_random_anchor_windows(batch) @@ -549,45 +416,29 @@ class RLearNPolicy(PreTrainedPolicy): lang_time = time.perf_counter() - lang_start # Token preparation - # Register tokens - register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B) - # Project embeddings - lang_tokens = self.to_lang_tokens(lang_embeds) - video_tokens = self.to_video_tokens(video_patch_embeds) # (B, T, P, D) + lang_tokens = self.to_lang_tokens(lang_embeds) # (B, L, D) + # Collapse patches to per-frame tokens then project + video_frame_embeds = video_patch_embeds.mean(dim=2) # (B, T_eff, D_vision) + video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D) + # First-frame positional embedding only + video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos - # Add temporal + spatial positional encoding (window-relative only) - Bv, T_video, P_video, Dm = video_tokens.shape - if P_video > self.spatial_pos_embedding.shape[0]: - raise ValueError(f"Number of patch tokens {P_video} exceeds max_patch_tokens {self.spatial_pos_embedding.shape[0]}") - t_pos = self.temporal_pos_embedding[:T_video] # (T, D) - p_pos = self.spatial_pos_embedding[:P_video] # (P, D) - pos = t_pos[:, None, :] + p_pos[None, :, :] # (T, P, D) - video_tokens = video_tokens + pos - # Flatten patches into sequence tokens - video_tokens = rearrange(video_tokens, 'b t p d -> b (t p) d') - - # Pack all tokens for attention [lang | register | video] - tokens, lang_video_packed_shape = pack((lang_tokens, register_tokens, video_tokens), 'b * d') - - # Extend mask for register and video tokens - mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) - - # Forward through x_transformers Decoder + # Build attention mask for decoder (True = keep) + # Language mask from tokenizer, rest are fully valid + full_mask = F.pad(mask, (0, video_tokens.shape[1]), value=True) + # Pack and run transformer transformer_start = time.perf_counter() - attended = self.decoder(tokens, mask=mask) - - # Unpack and get video token features - _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') # (B, T*P, D) - # Restore (B, T, P, D) and pool patches per frame - attended_video_tokens = rearrange(attended_video_tokens, 'b (t p) d -> b t p d', t=T_video, p=P_video) - frame_tokens = attended_video_tokens.mean(dim=2) # (B, T, D) - frame_tokens = self.frame_mlp(frame_tokens) - - # MLP predictor - video_frame_embeds = self.mlp_predictor(frame_tokens) + tokens_packed, packed_shape = pack((lang_tokens, video_tokens), 'b * d') + attended = self.decoder(tokens_packed, mask=full_mask) + attended_lang, attended_video = unpack(attended, packed_shape, 'b * d') transformer_time = time.perf_counter() - transformer_start + # Per-frame prediction + frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D) + raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff) + predicted_rewards = torch.sigmoid(raw_logits) + # Generate progress labels on-the-fly (ReWiND approach) # IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window loss_dict: dict[str, float] = {} @@ -609,52 +460,18 @@ class RLearNPolicy(PreTrainedPolicy): ) target = self._calculate_anchor_based_progress(T_eff, anchor_stats) - # During inference, we might not want to compute loss - if not self.training and target is None: - # Return predictions without loss using temporal-aware head - - # Add temporal position information - B_inf, T_inf = video_frame_embeds.shape[:2] - temporal_pos = torch.linspace(0, 1, T_inf, device=video_frame_embeds.device) - temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B_inf, T_inf, 1) - - # Concatenate and forward through temporal-aware head - temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) - raw_logits = self.reward_head(temporal_input).squeeze(-1) - rewards = torch.sigmoid(raw_logits) - return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} - - # Calculate loss using logit regression + # Compute main loss (or just return predictions in eval) loss_start = time.perf_counter() + if target is None: + total_loss = raw_logits.mean() * 0.0 + loss = total_loss + else: + target_expanded = target # (B, T_eff) + eps = self.config.logit_eps + target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps)) + loss = F.mse_loss(raw_logits, target_logits) + total_loss = loss - # Get model outputs with temporal-aware head - - # Add temporal position information - temporal_pos = torch.linspace(0, 1, T_eff, device=video_frame_embeds.device) - temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_eff, 1) # (B, T_eff, 1) - - # Concatenate embeddings with temporal position - temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1) - - # Forward through temporal-aware head - raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T_eff) - - # FIXED: More robust logit regression with gradient protection - eps = self.config.logit_eps - target_expanded = target.expand(B, -1)[:, :T_eff] # Expand and trim to T_eff - target_clamped = torch.clamp(target_expanded, eps, 1 - eps) - target_logits = torch.logit(target_clamped) - - # Use Smooth L1 loss instead of MSE for better gradient stability - loss = F.smooth_l1_loss(raw_logits, target_logits, reduction='mean', beta=1.0) - - # Clip gradients specifically for the reward head during backward pass - # This prevents extreme gradients from corrupting AdamW momentum - if self.training: - raw_logits.register_hook(lambda grad: torch.clamp(grad, -5.0, 5.0)) - - # For logging, compute sigmoid predictions - predicted_rewards = torch.sigmoid(raw_logits) # Mismatched video-language pairs loss (only when languages actually differ) L_mismatch = torch.zeros((), device=device) @@ -683,54 +500,34 @@ class RLearNPolicy(PreTrainedPolicy): lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm) # Pack and forward - tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d') - mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) + tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, video_tokens), 'b * d') + mask_mm = F.pad(mask_mm, (0, video_tokens.shape[1]), value=True) attended_mm = self.decoder(tokens_mm, mask=mask_mm) - _, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d') + _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d') # Process mismatch frames with single MLP mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D) - mismatch_embeds = self.mlp_predictor(mismatch_tokens) - - # Predict near-zero progress for mismatched pairs with temporal awareness - - # Add temporal position information for mismatch computation - T_mismatch = mismatch_embeds.shape[1] - temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=mismatch_embeds.device) - temporal_pos_mm = temporal_pos_mm.unsqueeze(0).unsqueeze(-1).expand(B, T_mismatch, 1) - - # Concatenate mismatch embeddings with temporal position - temporal_input_mm = torch.cat([mismatch_embeds, temporal_pos_mm], dim=-1) - - # Forward through temporal-aware head - mismatch_raw_logits = self.reward_head(temporal_input_mm).squeeze(-1) - - # Create mask tensor for loss calculation + mismatch_raw_logits = self.reward_head(mismatch_tokens).squeeze(-1) + mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool) - if mismatch_tensor.any(): - # Target logit corresponding to sigmoid ≈ 0 eps = self.config.logit_eps - zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps)) - - # Only compute loss for samples that are actually mismatched + zeros_target_logits = torch.logit(torch.full_like(mismatch_raw_logits, eps)) mismatch_loss_per_sample = F.mse_loss( mismatch_raw_logits, zeros_target_logits, reduction='none' - ).mean(dim=1) # (B,) - - # Apply mask and average only over true mismatches + ).mean(dim=1) L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean() # Total loss - total_loss = loss + L_mismatch + total_loss = total_loss + L_mismatch loss_time = time.perf_counter() - loss_start # DEBUG: Clean logit regression monitoring with full array printing if self.training and torch.rand(1).item() < 0.03: with torch.no_grad(): sample_idx = torch.randint(0, B, (1,)).item() - sample_targets = target_expanded[sample_idx, :T_eff].cpu().numpy() - sample_preds = predicted_rewards[sample_idx].cpu().numpy() + sample_targets = target_expanded[sample_idx, :T_eff].cpu().numpy() if target is not None else np.zeros((T_eff,), dtype=np.float32) + sample_preds = predicted_rewards[sample_idx].detach().cpu().numpy() print(f"\n=== LOGIT REGRESSION DEBUG ===") print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}") @@ -767,7 +564,7 @@ class RLearNPolicy(PreTrainedPolicy): print(f"Sample {sample_idx}: T_eff={T_eff}, target ∈ [{sample_targets.min():.3f}, {sample_targets.max():.3f}], pred ∈ [{sample_preds.min():.3f}, {sample_preds.max():.3f}]") - print(f"Loss: {loss:.6f}") + print(f"Loss: {total_loss:.6f}") print("=" * 60) total_forward_time = time.perf_counter() - forward_start @@ -775,15 +572,15 @@ class RLearNPolicy(PreTrainedPolicy): # Log individual loss components loss_dict.update({ "loss": float(total_loss.detach().item()), - "loss_main": float(loss.detach().item()), + "loss_main": float(loss.detach().item() if isinstance(loss, torch.Tensor) else 0.0), "loss_mismatch": float(L_mismatch.detach().item()), "t_eff": float(T_eff), "lang_len_mean": float(mask.sum().float().mean().item()), # Use mask to get actual lengths # Target statistics for monitoring - "target_min": float(target.min().item()), - "target_max": float(target.max().item()), - "target_mean": float(target.mean().item()), - "target_std": float(target.std().item()), + "target_min": float(target.min().item()) if target is not None else 0.0, + "target_max": float(target.max().item()) if target is not None else 0.0, + "target_mean": float(target.mean().item()) if target is not None else 0.0, + "target_std": float(target.std().item()) if target is not None else 0.0, # Prediction statistics "pred_mean": float(predicted_rewards.mean().item()), "pred_std": float(predicted_rewards.std().item()), @@ -1122,8 +919,7 @@ class RLearNPolicy(PreTrainedPolicy): return torch.stack(all_progress) # (B, T_eff) - - + def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]: import json lengths: list[int] = []