mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
refactor(lerobot_train.py): remove rabc specific configuration and replace it with a generic samplerweight class in lerobot_train
This commit is contained in:
309
src/lerobot/policies/sarm/rabc.py
Normal file
309
src/lerobot/policies/sarm/rabc.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
RA-BC (Reward-Aligned Behavior Cloning) sample weighting implementation.
|
||||
|
||||
This module implements the SampleWeighter protocol for RA-BC training,
|
||||
which weights training samples based on their task progress as measured
|
||||
by the SARM reward model.
|
||||
|
||||
The weights are computed based on progress deltas:
|
||||
delta = progress[t + chunk_size] - progress[t]
|
||||
|
||||
High-quality samples (positive progress) get higher weights, while
|
||||
samples with negative progress (going backwards) get zero weight.
|
||||
|
||||
See: https://arxiv.org/abs/2509.25358 for the SARM paper.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
|
||||
def resolve_hf_path(path: str | Path) -> Path:
|
||||
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
|
||||
path_str = str(path)
|
||||
if path_str.startswith("hf://datasets/"):
|
||||
parts = path_str.replace("hf://datasets/", "").split("/")
|
||||
repo_id = "/".join(parts[:2])
|
||||
filename = "/".join(parts[2:])
|
||||
return Path(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"))
|
||||
return Path(path)
|
||||
|
||||
|
||||
class RABCWeights:
|
||||
"""
|
||||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||
|
||||
This class implements the SampleWeighter protocol for use with the generic
|
||||
sample weighting infrastructure in lerobot.
|
||||
|
||||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||
During training, computes:
|
||||
- progress_delta = progress[t + chunk_size] - progress[t]
|
||||
- rabc_weight based on the delta (paper Eq. 8-9)
|
||||
|
||||
Args:
|
||||
progress_path: Path to parquet file with precomputed progress values.
|
||||
Supports HuggingFace URLs (hf://datasets/...).
|
||||
chunk_size: Number of frames ahead for computing progress delta.
|
||||
head_mode: Which SARM head to use ("sparse" or "dense").
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01).
|
||||
epsilon: Small constant for numerical stability (default: 1e-6).
|
||||
fallback_weight: Weight to use for frames without valid delta (default: 1.0).
|
||||
device: Device to return tensors on.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
progress_path: str | Path,
|
||||
chunk_size: int = 50,
|
||||
head_mode: str = "sparse",
|
||||
kappa: float = 0.01,
|
||||
epsilon: float = 1e-6,
|
||||
fallback_weight: float = 1.0,
|
||||
device: torch.device | None = None,
|
||||
):
|
||||
self.progress_path = resolve_hf_path(progress_path)
|
||||
self.chunk_size = chunk_size
|
||||
self.head_mode = head_mode
|
||||
self.kappa = kappa
|
||||
self.epsilon = epsilon
|
||||
self.fallback_weight = fallback_weight
|
||||
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Determine progress column name
|
||||
self.progress_column = f"progress_{head_mode}"
|
||||
|
||||
# Load progress values
|
||||
logging.info(f"Loading SARM progress values from {self.progress_path}")
|
||||
self.df = pd.read_parquet(self.progress_path)
|
||||
|
||||
# Check if the requested head mode column exists
|
||||
if self.progress_column not in self.df.columns:
|
||||
available = [c for c in self.df.columns if c.startswith("progress")]
|
||||
raise ValueError(
|
||||
f"Column '{self.progress_column}' not found. Available progress columns: {available}"
|
||||
)
|
||||
|
||||
logging.info(f"Using progress column: {self.progress_column}")
|
||||
|
||||
self.progress_lookup: dict[int, float] = {}
|
||||
self.episode_lookup: dict[int, int] = {}
|
||||
|
||||
for _, row in self.df.iterrows():
|
||||
global_idx = int(row["index"])
|
||||
progress = row[self.progress_column]
|
||||
episode_idx = int(row["episode_index"])
|
||||
|
||||
if not np.isnan(progress):
|
||||
self.progress_lookup[global_idx] = float(progress)
|
||||
self.episode_lookup[global_idx] = episode_idx
|
||||
|
||||
# Build episode boundaries for delta computation
|
||||
self.episode_boundaries: dict[int, dict[str, int]] = {}
|
||||
for episode_idx in self.df["episode_index"].unique():
|
||||
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||||
self.episode_boundaries[int(episode_idx)] = {
|
||||
"start": int(ep_df["index"].min()),
|
||||
"end": int(ep_df["index"].max()) + 1,
|
||||
}
|
||||
|
||||
logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
|
||||
logging.info(f"Chunk size for delta computation: {chunk_size}")
|
||||
|
||||
# Compute global statistics for weight computation
|
||||
self._compute_global_stats()
|
||||
|
||||
def _compute_global_stats(self) -> None:
|
||||
"""Compute global mean and std of progress deltas for weight calculation."""
|
||||
all_deltas = []
|
||||
|
||||
for global_idx, progress in self.progress_lookup.items():
|
||||
episode_idx = self.episode_lookup.get(global_idx)
|
||||
if episode_idx is None:
|
||||
continue
|
||||
|
||||
bounds = self.episode_boundaries.get(episode_idx)
|
||||
if bounds is None:
|
||||
continue
|
||||
|
||||
future_idx = global_idx + self.chunk_size
|
||||
if future_idx >= bounds["end"]:
|
||||
# Near end of episode: use last frame's progress
|
||||
future_idx = bounds["end"] - 1
|
||||
|
||||
future_progress = self.progress_lookup.get(future_idx)
|
||||
if future_progress is not None:
|
||||
delta = future_progress - progress
|
||||
all_deltas.append(delta)
|
||||
|
||||
if all_deltas:
|
||||
self.delta_mean = max(float(np.mean(all_deltas)), 0.0)
|
||||
self.delta_std = max(float(np.std(all_deltas)), self.epsilon)
|
||||
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||||
else:
|
||||
self.delta_mean = 0.0
|
||||
self.delta_std = self.epsilon
|
||||
logging.warning("No valid progress deltas found, using default stats")
|
||||
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute RA-BC weights for a batch.
|
||||
|
||||
For each sample:
|
||||
1. Get progress at current frame
|
||||
2. Get progress at frame + chunk_size (within same episode)
|
||||
3. Compute delta = future_progress - current_progress
|
||||
4. Compute weight using paper Eq. 8-9
|
||||
|
||||
Args:
|
||||
batch: Training batch containing "index" key with global frame indices.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Weights tensor (batch_size,) normalized to sum to batch_size.
|
||||
- Stats dict with weighting statistics for logging.
|
||||
"""
|
||||
indices = batch.get("index")
|
||||
if indices is None:
|
||||
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||||
batch_size = self._get_batch_size(batch)
|
||||
stats = {"mean_weight": 1.0, "num_zero_weight": 0, "num_full_weight": batch_size}
|
||||
return torch.ones(batch_size, device=self.device), stats
|
||||
|
||||
# Convert to list of ints
|
||||
if isinstance(indices, torch.Tensor):
|
||||
indices = indices.cpu().numpy().tolist()
|
||||
elif isinstance(indices, np.ndarray):
|
||||
indices = indices.tolist()
|
||||
|
||||
# Compute deltas and weights for each sample
|
||||
deltas = []
|
||||
for idx in indices:
|
||||
idx = int(idx)
|
||||
delta = self._compute_delta(idx)
|
||||
deltas.append(delta)
|
||||
|
||||
deltas_array = np.array(deltas, dtype=np.float32)
|
||||
|
||||
# Compute weights from deltas
|
||||
weights = self._compute_weights(deltas_array)
|
||||
|
||||
# Compute stats before normalization for logging
|
||||
raw_mean_weight = float(np.nanmean(weights))
|
||||
num_zero_weight = int(np.sum(weights == 0))
|
||||
num_full_weight = int(np.sum(weights == 1.0))
|
||||
batch_stats = {
|
||||
"mean_weight": raw_mean_weight,
|
||||
"num_zero_weight": num_zero_weight,
|
||||
"num_full_weight": num_full_weight,
|
||||
}
|
||||
|
||||
weights_tensor = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||||
|
||||
# Normalize to sum to batch_size
|
||||
batch_size = len(weights_tensor)
|
||||
weight_sum = weights_tensor.sum() + self.epsilon
|
||||
weights_tensor = weights_tensor * batch_size / weight_sum
|
||||
|
||||
return weights_tensor, batch_stats
|
||||
|
||||
def _compute_delta(self, global_idx: int) -> float:
|
||||
"""Compute progress delta for a single frame."""
|
||||
current_progress = self.progress_lookup.get(global_idx)
|
||||
if current_progress is None:
|
||||
return float("nan")
|
||||
|
||||
episode_idx = self.episode_lookup.get(global_idx)
|
||||
if episode_idx is None:
|
||||
return float("nan")
|
||||
|
||||
bounds = self.episode_boundaries.get(episode_idx)
|
||||
if bounds is None:
|
||||
return float("nan")
|
||||
|
||||
future_idx = global_idx + self.chunk_size # Δ = chunk_size
|
||||
if future_idx >= bounds["end"]:
|
||||
# Near end of episode: use last frame's progress instead
|
||||
future_idx = bounds["end"] - 1
|
||||
|
||||
future_progress = self.progress_lookup.get(future_idx)
|
||||
if future_progress is None:
|
||||
return float("nan")
|
||||
|
||||
return future_progress - current_progress
|
||||
|
||||
def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute RA-BC weights from progress deltas.
|
||||
|
||||
Following paper Eq. 8-9:
|
||||
- Soft weight: ˜wi = clip((ri − (µ − 2σ)) / (4σ + ε), 0, 1)
|
||||
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||||
|
||||
Returns:
|
||||
Array of weights.
|
||||
"""
|
||||
valid_mask = ~np.isnan(deltas)
|
||||
|
||||
# Compute soft weights using global statistics
|
||||
lower_bound = self.delta_mean - 2 * self.delta_std
|
||||
soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
|
||||
soft_weights = np.clip(soft_weights, 0.0, 1.0)
|
||||
|
||||
# Apply paper's Eq. 9
|
||||
weights = np.zeros_like(deltas, dtype=np.float32)
|
||||
|
||||
# High quality: ri > kappa → weight = 1
|
||||
high_quality_mask = deltas > self.kappa
|
||||
weights[high_quality_mask] = 1.0
|
||||
|
||||
# Moderate quality: 0 <= ri <= kappa → weight = soft_weight
|
||||
moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
|
||||
weights[moderate_mask] = soft_weights[moderate_mask]
|
||||
|
||||
# Negative progress: ri < 0 → weight = 0 (already 0)
|
||||
# Invalid (NaN): use fallback weight
|
||||
weights[~valid_mask] = self.fallback_weight
|
||||
|
||||
return weights
|
||||
|
||||
def _get_batch_size(self, batch: dict) -> int:
|
||||
"""Determine batch size from batch."""
|
||||
for key in ["action", "index"]:
|
||||
if key in batch:
|
||||
val = batch[key]
|
||||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||||
return int(val.shape[0])
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get global statistics about the RA-BC weighting."""
|
||||
return {
|
||||
"type": "rabc",
|
||||
"num_frames": len(self.progress_lookup),
|
||||
"chunk_size": self.chunk_size,
|
||||
"head_mode": self.head_mode,
|
||||
"delta_mean": self.delta_mean,
|
||||
"delta_std": self.delta_std,
|
||||
"kappa": self.kappa,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user