Files
lerobot-clone/src/lerobot/utils/rabc.py
Pepijn 60efd875fa resolve path correctlt (#2710)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2025-12-26 23:57:17 +01:00

289 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
def resolve_hf_path(path: str | Path) -> Path:
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
path_str = str(path)
if path_str.startswith("hf://datasets/"):
parts = path_str.replace("hf://datasets/", "").split("/")
repo_id = "/".join(parts[:2])
filename = "/".join(parts[2:])
return Path(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"))
return Path(path)
class RABCWeights:
"""
Load precomputed SARM progress values and compute RA-BC weights during training.
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
During training, computes:
- progress_delta = progress[t + chunk_size] - progress[t]
- rabc_weight based on the delta (paper Eq. 8-9)
Args:
progress_path: Path to parquet file with precomputed progress values
chunk_size: Number of frames ahead for computing progress delta
head_mode: Which SARM head to use ("sparse" or "dense")
kappa: Hard threshold for high-quality samples (default: 0.01)
epsilon: Small constant for numerical stability (default: 1e-6)
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
device: Device to return tensors on
"""
def __init__(
self,
progress_path: str | Path,
chunk_size: int = 50,
head_mode: str = "sparse",
kappa: float = 0.01,
epsilon: float = 1e-6,
fallback_weight: float = 1.0,
device: torch.device = None,
):
self.progress_path = resolve_hf_path(progress_path)
self.chunk_size = chunk_size
self.head_mode = head_mode
self.kappa = kappa
self.epsilon = epsilon
self.fallback_weight = fallback_weight
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Determine progress column name
self.progress_column = f"progress_{head_mode}"
# Load progress values
logging.info(f"Loading SARM progress values from {self.progress_path}")
self.df = pd.read_parquet(self.progress_path)
# Check if the requested head mode column exists
if self.progress_column not in self.df.columns:
available = [c for c in self.df.columns if c.startswith("progress")]
raise ValueError(
f"Column '{self.progress_column}' not found. Available progress columns: {available}"
)
logging.info(f"Using progress column: {self.progress_column}")
self.progress_lookup = {}
self.episode_lookup = {}
for _, row in self.df.iterrows():
global_idx = int(row["index"])
progress = row[self.progress_column]
episode_idx = int(row["episode_index"])
if not np.isnan(progress):
self.progress_lookup[global_idx] = float(progress)
self.episode_lookup[global_idx] = episode_idx
# Build episode boundaries for delta computation
self.episode_boundaries = {}
for episode_idx in self.df["episode_index"].unique():
ep_df = self.df[self.df["episode_index"] == episode_idx]
self.episode_boundaries[int(episode_idx)] = {
"start": int(ep_df["index"].min()),
"end": int(ep_df["index"].max()) + 1,
}
logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
logging.info(f"Chunk size for delta computation: {chunk_size}")
# Compute global statistics for weight computation
self._compute_global_stats()
def _compute_global_stats(self):
"""Compute global mean and std of progress deltas for weight calculation."""
all_deltas = []
for global_idx, progress in self.progress_lookup.items():
episode_idx = self.episode_lookup.get(global_idx)
if episode_idx is None:
continue
bounds = self.episode_boundaries.get(episode_idx)
if bounds is None:
continue
future_idx = global_idx + self.chunk_size
if future_idx >= bounds["end"]:
# Near end of episode: use last frame's progress
future_idx = bounds["end"] - 1
future_progress = self.progress_lookup.get(future_idx)
if future_progress is not None:
delta = future_progress - progress
all_deltas.append(delta)
if all_deltas:
self.delta_mean = max(np.mean(all_deltas), 0.0)
self.delta_std = max(np.std(all_deltas), self.epsilon)
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
else:
self.delta_mean = 0.0
self.delta_std = self.epsilon
logging.warning("No valid progress deltas found, using default stats")
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""
Compute RA-BC weights for a batch.
For each sample:
1. Get progress at current frame
2. Get progress at frame + chunk_size (within same episode)
3. Compute delta = future_progress - current_progress
4. Compute weight using paper Eq. 8-9
Args:
batch: Training batch containing "index" key with global frame indices
Returns:
Tuple of:
- Weights tensor (batch_size,) normalized to sum to batch_size
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
"""
indices = batch.get("index")
if indices is None:
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
batch_size = self._get_batch_size(batch)
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
# Convert to list of ints
if isinstance(indices, torch.Tensor):
indices = indices.cpu().numpy().tolist()
elif isinstance(indices, np.ndarray):
indices = indices.tolist()
# Compute deltas and weights for each sample
deltas = []
for idx in indices:
idx = int(idx)
delta = self._compute_delta(idx)
deltas.append(delta)
deltas = np.array(deltas, dtype=np.float32)
# Compute weights from deltas
weights = self._compute_weights(deltas)
# Compute stats before normalization for logging
raw_mean_weight = float(np.nanmean(weights))
num_zero_weight = int(np.sum(weights == 0))
num_full_weight = int(np.sum(weights == 1.0))
batch_stats = {
"raw_mean_weight": raw_mean_weight,
"num_zero_weight": num_zero_weight,
"num_full_weight": num_full_weight,
}
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
# Normalize to sum to batch_size
batch_size = len(weights)
weight_sum = weights.sum() + self.epsilon
weights = weights * batch_size / weight_sum
return weights, batch_stats
def _compute_delta(self, global_idx: int) -> float:
"""Compute progress delta for a single frame."""
current_progress = self.progress_lookup.get(global_idx)
if current_progress is None:
return np.nan
episode_idx = self.episode_lookup.get(global_idx)
if episode_idx is None:
return np.nan
bounds = self.episode_boundaries.get(episode_idx)
if bounds is None:
return np.nan
future_idx = global_idx + self.chunk_size # Δ = chunk_size
if future_idx >= bounds["end"]:
# Near end of episode: use last frame's progress instead
future_idx = bounds["end"] - 1
future_progress = self.progress_lookup.get(future_idx)
if future_progress is None:
return np.nan
return future_progress - current_progress
def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
"""
Compute RA-BC weights from progress deltas.
Following paper Eq. 8-9:
- Soft weight: ˜wi = clip((ri 2σ)) / (4σ + ε), 0, 1)
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
Returns:
Array of weights
"""
valid_mask = ~np.isnan(deltas)
# Compute soft weights using global statistics
lower_bound = self.delta_mean - 2 * self.delta_std
soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
soft_weights = np.clip(soft_weights, 0.0, 1.0)
# Apply paper's Eq. 9
weights = np.zeros_like(deltas, dtype=np.float32)
# High quality: ri > kappa → weight = 1
high_quality_mask = deltas > self.kappa
weights[high_quality_mask] = 1.0
# Moderate quality: 0 <= ri <= kappa → weight = soft_weight
moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
weights[moderate_mask] = soft_weights[moderate_mask]
# Negative progress: ri < 0 → weight = 0 (already 0)
# Invalid (NaN): use fallback weight
weights[~valid_mask] = self.fallback_weight
return weights
def _get_batch_size(self, batch: dict) -> int:
"""Determine batch size from batch."""
for key in ["action", "index"]:
if key in batch:
val = batch[key]
if isinstance(val, (torch.Tensor, np.ndarray)):
return val.shape[0]
return 1
def get_stats(self) -> dict:
"""Get statistics."""
return {
"num_frames": len(self.progress_lookup),
"chunk_size": self.chunk_size,
"head_mode": self.head_mode,
"delta_mean": self.delta_mean,
"delta_std": self.delta_std,
"kappa": self.kappa,
}