mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
289 lines
10 KiB
Python
289 lines
10 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import torch
|
||
from huggingface_hub import hf_hub_download
|
||
|
||
|
||
def resolve_hf_path(path: str | Path) -> Path:
|
||
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
|
||
path_str = str(path)
|
||
if path_str.startswith("hf://datasets/"):
|
||
parts = path_str.replace("hf://datasets/", "").split("/")
|
||
repo_id = "/".join(parts[:2])
|
||
filename = "/".join(parts[2:])
|
||
return Path(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset"))
|
||
return Path(path)
|
||
|
||
|
||
class RABCWeights:
|
||
"""
|
||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||
|
||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||
During training, computes:
|
||
- progress_delta = progress[t + chunk_size] - progress[t]
|
||
- rabc_weight based on the delta (paper Eq. 8-9)
|
||
|
||
Args:
|
||
progress_path: Path to parquet file with precomputed progress values
|
||
chunk_size: Number of frames ahead for computing progress delta
|
||
head_mode: Which SARM head to use ("sparse" or "dense")
|
||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||
epsilon: Small constant for numerical stability (default: 1e-6)
|
||
fallback_weight: Weight to use for frames without valid delta (default: 1.0)
|
||
device: Device to return tensors on
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
progress_path: str | Path,
|
||
chunk_size: int = 50,
|
||
head_mode: str = "sparse",
|
||
kappa: float = 0.01,
|
||
epsilon: float = 1e-6,
|
||
fallback_weight: float = 1.0,
|
||
device: torch.device = None,
|
||
):
|
||
self.progress_path = resolve_hf_path(progress_path)
|
||
self.chunk_size = chunk_size
|
||
self.head_mode = head_mode
|
||
self.kappa = kappa
|
||
self.epsilon = epsilon
|
||
self.fallback_weight = fallback_weight
|
||
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# Determine progress column name
|
||
self.progress_column = f"progress_{head_mode}"
|
||
|
||
# Load progress values
|
||
logging.info(f"Loading SARM progress values from {self.progress_path}")
|
||
self.df = pd.read_parquet(self.progress_path)
|
||
|
||
# Check if the requested head mode column exists
|
||
if self.progress_column not in self.df.columns:
|
||
available = [c for c in self.df.columns if c.startswith("progress")]
|
||
raise ValueError(
|
||
f"Column '{self.progress_column}' not found. Available progress columns: {available}"
|
||
)
|
||
|
||
logging.info(f"Using progress column: {self.progress_column}")
|
||
|
||
self.progress_lookup = {}
|
||
self.episode_lookup = {}
|
||
|
||
for _, row in self.df.iterrows():
|
||
global_idx = int(row["index"])
|
||
progress = row[self.progress_column]
|
||
episode_idx = int(row["episode_index"])
|
||
|
||
if not np.isnan(progress):
|
||
self.progress_lookup[global_idx] = float(progress)
|
||
self.episode_lookup[global_idx] = episode_idx
|
||
|
||
# Build episode boundaries for delta computation
|
||
self.episode_boundaries = {}
|
||
for episode_idx in self.df["episode_index"].unique():
|
||
ep_df = self.df[self.df["episode_index"] == episode_idx]
|
||
self.episode_boundaries[int(episode_idx)] = {
|
||
"start": int(ep_df["index"].min()),
|
||
"end": int(ep_df["index"].max()) + 1,
|
||
}
|
||
|
||
logging.info(f"Loaded {len(self.progress_lookup)} frame progress values")
|
||
logging.info(f"Chunk size for delta computation: {chunk_size}")
|
||
|
||
# Compute global statistics for weight computation
|
||
self._compute_global_stats()
|
||
|
||
def _compute_global_stats(self):
|
||
"""Compute global mean and std of progress deltas for weight calculation."""
|
||
all_deltas = []
|
||
|
||
for global_idx, progress in self.progress_lookup.items():
|
||
episode_idx = self.episode_lookup.get(global_idx)
|
||
if episode_idx is None:
|
||
continue
|
||
|
||
bounds = self.episode_boundaries.get(episode_idx)
|
||
if bounds is None:
|
||
continue
|
||
|
||
future_idx = global_idx + self.chunk_size
|
||
if future_idx >= bounds["end"]:
|
||
# Near end of episode: use last frame's progress
|
||
future_idx = bounds["end"] - 1
|
||
|
||
future_progress = self.progress_lookup.get(future_idx)
|
||
if future_progress is not None:
|
||
delta = future_progress - progress
|
||
all_deltas.append(delta)
|
||
|
||
if all_deltas:
|
||
self.delta_mean = max(np.mean(all_deltas), 0.0)
|
||
self.delta_std = max(np.std(all_deltas), self.epsilon)
|
||
logging.info(f"Progress delta stats: mean={self.delta_mean:.4f}, std={self.delta_std:.4f}")
|
||
else:
|
||
self.delta_mean = 0.0
|
||
self.delta_std = self.epsilon
|
||
logging.warning("No valid progress deltas found, using default stats")
|
||
|
||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||
"""
|
||
Compute RA-BC weights for a batch.
|
||
|
||
For each sample:
|
||
1. Get progress at current frame
|
||
2. Get progress at frame + chunk_size (within same episode)
|
||
3. Compute delta = future_progress - current_progress
|
||
4. Compute weight using paper Eq. 8-9
|
||
|
||
Args:
|
||
batch: Training batch containing "index" key with global frame indices
|
||
|
||
Returns:
|
||
Tuple of:
|
||
- Weights tensor (batch_size,) normalized to sum to batch_size
|
||
- Stats dict with raw_mean_weight, num_zero_weight, num_full_weight
|
||
"""
|
||
indices = batch.get("index")
|
||
if indices is None:
|
||
logging.warning("RA-BC: Batch missing 'index' key, using uniform weights")
|
||
batch_size = self._get_batch_size(batch)
|
||
return torch.ones(batch_size, device=self.device), {"raw_mean_weight": 1.0}
|
||
|
||
# Convert to list of ints
|
||
if isinstance(indices, torch.Tensor):
|
||
indices = indices.cpu().numpy().tolist()
|
||
elif isinstance(indices, np.ndarray):
|
||
indices = indices.tolist()
|
||
|
||
# Compute deltas and weights for each sample
|
||
deltas = []
|
||
for idx in indices:
|
||
idx = int(idx)
|
||
delta = self._compute_delta(idx)
|
||
deltas.append(delta)
|
||
|
||
deltas = np.array(deltas, dtype=np.float32)
|
||
|
||
# Compute weights from deltas
|
||
weights = self._compute_weights(deltas)
|
||
|
||
# Compute stats before normalization for logging
|
||
raw_mean_weight = float(np.nanmean(weights))
|
||
num_zero_weight = int(np.sum(weights == 0))
|
||
num_full_weight = int(np.sum(weights == 1.0))
|
||
batch_stats = {
|
||
"raw_mean_weight": raw_mean_weight,
|
||
"num_zero_weight": num_zero_weight,
|
||
"num_full_weight": num_full_weight,
|
||
}
|
||
|
||
weights = torch.tensor(weights, device=self.device, dtype=torch.float32)
|
||
|
||
# Normalize to sum to batch_size
|
||
batch_size = len(weights)
|
||
weight_sum = weights.sum() + self.epsilon
|
||
weights = weights * batch_size / weight_sum
|
||
|
||
return weights, batch_stats
|
||
|
||
def _compute_delta(self, global_idx: int) -> float:
|
||
"""Compute progress delta for a single frame."""
|
||
current_progress = self.progress_lookup.get(global_idx)
|
||
if current_progress is None:
|
||
return np.nan
|
||
|
||
episode_idx = self.episode_lookup.get(global_idx)
|
||
if episode_idx is None:
|
||
return np.nan
|
||
|
||
bounds = self.episode_boundaries.get(episode_idx)
|
||
if bounds is None:
|
||
return np.nan
|
||
|
||
future_idx = global_idx + self.chunk_size # Δ = chunk_size
|
||
if future_idx >= bounds["end"]:
|
||
# Near end of episode: use last frame's progress instead
|
||
future_idx = bounds["end"] - 1
|
||
|
||
future_progress = self.progress_lookup.get(future_idx)
|
||
if future_progress is None:
|
||
return np.nan
|
||
|
||
return future_progress - current_progress
|
||
|
||
def _compute_weights(self, deltas: np.ndarray) -> np.ndarray:
|
||
"""
|
||
Compute RA-BC weights from progress deltas.
|
||
|
||
Following paper Eq. 8-9:
|
||
- Soft weight: ˜wi = clip((ri − (µ − 2σ)) / (4σ + ε), 0, 1)
|
||
- Final weight: wi = 1{ri > κ} + 1{0 ≤ ri ≤ κ}˜wi
|
||
|
||
Returns:
|
||
Array of weights
|
||
"""
|
||
valid_mask = ~np.isnan(deltas)
|
||
|
||
# Compute soft weights using global statistics
|
||
lower_bound = self.delta_mean - 2 * self.delta_std
|
||
soft_weights = (deltas - lower_bound) / (4 * self.delta_std + self.epsilon)
|
||
soft_weights = np.clip(soft_weights, 0.0, 1.0)
|
||
|
||
# Apply paper's Eq. 9
|
||
weights = np.zeros_like(deltas, dtype=np.float32)
|
||
|
||
# High quality: ri > kappa → weight = 1
|
||
high_quality_mask = deltas > self.kappa
|
||
weights[high_quality_mask] = 1.0
|
||
|
||
# Moderate quality: 0 <= ri <= kappa → weight = soft_weight
|
||
moderate_mask = (deltas >= 0) & (deltas <= self.kappa)
|
||
weights[moderate_mask] = soft_weights[moderate_mask]
|
||
|
||
# Negative progress: ri < 0 → weight = 0 (already 0)
|
||
# Invalid (NaN): use fallback weight
|
||
weights[~valid_mask] = self.fallback_weight
|
||
|
||
return weights
|
||
|
||
def _get_batch_size(self, batch: dict) -> int:
|
||
"""Determine batch size from batch."""
|
||
for key in ["action", "index"]:
|
||
if key in batch:
|
||
val = batch[key]
|
||
if isinstance(val, (torch.Tensor, np.ndarray)):
|
||
return val.shape[0]
|
||
return 1
|
||
|
||
def get_stats(self) -> dict:
|
||
"""Get statistics."""
|
||
return {
|
||
"num_frames": len(self.progress_lookup),
|
||
"chunk_size": self.chunk_size,
|
||
"head_mode": self.head_mode,
|
||
"delta_mean": self.delta_mean,
|
||
"delta_std": self.delta_std,
|
||
"kappa": self.kappa,
|
||
}
|