From 0fb5f049657ffadd8452c30a5aa9ac26e0af8b68 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 5 May 2026 11:59:57 +0200 Subject: [PATCH] fix(smolvla2): handle BatchEncoding return from apply_chat_template MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``tokenizer.apply_chat_template(..., tokenize=True, return_tensors='pt')`` on newer transformers returns a ``BatchEncoding`` (dict-like) rather than a raw ``Tensor`` — particularly when the underlying call routes through a processor. ``_build_text_batch`` only handled the ``Tensor`` and ``list`` shapes, so the encoding object reached SmolVLA's ``embed_language_tokens`` and ``F.embedding`` blew up with ``argument 'indices' must be Tensor, not BatchEncoding`` on every high-level forward. Normalise the return: * ``BatchEncoding`` / ``dict`` → take ``input_ids`` (and the encoder's ``attention_mask`` when present, since ``pad_token_id`` can be ``None`` for SmolVLM and the fall-back ``ids != pad_token_id`` breaks then), * ``list[int]`` / ``list[list[int]]`` → wrap in a long tensor, * ``Tensor`` → keep as-is. After unwrapping, ensure shape ``(1, seq)`` and that ``attention_mask`` is a tensor on the same device as ``ids``. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../policies/smolvla2/inference/steps.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index b0951c789..b6be411ef 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -171,17 +171,46 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic tokenizer.pad_token = tokenizer.eos_token text_messages = [_strip_recipe_keys(m) for m in prompt_messages] - ids = tokenizer.apply_chat_template( + encoded = tokenizer.apply_chat_template( text_messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", ) + # ``apply_chat_template`` can return any of: + # - a Tensor of shape ``(seq,)`` or ``(1, seq)`` (older transformers), + # - a list[int] / list[list[int]] (when ``return_tensors`` is ignored), + # - a ``BatchEncoding`` dict-like with ``input_ids`` / ``attention_mask`` + # (newer transformers, especially via processor.apply_chat_template + # forwarding through here). + # Normalise to ``ids: Tensor[1, seq]`` and grab the encoder's + # attention mask when available so we don't have to re-derive it + # from ``pad_token_id`` (which can be ``None`` for SmolVLM). + attn: Any = None + if hasattr(encoded, "input_ids"): + ids = encoded.input_ids + attn = getattr(encoded, "attention_mask", None) + elif isinstance(encoded, dict) and "input_ids" in encoded: + ids = encoded["input_ids"] + attn = encoded.get("attention_mask") + else: + ids = encoded if isinstance(ids, list): - ids = ids[0] if ids else [] + if ids and isinstance(ids[0], list): + ids = ids[0] + import torch # noqa: PLC0415 + + ids = torch.tensor(ids, dtype=torch.long) if hasattr(ids, "ndim") and ids.ndim == 1: ids = ids.unsqueeze(0) - attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None + if attn is None and tokenizer.pad_token_id is not None: + attn = ids != tokenizer.pad_token_id + elif isinstance(attn, list): + import torch # noqa: PLC0415 + + attn = torch.tensor(attn, dtype=torch.long) + if attn.ndim == 1: + attn = attn.unsqueeze(0) # Move tokens onto the policy's device — otherwise prefix embedding # raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA # model), which the caller's broad except would swallow silently.