mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
add testing for sampl weighter
This commit is contained in:
@@ -36,6 +36,8 @@ import pandas as pd
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.utils.sample_weighting import SampleWeighter
|
||||
|
||||
|
||||
def resolve_hf_path(path: str | Path) -> Path:
|
||||
"""Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path."""
|
||||
@@ -48,11 +50,11 @@ def resolve_hf_path(path: str | Path) -> Path:
|
||||
return Path(path)
|
||||
|
||||
|
||||
class RABCWeights:
|
||||
class RABCWeights(SampleWeighter):
|
||||
"""
|
||||
Load precomputed SARM progress values and compute RA-BC weights during training.
|
||||
|
||||
This class implements the SampleWeighter protocol for use with the generic
|
||||
This class implements the SampleWeighter ABC for use with the generic
|
||||
sample weighting infrastructure in lerobot.
|
||||
|
||||
Progress values are loaded from a parquet file (generated by compute_rabc_weights.py).
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"""
|
||||
Sample weighting abstraction for training.
|
||||
|
||||
This module provides a generic protocol for sample weighting strategies (e.g., RA-BC)
|
||||
This module provides an abstract base class for sample weighting strategies (e.g., RA-BC)
|
||||
that can be used during training without polluting the training script with
|
||||
policy-specific code.
|
||||
|
||||
@@ -35,8 +35,9 @@ Example usage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@@ -44,11 +45,8 @@ if TYPE_CHECKING:
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SampleWeighter(Protocol):
|
||||
class SampleWeighter(ABC):
|
||||
"""
|
||||
Protocol for sample weighting strategies during training.
|
||||
|
||||
Implementations compute per-sample weights that can be used to weight
|
||||
the loss during training. This enables techniques like:
|
||||
- RA-BC (Reward-Aligned Behavior Cloning)
|
||||
@@ -57,6 +55,7 @@ class SampleWeighter(Protocol):
|
||||
- Quality-based filtering
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""
|
||||
Compute per-sample weights for a training batch.
|
||||
@@ -71,8 +70,8 @@ class SampleWeighter(Protocol):
|
||||
normalized to sum to batch_size for stable gradients.
|
||||
- stats: Dictionary with logging-friendly statistics about the weights.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_stats(self) -> dict:
|
||||
"""
|
||||
Get global statistics about the weighting strategy.
|
||||
@@ -80,7 +79,6 @@ class SampleWeighter(Protocol):
|
||||
Returns:
|
||||
Dictionary with statistics for logging (e.g., mean delta, coverage).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -89,8 +87,8 @@ class SampleWeightingConfig:
|
||||
Configuration for sample weighting during training.
|
||||
|
||||
This is a generic config that supports multiple weighting strategies.
|
||||
The `type` field determines which implementation to use, and `params`
|
||||
contains type-specific parameters.
|
||||
The `type` field determines which implementation to use, and `extra_params`
|
||||
contains additional type-specific parameters.
|
||||
|
||||
Attributes:
|
||||
type: Weighting strategy type ("rabc", "uniform", etc.)
|
||||
@@ -98,6 +96,7 @@ class SampleWeightingConfig:
|
||||
head_mode: Which model head to use for progress ("sparse" or "dense")
|
||||
kappa: Hard threshold for high-quality samples (RABC-specific)
|
||||
epsilon: Small constant for numerical stability
|
||||
extra_params: Additional type-specific parameters passed to the weighter
|
||||
"""
|
||||
|
||||
type: str = "rabc"
|
||||
@@ -178,12 +177,17 @@ def _make_rabc_weighter(
|
||||
)
|
||||
|
||||
|
||||
class UniformWeighter:
|
||||
class UniformWeighter(SampleWeighter):
|
||||
"""
|
||||
No-op sample weighter that returns uniform weights.
|
||||
|
||||
Useful as a baseline or when you want to disable weighting without
|
||||
changing the training code structure.
|
||||
|
||||
Note:
|
||||
Batch size is determined by looking for tensor values in the batch
|
||||
dictionary. The method checks common keys like "action", "index",
|
||||
and "observation.state" first, then falls back to scanning all values.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
@@ -191,19 +195,43 @@ class UniformWeighter:
|
||||
|
||||
def compute_batch_weights(self, batch: dict) -> tuple[torch.Tensor, dict]:
|
||||
"""Return uniform weights (all ones)."""
|
||||
# Determine batch size from batch
|
||||
batch_size = 1
|
||||
for key in ["action", "index"]:
|
||||
if key in batch:
|
||||
val = batch[key]
|
||||
if isinstance(val, torch.Tensor):
|
||||
batch_size = val.shape[0]
|
||||
break
|
||||
batch_size = self._determine_batch_size(batch)
|
||||
|
||||
weights = torch.ones(batch_size, device=self.device)
|
||||
stats = {"mean_weight": 1.0, "type": "uniform"}
|
||||
return weights, stats
|
||||
|
||||
def _determine_batch_size(self, batch: dict) -> int:
|
||||
"""
|
||||
Determine batch size from the batch dictionary.
|
||||
|
||||
Checks common keys first, then scans all values for tensors.
|
||||
|
||||
Args:
|
||||
batch: Training batch dictionary.
|
||||
|
||||
Returns:
|
||||
Batch size, or 1 if it cannot be determined.
|
||||
|
||||
Raises:
|
||||
ValueError: If batch is empty.
|
||||
"""
|
||||
if not batch:
|
||||
raise ValueError("Cannot determine batch size from empty batch")
|
||||
|
||||
# Check common keys first
|
||||
for key in ["action", "index", "observation.state"]:
|
||||
if key in batch and isinstance(batch[key], torch.Tensor):
|
||||
return batch[key].shape[0]
|
||||
|
||||
# Scan all values for any tensor
|
||||
for value in batch.values():
|
||||
if isinstance(value, torch.Tensor) and value.ndim >= 1:
|
||||
return value.shape[0]
|
||||
|
||||
# Last resort: return 1 (this handles non-tensor batches)
|
||||
return 1
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Return empty stats for uniform weighting."""
|
||||
return {"type": "uniform"}
|
||||
|
||||
345
tests/utils/test_sample_weighting.py
Normal file
345
tests/utils/test_sample_weighting.py
Normal file
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the sample weighting infrastructure."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.utils.sample_weighting import (
|
||||
SampleWeighter,
|
||||
SampleWeightingConfig,
|
||||
UniformWeighter,
|
||||
make_sample_weighter,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_progress_parquet(tmp_path):
|
||||
"""Create a sample progress parquet file for testing."""
|
||||
import pandas as pd
|
||||
|
||||
# Create sample progress data for 2 episodes with 10 frames each
|
||||
data = {
|
||||
"index": list(range(20)),
|
||||
"episode_index": [0] * 10 + [1] * 10,
|
||||
"frame_index": list(range(10)) * 2,
|
||||
"progress_sparse": [i / 10.0 for i in range(10)] * 2,
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
parquet_path = tmp_path / "sarm_progress.parquet"
|
||||
df.to_parquet(parquet_path)
|
||||
return parquet_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SampleWeightingConfig Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_config_default_values():
|
||||
"""Test default configuration values."""
|
||||
config = SampleWeightingConfig()
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path is None
|
||||
assert config.head_mode == "sparse"
|
||||
assert config.kappa == 0.01
|
||||
assert config.epsilon == 1e-6
|
||||
assert config.extra_params == {}
|
||||
|
||||
|
||||
def test_config_custom_values():
|
||||
"""Test configuration with custom values."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
head_mode="dense",
|
||||
kappa=0.05,
|
||||
epsilon=1e-8,
|
||||
extra_params={"fallback_weight": 0.5},
|
||||
)
|
||||
assert config.type == "rabc"
|
||||
assert config.progress_path == "/path/to/progress.parquet"
|
||||
assert config.head_mode == "dense"
|
||||
assert config.kappa == 0.05
|
||||
assert config.epsilon == 1e-8
|
||||
assert config.extra_params == {"fallback_weight": 0.5}
|
||||
|
||||
|
||||
def test_config_uniform_type():
|
||||
"""Test configuration for uniform weighting."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
assert config.type == "uniform"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UniformWeighter Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_uniform_weighter_inherits_from_sample_weighter():
|
||||
"""Test that UniformWeighter is a SampleWeighter."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_action_key():
|
||||
"""Test weight computation with 'action' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"action": torch.randn(8, 10)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (8,)
|
||||
assert torch.allclose(weights, torch.ones(8))
|
||||
assert stats["mean_weight"] == 1.0
|
||||
assert stats["type"] == "uniform"
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_with_index_key():
|
||||
"""Test weight computation with 'index' key in batch."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"index": torch.arange(16)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (16,)
|
||||
assert torch.allclose(weights, torch.ones(16))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_no_tensor_keys():
|
||||
"""Test weight computation with no tensor keys (fallback to size 1)."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {"other_key": "some_value"}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (1,)
|
||||
assert torch.allclose(weights, torch.ones(1))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_empty_batch_raises():
|
||||
"""Test that empty batch raises ValueError."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
batch = {}
|
||||
|
||||
with pytest.raises(ValueError, match="empty batch"):
|
||||
weighter.compute_batch_weights(batch)
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_scans_all_keys():
|
||||
"""Test that batch size is determined by scanning all tensor values."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
# Batch with non-standard key containing a tensor
|
||||
batch = {"custom_tensor": torch.randn(7, 3)}
|
||||
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.shape == (7,)
|
||||
assert torch.allclose(weights, torch.ones(7))
|
||||
|
||||
|
||||
def test_uniform_weighter_compute_batch_weights_on_cuda():
|
||||
"""Test that weights are placed on the correct device."""
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
weighter = UniformWeighter(device=torch.device("cuda"))
|
||||
batch = {"action": torch.randn(4, 10)}
|
||||
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert weights.device.type == "cuda"
|
||||
|
||||
|
||||
def test_uniform_weighter_get_stats():
|
||||
"""Test get_stats returns expected structure."""
|
||||
weighter = UniformWeighter(device=torch.device("cpu"))
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats == {"type": "uniform"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# make_sample_weighter Factory Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_factory_returns_none_for_none_config():
|
||||
"""Test that None config returns None weighter."""
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
result = make_sample_weighter(None, policy, device)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_factory_creates_uniform_weighter():
|
||||
"""Test creation of UniformWeighter."""
|
||||
config = SampleWeightingConfig(type="uniform")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, UniformWeighter)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_factory_raises_for_unknown_type():
|
||||
"""Test that unknown type raises ValueError."""
|
||||
config = SampleWeightingConfig(type="unknown_type")
|
||||
policy = Mock()
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown sample weighting type"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_chunk_size():
|
||||
"""Test that RABC weighter requires chunk_size in policy config."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path="/path/to/progress.parquet",
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = None # No chunk_size
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="chunk_size"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
def test_factory_rabc_requires_progress_path():
|
||||
"""Test that RABC weighter requires progress_path."""
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=None, # No progress path
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 50
|
||||
device = torch.device("cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="progress_path"):
|
||||
make_sample_weighter(config, policy, device)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests with RABCWeights
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_rabc_weights_is_sample_weighter(sample_progress_parquet):
|
||||
"""Test that RABCWeights inherits from SampleWeighter."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_compute_batch_weights(sample_progress_parquet):
|
||||
"""Test RABCWeights.compute_batch_weights returns correct structure."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, stats = weighter.compute_batch_weights(batch)
|
||||
|
||||
assert isinstance(weights, torch.Tensor)
|
||||
assert weights.shape == (4,)
|
||||
assert isinstance(stats, dict)
|
||||
assert "mean_weight" in stats
|
||||
|
||||
|
||||
def test_rabc_get_stats(sample_progress_parquet):
|
||||
"""Test RABCWeights.get_stats returns expected structure."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
)
|
||||
|
||||
stats = weighter.get_stats()
|
||||
|
||||
assert stats["type"] == "rabc"
|
||||
assert "num_frames" in stats
|
||||
assert "chunk_size" in stats
|
||||
assert stats["chunk_size"] == 5
|
||||
assert "head_mode" in stats
|
||||
assert stats["head_mode"] == "sparse"
|
||||
assert "delta_mean" in stats
|
||||
assert "delta_std" in stats
|
||||
|
||||
|
||||
def test_factory_creates_rabc_weighter(sample_progress_parquet):
|
||||
"""Test factory creates RABCWeights with valid config."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
config = SampleWeightingConfig(
|
||||
type="rabc",
|
||||
progress_path=str(sample_progress_parquet),
|
||||
head_mode="sparse",
|
||||
kappa=0.01,
|
||||
)
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
policy.config.chunk_size = 5
|
||||
device = torch.device("cpu")
|
||||
|
||||
weighter = make_sample_weighter(config, policy, device)
|
||||
|
||||
assert isinstance(weighter, RABCWeights)
|
||||
assert isinstance(weighter, SampleWeighter)
|
||||
|
||||
|
||||
def test_rabc_weights_normalization(sample_progress_parquet):
|
||||
"""Test that RABCWeights normalizes weights to sum to batch_size."""
|
||||
from lerobot.policies.sarm.rabc import RABCWeights
|
||||
|
||||
weighter = RABCWeights(
|
||||
progress_path=sample_progress_parquet,
|
||||
chunk_size=5,
|
||||
head_mode="sparse",
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
batch = {"index": torch.tensor([0, 1, 2, 3])}
|
||||
weights, _ = weighter.compute_batch_weights(batch)
|
||||
|
||||
# Weights should be normalized to sum approximately to batch_size
|
||||
batch_size = 4
|
||||
assert abs(weights.sum().item() - batch_size) < 0.1
|
||||
Reference in New Issue
Block a user