From eb2e79e22db3b799047e35db5198cf7e0e911fff Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 12 Jan 2026 11:38:04 +0100 Subject: [PATCH] refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train --- src/lerobot/policies/sarm/rabc.py | 309 ++++++++++++++++++++++++++++++ 1 file changed, 309 insertions(+) create mode 100644 src/lerobot/policies/sarm/rabc.py diff --git a/src/lerobot/policies/sarm/rabc.py b/src/lerobot/policies/sarm/rabc.py new file mode 100644 index 000000000..42b11c70d --- /dev/null +++ b/src/lerobot/policies/sarm/rabc.py @@ -0,0 +1,309 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation. + +This module implements the SampleWeighter protocol for RA-BC training, +which weights training samples based on their task progress as measured +by the SARM reward model. + +The weights are computed based on progress deltas: + delta = progress[t + chunk_size] - progress[t] + +High-quality samples (positive progress) get higher weights, while +samples with negative progress (going backwards) get zero weight. + +See: https://arxiv.org/abs/2509.25358 for the SARM paper. +""" + +import logging +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from huggingface_hub import hf_hub_download + + +def resolve_hf_path(path: str | Path) -> Path: + """Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path.""" + path_str = str(path) + if path_str.startswith("hf://datasets/"): + parts = path_str.replace("hf://datasets/", "").split("/") + repo_id = "/".join(parts[:2]) + filename = "/".join(parts[2:]) + return Path(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")) + return Path(path) + + +class RABCWeights: + """ + Load precomputed SARM progress values and compute RA-BC weights during training. + + This class implements the SampleWeighter protocol for use with the generic + sample weighting infrastructure in lerobot. + + Progress values are loaded from a parquet file (generated by compute_rabc_weights.py). + During training, computes: + - progress_delta = progress[t + chunk_size] - progress[t] + - rabc_weight based on the delta (paper Eq. 8-9) + + Args: + progress_path: Path to parquet file with precomputed progress values. + Supports HuggingFace URLs (hf://datasets/...). + chunk_size: Number of frames ahead for computing progress delta. + head_mode: Which SARM head to use ("sparse" or "dense"). + kappa: Hard threshold for high-quality samples (default: 0.01). + epsilon: Small constant for numerical stability (default: 1e-6). + fallback_weight: Weight to use for frames without valid delta (default: 1.0). + device: Device to return tensors on. + """ + + def __init__( + self, + progress_path: str | Path, + chunk_size: int = 50, + head_mode: str = "sparse", + kappa: float = 0.01, + epsilon: float = 1e-6, + fallback_weight: float = 1.0, + device: torch.device | None = None, + ): + self.progress_path = resolve_hf_path(progress_path) + self.chunk_size = chunk_size + self.head_mode = head_mode + self.kappa = kappa + self.epsilon = epsilon + self.fallback_weight = fallback_weight + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Determine progress column name + self.progress_column = f"progress_{head_mode}" + + # Load progress values + logging.info(f"Loading SARM progress values from {self.progress_path}") + self.df = pd.read_parquet(self.progress_path) + + # Check if the requested head mode column exists + if self.progress_column not in self.df.columns: + available = [c for c in self.df.columns if c.startswith("progress")] + raise ValueError( + f"Column '{self.progress_column}' not found. Available progress columns: {available}" + ) + + logging.info(f"Using progress column: {self.progress_column}") + + self.progress_lookup: dict[int, float] = {} + self.episode_lookup: dict[int, int] = {} + + for _, row in self.df.iterrows(): + global_idx = int(row["index"]) + progress = row[self.progress_column] + episode_idx = int(row["episode_index"]) + + if not np.isnan(progress): + self.progress_lookup[global_idx] = float(progress) + self.episode_lookup[global_idx] = episode_idx + + # Build episode boundaries for delta computation + self.episode_boundaries: dict[int, dict[str, int]] = {} + for episode_idx in self.df["episode_index"].unique(): + ep_df = self.df[self.df["episode_index"] == episode_idx] + self.episode_boundaries[int(episode_idx)] = { + "start": int(ep_df["index"].min()), + "end": int(ep_df["index"].max()) + 1, + } + + logging.info(f"Loaded {len(self.progress_lookup)} frame progress values") + logging.info(f"Chunk size for delta computation: {chunk_size}") + + # Compute global statistics for weight computation + self._compute_global_stats() + + def _compute_global_stats(self) -> None: + """Compute global mean and std of progress deltas for weight calculation.""" + all_deltas = [] + + for global_idx, progress in self.progress_lookup.items(): + episode_idx = self.episode_lookup.get(global_idx) + if episode_idx is None: + continue + + bounds = self.episode_boundaries.get(episode_idx) + if bounds is None: + continue + + future_idx = global_idx + self.chunk_size + if future_idx >= bounds["end"]: + # Near end of episode: use last frame's progress + future_idx = bounds["end"] - 1 + + future_progress = self.progress_lookup.get(future_idx) + if future_progress is not None: + delta = future_progress - progress + all_deltas.append(delta) + + if all_deltas: + self.delta_mean = max(float(np.mean(all_deltas)), 0.0) + self.delta_std = max(float(np.std(all_deltas)), self.epsilon) + logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}") + else: + self.delta_mean = 0.0 + self.delta_std = self.epsilon + logging.warning("No valid progress deltas found, using default stats") + + def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]: + """ + Compute RA-BC weights for a batch. + + For each sample: + 1. Get progress at current frame + 2. Get progress at frame + chunk_size (within same episode) + 3. Compute delta = future_progress - current_progress + 4. Compute weight using paper Eq. 8-9 + + Args: + batch: Training batch containing "index" key with global frame indices. + + Returns: + Tuple of: + - Weights tensor (batch_size,) normalized to sum to batch_size. + - Stats dict with weighting statistics for logging. + """ + indices = batch.get("index") + if indices is None: + logging.warning("RA-BC: Batch missing 'index' key, using uniform weights") + batch_size = self._get_batch_size(batch) + stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size} + return torch.ones(batch_size, device=self.device), stats + + # Convert to list of ints + if isinstance(indices, torch.Tensor): + indices = indices.cpu().numpy().tolist() + elif isinstance(indices, np.ndarray): + indices = indices.tolist() + + # Compute deltas and weights for each sample + deltas = [] + for idx in indices: + idx = int(idx) + delta = self._compute_delta(idx) + deltas.append(delta) + + deltas_array = np.array(deltas, dtype=np.float32) + + # Compute weights from deltas + weights = self._compute_weights(deltas_array) + + # Compute stats before normalization for logging + raw_mean_weight = float(np.nanmean(weights)) + num_zero_weight = int(np.sum(weights == 0)) + num_full_weight = int(np.sum(weights == 1.0)) + batch_stats = { + "mean_weight": raw_mean_weight, + "num_zero_weight": num_zero_weight, + "num_full_weight": num_full_weight, + } + + weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32) + + # Normalize to sum to batch_size + batch_size = len(weights_tensor) + weight_sum = weights_tensor.sum() + self.epsilon + weights_tensor = weights_tensor * batch_size / weight_sum + + return weights_tensor, batch_stats + + def _compute_delta(self, global_idx: int) -> float: + """Compute progress delta for a single frame.""" + current_progress = self.progress_lookup.get(global_idx) + if current_progress is None: + return float("nan") + + episode_idx = self.episode_lookup.get(global_idx) + if episode_idx is None: + return float("nan") + + bounds = self.episode_boundaries.get(episode_idx) + if bounds is None: + return float("nan") + + future_idx = global_idx + self.chunk_size # Δ = chunk_size + if future_idx >= bounds["end"]: + # Near end of episode: use last frame's progress instead + future_idx = bounds["end"] - 1 + + future_progress = self.progress_lookup.get(future_idx) + if future_progress is None: + return float("nan") + + return future_progress - current_progress + + def _compute_weights(self, deltas: np.ndarray) -> np.ndarray: + """ + Compute RA-BC weights from progress deltas. + + Following paper Eq. 8-9: + - Soft weight: ˜wi = clip((ri − (µ − 2σ)) / (4σ + ε), 0, 1) + - Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi + + Returns: + Array of weights. + """ + valid_mask = ~np.isnan(deltas) + + # Compute soft weights using global statistics + lower_bound = self.delta_mean - 2 * self.delta_std + soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon) + soft_weights = np.clip(soft_weights, 0.0, 1.0) + + # Apply paper's Eq. 9 + weights = np.zeros_like(deltas, dtype=np.float32) + + # High quality: ri > kappa → weight = 1 + high_quality_mask = deltas > self.kappa + weights[high_quality_mask] = 1.0 + + # Moderate quality: 0 <= ri <= kappa → weight = soft_weight + moderate_mask = (deltas >= 0) & (deltas <= self.kappa) + weights[moderate_mask] = soft_weights[moderate_mask] + + # Negative progress: ri < 0 → weight = 0 (already 0) + # Invalid (NaN): use fallback weight + weights[~valid_mask] = self.fallback_weight + + return weights + + def _get_batch_size(self, batch: dict) -> int: + """Determine batch size from batch.""" + for key in ["action", "index"]: + if key in batch: + val = batch[key] + if isinstance(val, (torch.Tensor, np.ndarray)): + return int(val.shape[0]) + return 1 + + def get_stats(self) -> dict: + """Get global statistics about the RA-BC weighting.""" + return { + "type": "rabc", + "num_frames": len(self.progress_lookup), + "chunk_size": self.chunk_size, + "head_mode": self.head_mode, + "delta_mean": self.delta_mean, + "delta_std": self.delta_std, + "kappa": self.kappa, + } +