mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Resolves conflicts from 66 commits on the base branch: * pyproject.toml — keep base's transformers>=5.4.0,<5.6.0; add the sentencepiece-dep entry pi052 (FAST action tokenizer) needs. * policies/__init__.py — keep pi052 export; drop the RewardClassifierConfig export that base removed. * policies/factory.py — docstring list resolution (keep pi052; drop reward_classifier, removed by base). * annotations/steerable_pipeline/executor.py — adopt base's renamed _ensure_annotation_metadata_in_info (it already advertises the say tool); drop pi052's older _ensure_tools_in_info call. * configs/train.py — keep pi052's vqa_target_fraction; adopt base's SampleWeightingConfig (legacy RA-BC inline params already covered by the migration shim base added). * scripts/lerobot_train.py — merge pi052's per-policy processor rebuild + dataset_repo_id pass-through with base's active_cfg / is_reward_model_training tightening, and re-route vqa-weighted sampler to active_cfg.drop_n_last_frames. * datasets/language_render.py — adopt base's _select_one + timestamp tolerance (drops pi052's stale _select_latest / per-style sort_key). * tests — adopt base's parametrized per-camera blend + tolerance test; drop pi052 tests that overlap with base's tighter rewrites; keep pi052's flow-only / VQA-blend coverage; add a test_canonical_recipe_loads check on subtask_mem_vqa_speech.yaml. * policies/pi052/processor_pi052.py — import RenderMessagesStep directly from render_messages_processor (base intentionally dropped it from lerobot.processor's re-exports). * uv.lock — regenerated cleanly from base + pi052's pocket-tts / beartype. All 67 touched tests pass (30 pi052 + 37 recipe / language-render / pipeline / render-messages). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
66 lines
2.5 KiB
Python
66 lines
2.5 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from torch.utils.data._utils.collate import default_collate
|
|
|
|
from lerobot.datasets.language import LANGUAGE_COLUMNS
|
|
|
|
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices", *LANGUAGE_COLUMNS}
|
|
|
|
|
|
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
|
|
"""Collate function that preserves Python-list and language fields as lists.
|
|
|
|
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
|
|
rendered-message and language fields as plain Python lists, and delegates
|
|
every other key to PyTorch's ``default_collate``.
|
|
"""
|
|
batch = [sample for sample in batch if sample is not None]
|
|
if not batch:
|
|
return None
|
|
|
|
# All-or-nothing per key: a partial-presence batch (e.g. half the samples
|
|
# carry `messages` and half don't) is a real bug in the upstream
|
|
# rendering step — silently filtering would hand downstream consumers a
|
|
# preserved list shorter than the tensor batch. Raise instead so the
|
|
# mismatch surfaces at the boundary.
|
|
preserved: dict[str, list[Any]] = {}
|
|
for key in _PYTHON_LIST_KEYS:
|
|
presence = [key in sample for sample in batch]
|
|
if not any(presence):
|
|
continue
|
|
if not all(presence):
|
|
raise ValueError(
|
|
f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; "
|
|
f"every sample in a batch must agree."
|
|
)
|
|
preserved[key] = [sample[key] for sample in batch]
|
|
tensorizable = [
|
|
{
|
|
key: value
|
|
for key, value in sample.items()
|
|
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
|
|
}
|
|
for sample in batch
|
|
]
|
|
collated = default_collate(tensorizable)
|
|
collated.update(preserved)
|
|
return collated
|