diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index b0951c789..b6be411ef 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -171,17 +171,46 @@ def _build_text_batch(policy: Any, prompt_messages: list[dict[str, Any]]) -> dic tokenizer.pad_token = tokenizer.eos_token text_messages = [_strip_recipe_keys(m) for m in prompt_messages] - ids = tokenizer.apply_chat_template( + encoded = tokenizer.apply_chat_template( text_messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", ) + # ``apply_chat_template`` can return any of: + # - a Tensor of shape ``(seq,)`` or ``(1, seq)`` (older transformers), + # - a list[int] / list[list[int]] (when ``return_tensors`` is ignored), + # - a ``BatchEncoding`` dict-like with ``input_ids`` / ``attention_mask`` + # (newer transformers, especially via processor.apply_chat_template + # forwarding through here). + # Normalise to ``ids: Tensor[1, seq]`` and grab the encoder's + # attention mask when available so we don't have to re-derive it + # from ``pad_token_id`` (which can be ``None`` for SmolVLM). + attn: Any = None + if hasattr(encoded, "input_ids"): + ids = encoded.input_ids + attn = getattr(encoded, "attention_mask", None) + elif isinstance(encoded, dict) and "input_ids" in encoded: + ids = encoded["input_ids"] + attn = encoded.get("attention_mask") + else: + ids = encoded if isinstance(ids, list): - ids = ids[0] if ids else [] + if ids and isinstance(ids[0], list): + ids = ids[0] + import torch # noqa: PLC0415 + + ids = torch.tensor(ids, dtype=torch.long) if hasattr(ids, "ndim") and ids.ndim == 1: ids = ids.unsqueeze(0) - attn = (ids != tokenizer.pad_token_id) if tokenizer.pad_token_id is not None else None + if attn is None and tokenizer.pad_token_id is not None: + attn = ids != tokenizer.pad_token_id + elif isinstance(attn, list): + import torch # noqa: PLC0415 + + attn = torch.tensor(attn, dtype=torch.long) + if attn.ndim == 1: + attn = attn.unsqueeze(0) # Move tokens onto the policy's device — otherwise prefix embedding # raises a device-mismatch on every forward (CPU tensor vs MPS / CUDA # model), which the caller's broad except would swallow silently.