mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
feat(rewards): add TOPReward reward model (#3629)
* feat(rewards): add TOPReward reward model * refactor(rewards): clean up TOPReward processor/model * fix(rewards/topreward): add missing input keys mm_token_type_ids * fix(rewards/topreward): fix pyproject extra typo and simplify processor (#3653) Add lerobot[topreward] extra to all in pyproject.toml, drop the redundant labels arg in scoring, and collapse the dead-branch shape check in the encoder processor. * optmize topreward input processing (#3660) --------- Co-authored-by: Cole <91766445+jcoleharrison@users.noreply.github.com> Co-authored-by: Haoming Song <haomingsong24@gmail.com>
This commit is contained in:
296
tests/rewards/test_modeling_topreward.py
Normal file
296
tests/rewards/test_modeling_topreward.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the TOPReward reward model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.rewards.factory import get_reward_model_class, make_reward_model_config
|
||||
from lerobot.rewards.topreward import TOPRewardConfig
|
||||
from lerobot.rewards.topreward.processor_topreward import TOPREWARD_FEATURE_PREFIX, TOPREWARD_INPUT_KEYS
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
|
||||
class _FakeQwenModel(torch.nn.Module):
|
||||
"""Stand-in for ``Qwen3VLForConditionalGeneration``.
|
||||
|
||||
Returns a ``SimpleNamespace`` with ``logits`` of a controlled shape so
|
||||
the log-prob extraction path in ``compute_reward`` can be exercised
|
||||
without downloading real VLM weights.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._param = torch.nn.Parameter(torch.zeros(1))
|
||||
self._reward_value: float = -1.5
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs): # noqa: ARG003
|
||||
return cls()
|
||||
|
||||
def forward( # noqa: ARG002
|
||||
self, input_ids, attention_mask=None, labels=None, logits_to_keep=0, **kwargs
|
||||
):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
vocab_size = 1000
|
||||
logits = torch.zeros(batch_size, seq_len, vocab_size)
|
||||
# Place a controlled log-prob at the target token position so the
|
||||
# model returns a predictable reward value.
|
||||
# The label-masked suffix is the last token.
|
||||
# After the causal-LM shift (logits[:, :-1], labels[:, 1:]) the scored
|
||||
# position is logits[:, -2, :] predicting labels[:, -1].
|
||||
# We set logits so that log_softmax at the target token ≈ _reward_value.
|
||||
for i in range(batch_size):
|
||||
target_idx = int(input_ids[i, -1].item())
|
||||
logits[i, -2, target_idx] = self._reward_value * -10 # high logit -> high log-prob
|
||||
if logits_to_keep:
|
||||
logits = logits[:, -logits_to_keep:, :]
|
||||
return SimpleNamespace(logits=logits)
|
||||
|
||||
|
||||
def _patch_build(monkeypatch) -> None:
|
||||
"""Stub out HF AutoX so TOPReward construction is cheap and offline."""
|
||||
from lerobot.rewards.topreward import modeling_topreward
|
||||
|
||||
monkeypatch.setattr(modeling_topreward, "Qwen3VLForConditionalGeneration", _FakeQwenModel)
|
||||
|
||||
|
||||
def _make_batch(
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
labels: torch.Tensor | None = None,
|
||||
*,
|
||||
omit: str | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Build a ``compute_reward``-ready batch using TOPReward's namespaced keys."""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
|
||||
batch: dict[str, torch.Tensor] = {}
|
||||
if labels is not None:
|
||||
batch[f"{TOPREWARD_FEATURE_PREFIX}labels"] = labels
|
||||
batch.update(
|
||||
{
|
||||
f"{TOPREWARD_FEATURE_PREFIX}input_ids": input_ids,
|
||||
f"{TOPREWARD_FEATURE_PREFIX}attention_mask": attention_mask,
|
||||
f"{TOPREWARD_FEATURE_PREFIX}pixel_values_videos": torch.zeros(
|
||||
batch_size, 1536, dtype=torch.float32
|
||||
),
|
||||
f"{TOPREWARD_FEATURE_PREFIX}video_grid_thw": torch.ones(batch_size, 3, dtype=torch.long),
|
||||
f"{TOPREWARD_FEATURE_PREFIX}mm_token_type_ids": torch.zeros_like(input_ids),
|
||||
}
|
||||
)
|
||||
if omit is not None:
|
||||
batch.pop(f"{TOPREWARD_FEATURE_PREFIX}{omit}", None)
|
||||
return batch
|
||||
|
||||
|
||||
def _terminal_labels(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
labels = torch.full_like(input_ids, -100)
|
||||
labels[:, -1] = input_ids[:, -1]
|
||||
return labels
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry + factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_topreward_config_registered():
|
||||
assert "topreward" in RewardModelConfig.get_known_choices()
|
||||
assert RewardModelConfig.get_choice_class("topreward") is TOPRewardConfig
|
||||
assert isinstance(make_reward_model_config("topreward", device="cpu"), TOPRewardConfig)
|
||||
|
||||
|
||||
def test_topreward_factory_returns_in_tree_class():
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
assert get_reward_model_class("topreward") is TOPRewardModel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_topreward_config_rejects_zero_max_frames():
|
||||
with pytest.raises(ValueError, match="max_frames must be >= 1"):
|
||||
TOPRewardConfig(device="cpu", max_frames=0)
|
||||
|
||||
|
||||
def test_topreward_config_rejects_non_positive_fps():
|
||||
with pytest.raises(ValueError, match="fps must be > 0"):
|
||||
TOPRewardConfig(device="cpu", fps=0.0)
|
||||
|
||||
|
||||
def test_topreward_config_rejects_suffix_without_instruction_placeholder():
|
||||
with pytest.raises(ValueError, match=r"\{instruction\}"):
|
||||
TOPRewardConfig(device="cpu", prompt_suffix_template="no placeholder here")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_reward
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_returns_one_scalar_per_sample(monkeypatch):
|
||||
"""``compute_reward`` must return a ``(B,)`` float32 tensor with one
|
||||
log-prob reward per sample, consuming pre-encoded Qwen-VL tensors."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
input_ids = torch.randint(0, 100, (2, 10))
|
||||
attention_mask = torch.ones(2, 10, dtype=torch.long)
|
||||
labels = _terminal_labels(input_ids)
|
||||
|
||||
batch = _make_batch(input_ids, attention_mask, labels)
|
||||
rewards = model.compute_reward(batch)
|
||||
|
||||
assert rewards.shape == (2,)
|
||||
assert rewards.dtype == torch.float32
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_applies_success_threshold(monkeypatch):
|
||||
"""When ``success_threshold`` is finite, the model returns binary success."""
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu", success_threshold=0.0)
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
input_ids = torch.randint(0, 100, (2, 10))
|
||||
attention_mask = torch.ones(2, 10, dtype=torch.long)
|
||||
labels = _terminal_labels(input_ids)
|
||||
|
||||
batch = _make_batch(input_ids, attention_mask, labels)
|
||||
rewards = model.compute_reward(batch)
|
||||
|
||||
assert rewards.shape == (2,)
|
||||
assert set(rewards.tolist()).issubset({0.0, 1.0})
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_errors_when_inputs_missing(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
with pytest.raises(KeyError, match=r"observation\.topreward\.input_ids"):
|
||||
model.compute_reward(_make_batch(torch.randint(0, 100, (1, 10)), omit="input_ids"))
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_errors_when_labels_missing(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
input_ids = torch.randint(0, 100, (1, 10))
|
||||
with pytest.raises(KeyError, match=r"observation\.topreward\.labels"):
|
||||
model.compute_reward(_make_batch(input_ids, labels=None))
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_compute_reward_requires_all_encoder_keys(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
input_ids = torch.randint(0, 100, (1, 10))
|
||||
labels = _terminal_labels(input_ids)
|
||||
required_encoder_keys = set(TOPREWARD_INPUT_KEYS) - {"input_ids", "labels"}
|
||||
|
||||
for key in required_encoder_keys:
|
||||
with pytest.raises(KeyError, match=rf"observation\.topreward\.{key}"):
|
||||
model.compute_reward(_make_batch(input_ids, labels=labels, omit=key))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Save / load — config-only checkpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_save_pretrained_writes_only_config_json(monkeypatch, tmp_path):
|
||||
from huggingface_hub.constants import CONFIG_NAME, SAFETENSORS_SINGLE_FILE
|
||||
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(
|
||||
device="cpu",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
fps=4.0,
|
||||
image_key="observation.images.front",
|
||||
)
|
||||
model = TOPRewardModel(cfg)
|
||||
model.save_pretrained(str(tmp_path))
|
||||
|
||||
assert (tmp_path / CONFIG_NAME).exists()
|
||||
assert not (tmp_path / SAFETENSORS_SINGLE_FILE).exists()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_from_pretrained_local_dir_roundtrips_config(monkeypatch, tmp_path):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(
|
||||
device="cpu",
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
fps=4.0,
|
||||
image_key="observation.images.front",
|
||||
add_chat_template=True,
|
||||
success_threshold=-1.5,
|
||||
)
|
||||
TOPRewardModel(cfg).save_pretrained(str(tmp_path))
|
||||
|
||||
reloaded = TOPRewardModel.from_pretrained(str(tmp_path))
|
||||
|
||||
assert isinstance(reloaded.config, TOPRewardConfig)
|
||||
assert reloaded.config.vlm_name == "Qwen/Qwen3-VL-8B-Instruct"
|
||||
assert reloaded.config.fps == 4.0
|
||||
assert reloaded.config.image_key == "observation.images.front"
|
||||
assert reloaded.config.add_chat_template is True
|
||||
assert reloaded.config.success_threshold == -1.5
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_topreward_is_not_trainable(monkeypatch):
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
_patch_build(monkeypatch)
|
||||
cfg = TOPRewardConfig(device="cpu")
|
||||
model = TOPRewardModel(cfg)
|
||||
|
||||
assert model.is_trainable is False
|
||||
with pytest.raises(NotImplementedError, match="not trainable"):
|
||||
model.forward({"x": torch.zeros(1)})
|
||||
80
tests/rewards/test_topreward.py
Normal file
80
tests/rewards/test_topreward.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""End-to-end TOPReward smoke test with the real Qwen3-VL model."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
pytest.importorskip("transformers")
|
||||
|
||||
from lerobot.rewards.topreward.configuration_topreward import TOPRewardConfig # noqa: E402
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel # noqa: E402
|
||||
from lerobot.rewards.topreward.processor_topreward import ( # noqa: E402
|
||||
TOPREWARD_FEATURE_PREFIX,
|
||||
TOPREWARD_INPUT_KEYS,
|
||||
make_topreward_pre_post_processors,
|
||||
)
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires downloading and loading Qwen3-VL and is not meant for CI",
|
||||
)
|
||||
|
||||
|
||||
def _make_dummy_topreward_batch(image_key: str, task_key: str) -> dict[str, object]:
|
||||
num_frames = 4
|
||||
image_size = 64
|
||||
frames = torch.zeros(1, num_frames, 3, image_size, image_size, dtype=torch.uint8)
|
||||
for frame_idx in range(num_frames):
|
||||
frames[0, frame_idx, 0].fill_(min(frame_idx * 48, 255))
|
||||
frames[0, frame_idx, 1].fill_(96)
|
||||
frames[0, frame_idx, 2].fill_(192)
|
||||
|
||||
return {
|
||||
image_key: frames,
|
||||
task_key: ["pick up the red cube"],
|
||||
}
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_topreward_full_qwen3vl_preprocessor_to_compute_reward():
|
||||
cfg = TOPRewardConfig(
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
device="cuda",
|
||||
max_frames=4,
|
||||
fps=2.0,
|
||||
max_input_length=4096,
|
||||
)
|
||||
|
||||
preprocessor, _ = make_topreward_pre_post_processors(cfg)
|
||||
encoded_batch = preprocessor(_make_dummy_topreward_batch(cfg.image_key, cfg.task_key))
|
||||
for key in TOPREWARD_INPUT_KEYS:
|
||||
assert f"{TOPREWARD_FEATURE_PREFIX}{key}" in encoded_batch
|
||||
|
||||
model = TOPRewardModel(cfg)
|
||||
try:
|
||||
model.to(cfg.device)
|
||||
model.eval()
|
||||
rewards = model.compute_reward(encoded_batch)
|
||||
finally:
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert rewards.shape == (1,)
|
||||
assert rewards.dtype == torch.float32
|
||||
assert torch.isfinite(rewards).all()
|
||||
246
tests/rewards/test_topreward_processor.py
Normal file
246
tests/rewards/test_topreward_processor.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for TOPReward's pre-processing helpers and encoder step."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.rewards.topreward.processor_topreward import (
|
||||
TOPREWARD_FEATURE_PREFIX,
|
||||
TOPREWARD_INPUT_KEYS,
|
||||
_expand_tasks,
|
||||
_prepare_video_batch,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _prepare_video_batch — raw image/video batch -> (B, T, C, H, W) uint8
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_prepare_video_batch_batched_chw_float_is_converted_to_uint8():
|
||||
video = torch.rand(2, 4, 3, 8, 8)
|
||||
tensor = _prepare_video_batch(video, max_frames=None)
|
||||
|
||||
assert tensor.shape == (2, 4, 3, 8, 8)
|
||||
assert tensor.dtype == torch.uint8
|
||||
assert tensor.min() >= 0 and tensor.max() <= 255
|
||||
|
||||
|
||||
def test_prepare_video_batch_batched_thwc_uint8_is_permuted_to_channel_first():
|
||||
video = torch.randint(0, 256, (2, 3, 8, 8, 3), dtype=torch.uint8)
|
||||
tensor = _prepare_video_batch(video, max_frames=None)
|
||||
|
||||
assert tensor.shape == (2, 3, 3, 8, 8)
|
||||
assert tensor.dtype == torch.uint8
|
||||
|
||||
|
||||
def test_prepare_video_batch_max_frames_tail_crops_recent_frames():
|
||||
video = torch.zeros(1, 10, 3, 4, 4)
|
||||
for t in range(10):
|
||||
video[:, t] = t / 9.0
|
||||
|
||||
tensor = _prepare_video_batch(video, max_frames=3)
|
||||
|
||||
assert tensor.shape == (1, 3, 3, 4, 4)
|
||||
assert int(tensor[0, 0, 0, 0, 0]) == int(7 / 9 * 255)
|
||||
assert int(tensor[0, -1, 0, 0, 0]) == 255
|
||||
|
||||
|
||||
def test_prepare_video_batch_rejects_3d_input():
|
||||
with pytest.raises(ValueError, match="Expected TOPReward frames"):
|
||||
_prepare_video_batch(torch.zeros(4, 8, 8), max_frames=None)
|
||||
|
||||
|
||||
def test_prepare_video_batch_floats_above_one_are_rescaled_and_clipped():
|
||||
video = torch.full((1, 1, 3, 2, 2), 5.0)
|
||||
tensor = _prepare_video_batch(video, max_frames=None)
|
||||
|
||||
assert tensor.shape == (1, 1, 3, 2, 2)
|
||||
assert int(tensor.max()) == 255
|
||||
|
||||
|
||||
def test_prepare_video_batch_clips_very_large_floats_to_uint8_max():
|
||||
video = torch.full((1, 1, 3, 2, 2), 300.0)
|
||||
tensor = _prepare_video_batch(video, max_frames=None)
|
||||
|
||||
assert int(tensor.max()) == 255
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _expand_tasks — string / list / tuple broadcasting to batch size
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_expand_tasks_string_is_broadcast_to_batch_size():
|
||||
assert _expand_tasks("pick up", batch_size=3, default=None) == ["pick up", "pick up", "pick up"]
|
||||
|
||||
|
||||
def test_expand_tasks_list_of_matching_size_passes_through():
|
||||
assert _expand_tasks(["a", "b", "c"], batch_size=3, default=None) == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_expand_tasks_tuple_is_normalised_to_list():
|
||||
assert _expand_tasks(("a", "b"), batch_size=2, default=None) == ["a", "b"]
|
||||
|
||||
|
||||
def test_expand_tasks_single_element_list_is_broadcast():
|
||||
assert _expand_tasks(["only one"], batch_size=3, default=None) == ["only one"] * 3
|
||||
|
||||
|
||||
def test_expand_tasks_size_mismatch_raises():
|
||||
with pytest.raises(ValueError, match="Expected 3 tasks"):
|
||||
_expand_tasks(["a", "b"], batch_size=3, default=None)
|
||||
|
||||
|
||||
def test_expand_tasks_missing_uses_default():
|
||||
assert _expand_tasks(None, batch_size=2, default="fallback") == ["fallback", "fallback"]
|
||||
|
||||
|
||||
def test_expand_tasks_missing_without_default_raises():
|
||||
with pytest.raises(KeyError, match="task description"):
|
||||
_expand_tasks(None, batch_size=1, default=None)
|
||||
|
||||
|
||||
def test_expand_tasks_wrong_type_raises():
|
||||
with pytest.raises(TypeError, match="must be a string or list"):
|
||||
_expand_tasks(42, batch_size=1, default=None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder step — stubbed AutoProcessor
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _skip_if_topreward_extras_missing(func):
|
||||
func = skip_if_package_missing("transformers")(func)
|
||||
return func
|
||||
|
||||
|
||||
class _FakeTokenizer:
|
||||
eos_token = "<|endoftext|>"
|
||||
pad_token = "<|endoftext|>"
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return {"input_ids": torch.zeros(1, 10, dtype=torch.long)}
|
||||
|
||||
|
||||
class _FakeAutoProcessor:
|
||||
def __init__(self) -> None:
|
||||
self.tokenizer = _FakeTokenizer()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs): # noqa: ARG003
|
||||
return cls()
|
||||
|
||||
def apply_chat_template(self, messages, **kwargs): # noqa: ARG002
|
||||
return "fake_prompt_text"
|
||||
|
||||
def __call__(self, text=None, images=None, videos=None, **kwargs): # noqa: ARG002
|
||||
seq_len = 10
|
||||
batch_size = len(text) if isinstance(text, list) else 1
|
||||
return {
|
||||
"input_ids": torch.randint(0, 100, (batch_size, seq_len)),
|
||||
"attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
|
||||
"pixel_values_videos": torch.zeros(batch_size, 1536, dtype=torch.float32),
|
||||
"video_grid_thw": torch.ones(batch_size, 3, dtype=torch.long),
|
||||
"mm_token_type_ids": torch.zeros(batch_size, seq_len, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def _build_step(monkeypatch, **overrides):
|
||||
from lerobot.rewards.topreward import processor_topreward
|
||||
|
||||
monkeypatch.setattr(processor_topreward, "AutoProcessor", _FakeAutoProcessor)
|
||||
return processor_topreward.TOPRewardEncoderProcessorStep(**overrides)
|
||||
|
||||
|
||||
def _make_transition(observation: dict, complementary: dict | None = None) -> dict:
|
||||
transition: dict = {TransitionKey.OBSERVATION: observation}
|
||||
if complementary is not None:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = complementary
|
||||
return transition
|
||||
|
||||
|
||||
@_skip_if_topreward_extras_missing
|
||||
def test_encoder_step_emits_input_ids_and_labels(monkeypatch):
|
||||
"""The processor must emit Qwen-VL tensors including ``input_ids`` and
|
||||
``labels`` under the ``observation.topreward.*`` namespace."""
|
||||
step = _build_step(monkeypatch)
|
||||
|
||||
frames_batch = torch.zeros(2, 4, 3, 8, 8)
|
||||
out = step(
|
||||
_make_transition(
|
||||
observation={"observation.images.top": frames_batch},
|
||||
complementary={"task": ["pick", "place"]},
|
||||
)
|
||||
)
|
||||
|
||||
obs_out = out[TransitionKey.OBSERVATION]
|
||||
for key in TOPREWARD_INPUT_KEYS:
|
||||
assert f"{TOPREWARD_FEATURE_PREFIX}{key}" in obs_out
|
||||
|
||||
input_ids = obs_out[f"{TOPREWARD_FEATURE_PREFIX}input_ids"]
|
||||
labels = obs_out[f"{TOPREWARD_FEATURE_PREFIX}labels"]
|
||||
assert labels.dtype == torch.long
|
||||
assert labels.shape == (2, 10)
|
||||
assert labels[:, :-1].eq(-100).all()
|
||||
assert labels[:, -1].equal(input_ids[:, -1])
|
||||
|
||||
|
||||
@_skip_if_topreward_extras_missing
|
||||
def test_encoder_step_get_config_roundtrips_user_fields(monkeypatch):
|
||||
step = _build_step(
|
||||
monkeypatch,
|
||||
vlm_name="Qwen/Qwen3-VL-8B-Instruct",
|
||||
image_key="observation.images.cam_top",
|
||||
task_key="task",
|
||||
default_task="do the thing",
|
||||
max_frames=8,
|
||||
fps=4.0,
|
||||
add_chat_template=True,
|
||||
max_length=2048,
|
||||
)
|
||||
|
||||
cfg = step.get_config()
|
||||
assert cfg["vlm_name"] == "Qwen/Qwen3-VL-8B-Instruct"
|
||||
assert cfg["image_key"] == "observation.images.cam_top"
|
||||
assert cfg["default_task"] == "do the thing"
|
||||
assert cfg["max_frames"] == 8
|
||||
assert cfg["fps"] == 4.0
|
||||
assert cfg["add_chat_template"] is True
|
||||
assert cfg["max_length"] == 2048
|
||||
|
||||
|
||||
@_skip_if_topreward_extras_missing
|
||||
def test_encoder_step_transform_features_is_identity(monkeypatch):
|
||||
step = _build_step(monkeypatch)
|
||||
features = {
|
||||
PipelineFeatureType.OBSERVATION: {
|
||||
"observation.images.top": PolicyFeature(shape=(3, 224, 224), type=FeatureType.VISUAL),
|
||||
}
|
||||
}
|
||||
assert step.transform_features(features) == features
|
||||
|
||||
|
||||
@_skip_if_topreward_extras_missing
|
||||
def test_encoder_step_rejects_missing_image_key(monkeypatch):
|
||||
step = _build_step(monkeypatch, image_key="observation.images.top")
|
||||
with pytest.raises(KeyError, match="image key"):
|
||||
step(_make_transition(observation={}, complementary={"task": "pick"}))
|
||||
Reference in New Issue
Block a user