refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train

This commit is contained in:
Michel Aractingi
2026-01-12 11:38:04 +01:00
parent 1d86c9b7f2
commit eb2e79e22d

View File

@@ -0,0 +1,309 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation.
This module implements the SampleWeighter protocol for RA-BC training,
which weights training samples based on their task progress as measured
by the SARM reward model.
The weights are computed based on progress deltas:
delta = progress[t + chunk_size] - progress[t]
High-quality samples (positive progress) get higher weights, while
samples with negative progress (going backwards) get zero weight.
See: https://arxiv.org/abs/2509.25358 for the SARM paper.
"""
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
def resolve_hf_path(path: str | Path) -> Path:
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
path_str = str(path)
if path_str.startswith("hf://datasets/"):
parts = path_str.replace("hf://datasets/", "").split("/")
repo_id = "/".join(parts[:2])
filename = "/".join(parts[2:])
return Path(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"))
return Path(path)
class RABCWeights:
"""
Load precomputed SARM progress values and compute RA-BC weights during training.
This class implements the SampleWeighter protocol for use with the generic
sample weighting infrastructure in lerobot.
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
During training, computes:
- progress_delta = progress[t + chunk_size] - progress[t]
- rabc_weight based on the delta (paper Eq. 8-9)
Args:
progress_path: Path to parquet file with precomputed progress values.
Supports HuggingFace URLs (hf://datasets/...).
chunk_size: Number of frames ahead for computing progress delta.
head_mode: Which SARM head to use ("sparse" or "dense").
kappa: Hard threshold for high-quality samples (default: 0.01).
epsilon: Small constant for numerical stability (default: 1e-6).
fallback_weight: Weight to use for frames without valid delta (default: 1.0).
device: Device to return tensors on.
"""
def __init__(
self,
progress_path: str | Path,
chunk_size: int = 50,
head_mode: str = "sparse",
kappa: float = 0.01,
epsilon: float = 1e-6,
fallback_weight: float = 1.0,
device: torch.device | None = None,
):
self.progress_path = resolve_hf_path(progress_path)
self.chunk_size = chunk_size
self.head_mode = head_mode
self.kappa = kappa
self.epsilon = epsilon
self.fallback_weight = fallback_weight
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Determine progress column name
self.progress_column = f"progress_{head_mode}"
# Load progress values
logging.info(f"Loading SARM progress values from {self.progress_path}")
self.df = pd.read_parquet(self.progress_path)
# Check if the requested head mode column exists
if self.progress_column not in self.df.columns:
available = [c for c in self.df.columns if c.startswith("progress")]
raise ValueError(
f"Column '{self.progress_column}' not found. Available progress columns: {available}"
)
logging.info(f"Using progress column: {self.progress_column}")
self.progress_lookup: dict[int, float] = {}
self.episode_lookup: dict[int, int] = {}
for _, row in self.df.iterrows():
global_idx = int(row["index"])
progress = row[self.progress_column]
episode_idx = int(row["episode_index"])
if not np.isnan(progress):
self.progress_lookup[global_idx] = float(progress)
self.episode_lookup[global_idx] = episode_idx
# Build episode boundaries for delta computation
self.episode_boundaries: dict[int, dict[str, int]] = {}
for episode_idx in self.df["episode_index"].unique():
ep_df = self.df[self.df["episode_index"] == episode_idx]
self.episode_boundaries[int(episode_idx)] = {
"start": int(ep_df["index"].min()),
"end": int(ep_df["index"].max()) + 1,
}
logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
logging.info(f"Chunk size for delta computation: {chunk_size}")
# Compute global statistics for weight computation
self._compute_global_stats()
def _compute_global_stats(self) -> None:
"""Compute global mean and std of progress deltas for weight calculation."""
all_deltas = []
for global_idx, progress in self.progress_lookup.items():
episode_idx = self.episode_lookup.get(global_idx)
if episode_idx is None:
continue
bounds = self.episode_boundaries.get(episode_idx)
if bounds is None:
continue
future_idx = global_idx + self.chunk_size
if future_idx >= bounds["end"]:
# Near end of episode: use last frame's progress
future_idx = bounds["end"] - 1
future_progress = self.progress_lookup.get(future_idx)
if future_progress is not None:
delta = future_progress - progress
all_deltas.append(delta)
if all_deltas:
self.delta_mean = max(float(np.mean(all_deltas)), 0.0)
self.delta_std = max(float(np.std(all_deltas)), self.epsilon)
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
else:
self.delta_mean = 0.0
self.delta_std = self.epsilon
logging.warning("No valid progress deltas found, using default stats")
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""
Compute RA-BC weights for a batch.
For each sample:
1. Get progress at current frame
2. Get progress at frame + chunk_size (within same episode)
3. Compute delta = future_progress - current_progress
4. Compute weight using paper Eq. 8-9
Args:
batch: Training batch containing "index" key with global frame indices.
Returns:
Tuple of:
- Weights tensor (batch_size,) normalized to sum to batch_size.
- Stats dict with weighting statistics for logging.
"""
indices = batch.get("index")
if indices is None:
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
batch_size = self._get_batch_size(batch)
stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size}
return torch.ones(batch_size, device=self.device), stats
# Convert to list of ints
if isinstance(indices, torch.Tensor):
indices = indices.cpu().numpy().tolist()
elif isinstance(indices, np.ndarray):
indices = indices.tolist()
# Compute deltas and weights for each sample
deltas = []
for idx in indices:
idx = int(idx)
delta = self._compute_delta(idx)
deltas.append(delta)
deltas_array = np.array(deltas, dtype=np.float32)
# Compute weights from deltas
weights = self._compute_weights(deltas_array)
# Compute stats before normalization for logging
raw_mean_weight = float(np.nanmean(weights))
num_zero_weight = int(np.sum(weights == 0))
num_full_weight = int(np.sum(weights == 1.0))
batch_stats = {
"mean_weight": raw_mean_weight,
"num_zero_weight": num_zero_weight,
"num_full_weight": num_full_weight,
}
weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32)
# Normalize to sum to batch_size
batch_size = len(weights_tensor)
weight_sum = weights_tensor.sum() + self.epsilon
weights_tensor = weights_tensor * batch_size / weight_sum
return weights_tensor, batch_stats
def _compute_delta(self, global_idx: int) -> float:
"""Compute progress delta for a single frame."""
current_progress = self.progress_lookup.get(global_idx)
if current_progress is None:
return float("nan")
episode_idx = self.episode_lookup.get(global_idx)
if episode_idx is None:
return float("nan")
bounds = self.episode_boundaries.get(episode_idx)
if bounds is None:
return float("nan")
future_idx = global_idx + self.chunk_size # Δ = chunk_size
if future_idx >= bounds["end"]:
# Near end of episode: use last frame's progress instead
future_idx = bounds["end"] - 1
future_progress = self.progress_lookup.get(future_idx)
if future_progress is None:
return float("nan")
return future_progress - current_progress
def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
"""
Compute RA-BC weights from progress deltas.
Following paper Eq. 8-9:
- Soft weight: ˜wi = clip((ri 2σ)) / (4σ + ε), 0, 1)
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
Returns:
Array of weights.
"""
valid_mask = ~np.isnan(deltas)
# Compute soft weights using global statistics
lower_bound = self.delta_mean - 2 * self.delta_std
soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
soft_weights = np.clip(soft_weights, 0.0, 1.0)
# Apply paper's Eq. 9
weights = np.zeros_like(deltas, dtype=np.float32)
# High quality: ri > kappa → weight = 1
high_quality_mask = deltas > self.kappa
weights[high_quality_mask] = 1.0
# Moderate quality: 0 <= ri <= kappa → weight = soft_weight
moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
weights[moderate_mask] = soft_weights[moderate_mask]
# Negative progress: ri < 0 → weight = 0 (already 0)
# Invalid (NaN): use fallback weight
weights[~valid_mask] = self.fallback_weight
return weights
def _get_batch_size(self, batch: dict) -> int:
"""Determine batch size from batch."""
for key in ["action", "index"]:
if key in batch:
val = batch[key]
if isinstance(val, (torch.Tensor, np.ndarray)):
return int(val.shape[0])
return 1
def get_stats(self) -> dict:
"""Get global statistics about the RA-BC weighting."""
return {
"type": "rabc",
"num_frames": len(self.progress_lookup),
"chunk_size": self.chunk_size,
"head_mode": self.head_mode,
"delta_mean": self.delta_mean,
"delta_std": self.delta_std,
"kappa": self.kappa,
}