add testing for sampl weighter

This commit is contained in:
Michel Aractingi
2026-01-12 14:10:03 +01:00
parent 9ebc144b30
commit ded91ca866
3 changed files with 396 additions and 21 deletions

View File

@@ -36,6 +36,8 @@ import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from lerobot.utils.sample_weighting import SampleWeighter
def resolve_hf_path(path: str | Path) -> Path:
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
@@ -48,11 +50,11 @@ def resolve_hf_path(path: str | Path) -> Path:
return Path(path)
class RABCWeights:
class RABCWeights(SampleWeighter):
"""
Load precomputed SARM progress values and compute RA-BC weights during training.
This class implements the SampleWeighter protocol for use with the generic
This class implements the SampleWeighter ABC for use with the generic
sample weighting infrastructure in lerobot.
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).

View File

@@ -15,7 +15,7 @@
"""
Sample weighting abstraction for training.
This module provides a generic protocol for sample weighting strategies (e.g., RA-BC)
This module provides an abstract base class for sample weighting strategies (e.g., RA-BC)
that can be used during training without polluting the training script with
policy-specific code.
@@ -35,8 +35,9 @@ Example usage:
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from typing import TYPE_CHECKING
import torch
@@ -44,11 +45,8 @@ if TYPE_CHECKING:
from lerobot.policies.pretrained import PreTrainedPolicy
@runtime_checkable
class SampleWeighter(Protocol):
class SampleWeighter(ABC):
"""
Protocol for sample weighting strategies during training.
Implementations compute per-sample weights that can be used to weight
the loss during training. This enables techniques like:
- RA-BC (Reward-Aligned Behavior Cloning)
@@ -57,6 +55,7 @@ class SampleWeighter(Protocol):
- Quality-based filtering
"""
@abstractmethod
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""
Compute per-sample weights for a training batch.
@@ -71,8 +70,8 @@ class SampleWeighter(Protocol):
normalized to sum to batch_size for stable gradients.
- stats: Dictionary with logging-friendly statistics about the weights.
"""
...
@abstractmethod
def get_stats(self) -> dict:
"""
Get global statistics about the weighting strategy.
@@ -80,7 +79,6 @@ class SampleWeighter(Protocol):
Returns:
Dictionary with statistics for logging (e.g., mean delta, coverage).
"""
...
@dataclass
@@ -89,8 +87,8 @@ class SampleWeightingConfig:
Configuration for sample weighting during training.
This is a generic config that supports multiple weighting strategies.
The `type` field determines which implementation to use, and `params`
contains type-specific parameters.
The `type` field determines which implementation to use, and `extra_params`
contains additional type-specific parameters.
Attributes:
type: Weighting strategy type ("rabc", "uniform", etc.)
@@ -98,6 +96,7 @@ class SampleWeightingConfig:
head_mode: Which model head to use for progress ("sparse" or "dense")
kappa: Hard threshold for high-quality samples (RABC-specific)
epsilon: Small constant for numerical stability
extra_params: Additional type-specific parameters passed to the weighter
"""
type: str = "rabc"
@@ -178,12 +177,17 @@ def _make_rabc_weighter(
)
class UniformWeighter:
class UniformWeighter(SampleWeighter):
"""
No-op sample weighter that returns uniform weights.
Useful as a baseline or when you want to disable weighting without
changing the training code structure.
Note:
Batch size is determined by looking for tensor values in the batch
dictionary. The method checks common keys like "action", "index",
and "observation.state" first, then falls back to scanning all values.
"""
def __init__(self, device: torch.device):
@@ -191,19 +195,43 @@ class UniformWeighter:
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
"""Return uniform weights (all ones)."""
# Determine batch size from batch
batch_size = 1
for key in ["action", "index"]:
if key in batch:
val = batch[key]
if isinstance(val, torch.Tensor):
batch_size = val.shape[0]
break
batch_size = self._determine_batch_size(batch)
weights = torch.ones(batch_size, device=self.device)
stats = {"mean_weight": 1.0, "type": "uniform"}
return weights, stats
def _determine_batch_size(self, batch: dict) -> int:
"""
Determine batch size from the batch dictionary.
Checks common keys first, then scans all values for tensors.
Args:
batch: Training batch dictionary.
Returns:
Batch size, or 1 if it cannot be determined.
Raises:
ValueError: If batch is empty.
"""
if not batch:
raise ValueError("Cannot determine batch size from empty batch")
# Check common keys first
for key in ["action", "index", "observation.state"]:
if key in batch and isinstance(batch[key], torch.Tensor):
return batch[key].shape[0]
# Scan all values for any tensor
for value in batch.values():
if isinstance(value, torch.Tensor) and value.ndim >= 1:
return value.shape[0]
# Last resort: return 1 (this handles non-tensor batches)
return 1
def get_stats(self) -> dict:
"""Return empty stats for uniform weighting."""
return {"type": "uniform"}

View File

@@ -0,0 +1,345 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the sample weighting infrastructure."""
from unittest.mock import Mock
import pytest
import torch
from lerobot.utils.sample_weighting import (
SampleWeighter,
SampleWeightingConfig,
UniformWeighter,
make_sample_weighter,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def sample_progress_parquet(tmp_path):
"""Create a sample progress parquet file for testing."""
import pandas as pd
# Create sample progress data for 2 episodes with 10 frames each
data = {
"index": list(range(20)),
"episode_index": [0] * 10 + [1] * 10,
"frame_index": list(range(10)) * 2,
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
}
df = pd.DataFrame(data)
parquet_path = tmp_path / "sarm_progress.parquet"
df.to_parquet(parquet_path)
return parquet_path
# =============================================================================
# SampleWeightingConfig Tests
# =============================================================================
def test_config_default_values():
"""Test default configuration values."""
config = SampleWeightingConfig()
assert config.type == "rabc"
assert config.progress_path is None
assert config.head_mode == "sparse"
assert config.kappa == 0.01
assert config.epsilon == 1e-6
assert config.extra_params == {}
def test_config_custom_values():
"""Test configuration with custom values."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
head_mode="dense",
kappa=0.05,
epsilon=1e-8,
extra_params={"fallback_weight": 0.5},
)
assert config.type == "rabc"
assert config.progress_path == "/path/to/progress.parquet"
assert config.head_mode == "dense"
assert config.kappa == 0.05
assert config.epsilon == 1e-8
assert config.extra_params == {"fallback_weight": 0.5}
def test_config_uniform_type():
"""Test configuration for uniform weighting."""
config = SampleWeightingConfig(type="uniform")
assert config.type == "uniform"
# =============================================================================
# UniformWeighter Tests
# =============================================================================
def test_uniform_weighter_inherits_from_sample_weighter():
"""Test that UniformWeighter is a SampleWeighter."""
weighter = UniformWeighter(device=torch.device("cpu"))
assert isinstance(weighter, SampleWeighter)
def test_uniform_weighter_compute_batch_weights_with_action_key():
"""Test weight computation with 'action' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"action": torch.randn(8, 10)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (8,)
assert torch.allclose(weights, torch.ones(8))
assert stats["mean_weight"] == 1.0
assert stats["type"] == "uniform"
def test_uniform_weighter_compute_batch_weights_with_index_key():
"""Test weight computation with 'index' key in batch."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"index": torch.arange(16)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (16,)
assert torch.allclose(weights, torch.ones(16))
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
"""Test weight computation with no tensor keys (fallback to size 1)."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {"other_key": "some_value"}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (1,)
assert torch.allclose(weights, torch.ones(1))
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
"""Test that empty batch raises ValueError."""
weighter = UniformWeighter(device=torch.device("cpu"))
batch = {}
with pytest.raises(ValueError, match="empty batch"):
weighter.compute_batch_weights(batch)
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
"""Test that batch size is determined by scanning all tensor values."""
weighter = UniformWeighter(device=torch.device("cpu"))
# Batch with non-standard key containing a tensor
batch = {"custom_tensor": torch.randn(7, 3)}
weights, stats = weighter.compute_batch_weights(batch)
assert weights.shape == (7,)
assert torch.allclose(weights, torch.ones(7))
def test_uniform_weighter_compute_batch_weights_on_cuda():
"""Test that weights are placed on the correct device."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
weighter = UniformWeighter(device=torch.device("cuda"))
batch = {"action": torch.randn(4, 10)}
weights, _ = weighter.compute_batch_weights(batch)
assert weights.device.type == "cuda"
def test_uniform_weighter_get_stats():
"""Test get_stats returns expected structure."""
weighter = UniformWeighter(device=torch.device("cpu"))
stats = weighter.get_stats()
assert stats == {"type": "uniform"}
# =============================================================================
# make_sample_weighter Factory Tests
# =============================================================================
def test_factory_returns_none_for_none_config():
"""Test that None config returns None weighter."""
policy = Mock()
device = torch.device("cpu")
result = make_sample_weighter(None, policy, device)
assert result is None
def test_factory_creates_uniform_weighter():
"""Test creation of UniformWeighter."""
config = SampleWeightingConfig(type="uniform")
policy = Mock()
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, UniformWeighter)
assert isinstance(weighter, SampleWeighter)
def test_factory_raises_for_unknown_type():
"""Test that unknown type raises ValueError."""
config = SampleWeightingConfig(type="unknown_type")
policy = Mock()
device = torch.device("cpu")
with pytest.raises(ValueError, match="Unknown sample weighting type"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_chunk_size():
"""Test that RABC weighter requires chunk_size in policy config."""
config = SampleWeightingConfig(
type="rabc",
progress_path="/path/to/progress.parquet",
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = None # No chunk_size
device = torch.device("cpu")
with pytest.raises(ValueError, match="chunk_size"):
make_sample_weighter(config, policy, device)
def test_factory_rabc_requires_progress_path():
"""Test that RABC weighter requires progress_path."""
config = SampleWeightingConfig(
type="rabc",
progress_path=None, # No progress path
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 50
device = torch.device("cpu")
with pytest.raises(ValueError, match="progress_path"):
make_sample_weighter(config, policy, device)
# =============================================================================
# Integration Tests with RABCWeights
# =============================================================================
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
"""Test that RABCWeights inherits from SampleWeighter."""
from lerobot.policies.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
assert isinstance(weighter, SampleWeighter)
def test_rabc_compute_batch_weights(sample_progress_parquet):
"""Test RABCWeights.compute_batch_weights returns correct structure."""
from lerobot.policies.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, stats = weighter.compute_batch_weights(batch)
assert isinstance(weights, torch.Tensor)
assert weights.shape == (4,)
assert isinstance(stats, dict)
assert "mean_weight" in stats
def test_rabc_get_stats(sample_progress_parquet):
"""Test RABCWeights.get_stats returns expected structure."""
from lerobot.policies.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
)
stats = weighter.get_stats()
assert stats["type"] == "rabc"
assert "num_frames" in stats
assert "chunk_size" in stats
assert stats["chunk_size"] == 5
assert "head_mode" in stats
assert stats["head_mode"] == "sparse"
assert "delta_mean" in stats
assert "delta_std" in stats
def test_factory_creates_rabc_weighter(sample_progress_parquet):
"""Test factory creates RABCWeights with valid config."""
from lerobot.policies.sarm.rabc import RABCWeights
config = SampleWeightingConfig(
type="rabc",
progress_path=str(sample_progress_parquet),
head_mode="sparse",
kappa=0.01,
)
policy = Mock()
policy.config = Mock()
policy.config.chunk_size = 5
device = torch.device("cpu")
weighter = make_sample_weighter(config, policy, device)
assert isinstance(weighter, RABCWeights)
assert isinstance(weighter, SampleWeighter)
def test_rabc_weights_normalization(sample_progress_parquet):
"""Test that RABCWeights normalizes weights to sum to batch_size."""
from lerobot.policies.sarm.rabc import RABCWeights
weighter = RABCWeights(
progress_path=sample_progress_parquet,
chunk_size=5,
head_mode="sparse",
device=torch.device("cpu"),
)
batch = {"index": torch.tensor([0, 1, 2, 3])}
weights, _ = weighter.compute_batch_weights(batch)
# Weights should be normalized to sum approximately to batch_size
batch_size = 4
assert abs(weights.sum().item() - batch_size) < 0.1