From af6d8ebd5b9f5f2aaa26389f97dd4678cd18690b Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 30 Apr 2026 19:54:57 +0200 Subject: [PATCH] =?UTF-8?q?feat(smolvla2):=20dual-head=20forward=20?= =?UTF-8?q?=E2=80=94=20flow=20loss=20+=20lm=5Fhead=20text=20loss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The third and final commit of PR 3's SmolVLA2 work. Wires the actual training signal through: * ``predict_actions[i] = True`` → sample i contributes to flow loss * ``text_labels[i, t] != -100`` → token t of sample i contributes to LM-head cross-entropy Both routing knobs come from ``SmolVLA2ChatTokenizerStep`` (previous commit on this branch), which builds them from the recipe's ``message_streams`` / ``target_message_indices``. The per-sample ``predict_actions`` mask preserves the Pi0.5 convention from the plan's Section I.7: "True iff any low_level target exists". Implementation: - ``forward`` reads ``text_labels`` and ``predict_actions`` from the batch. When neither is present (vanilla SmolVLA usage with no recipe), delegates to ``SmolVLAPolicy.forward`` so unannotated datasets keep training as before — full backward compatibility. - ``flow_loss``: super().forward(reduction="none") returns the per-sample (B,) flow loss; we mask non-action samples with the ``predict_actions`` bool and renormalize by the count of action samples. ``flow_loss_weight = 0`` in the config disables this branch entirely (text-only training). - ``text_loss``: a prefix-only forward through the VLM (no action expert / suffix), slicing the lang-token range out of the resulting hidden states (``embed_prefix`` orders the prefix as ``[image_blocks..., lang, state]`` so the slice is unambiguous). Apply ``vlm.lm_head`` to those hidden states, cross-entropy with ``text_labels`` (ignore_index=-100). ``text_loss_weight = 0`` disables this branch (reverts to flow-only behaviour, matching SmolVLA exactly). - The two losses are summed with the config-supplied weights. Mixed-stream samples (one batch containing both action targets and text-only sub-recipes) are handled correctly: each sample contributes where its labels are valid and is masked elsewhere. Limitations / known follow-ups: - Text loss runs an additional prefix-only forward separate from the flow path's prefix forward. The forwards could share their prefix computation; for clarity of this first commit they don't. Optimization is straightforward when needed. - Per-sample loss for ``reduction="none"`` is not yet meaningfully defined for the dual path — we broadcast the scalar to (B,) for caller compatibility (e.g. RA-BC weighting will need follow-up). - Inference ``select_action`` is unchanged from SmolVLA today — it predicts actions only. A separate "generate text" ``select_message`` path is the natural next step for runtime use of the LM head (memory updates, plan refreshes, VQA answers). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../policies/smolvla2/modeling_smolvla2.py | 221 ++++++++++++++---- 1 file changed, 173 insertions(+), 48 deletions(-) diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index 03961d988..7b288cefc 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -13,26 +13,25 @@ # limitations under the License. """SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy. -This module defines :class:`SmolVLA2Policy`, which extends SmolVLA with: +Adds: * an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised, -* a forward path that routes to the flow head, the text head, or both, - driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``. +* a forward path that runs the flow head, the text head, or both, + driven by ``batch["predict_actions"]`` and ``batch["text_labels"]`` + produced by :class:`SmolVLA2ChatTokenizerStep` (the previous commit on + this branch). -The text-head computation itself is NOT wired up in this scaffold commit -(the processor doesn't yet produce ``text_labels`` either). This file is -the structural placeholder that: +Per-sample routing — within one batch: -1. registers the ``SmolVLA2Policy`` class with the right config name so - ``policies/factory.py`` can build it, -2. unfreezes ``lm_head`` at construction time when the config asks for it - (otherwise SmolVLA's ``train_expert_only`` freezes it again on every - ``train()`` call), -3. forwards to ``SmolVLAPolicy.forward`` so behaviour is identical to - SmolVLA when no text labels are present — i.e. existing SmolVLA - training scripts keep working. +* ``predict_actions[i] = True`` ⇒ sample ``i`` contributes to the flow + loss (action chunk supervision). +* ``predict_actions[i] = False`` ⇒ sample ``i`` is masked out of the + flow loss; only its text tokens (where ``text_labels[i, t] != -100``) + contribute to the LM-head cross-entropy. -The next commit on this branch fills in the actual text-loss path. +Falls back to ``SmolVLAPolicy.forward`` cleanly when neither +``text_labels`` nor ``predict_actions`` is in the batch — unannotated +datasets keep working unchanged. """ from __future__ import annotations @@ -40,33 +39,35 @@ from __future__ import annotations from typing import Any import torch +import torch.nn.functional as F from torch import Tensor -from ..smolvla.modeling_smolvla import SmolVLAPolicy +from lerobot.utils.constants import ( + ACTION, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) + +from ..smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks from .configuration_smolvla2 import SmolVLA2Config class SmolVLA2Policy(SmolVLAPolicy): - """SmolVLA + re-enabled SmolVLM language head. - - Compatible drop-in for ``SmolVLAPolicy`` from a checkpoint or factory - perspective. Behaviourally identical to SmolVLA until the text-head - code path lands in the next commit on this branch. - """ + """SmolVLA + re-enabled SmolVLM language head.""" config_class = SmolVLA2Config name = "smolvla2" def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None): if not isinstance(config, SmolVLA2Config): - # Allow loading a SmolVLA checkpoint into a SmolVLA2 model by - # widening the config type — the new fields fall back to their - # defaults, which preserves the existing SmolVLA behaviour. - config = SmolVLA2Config(**{ - f.name: getattr(config, f.name) - for f in config.__dataclass_fields__.values() - if hasattr(config, f.name) - }) + config = SmolVLA2Config( + **{ + f.name: getattr(config, f.name) + for f in config.__dataclass_fields__.values() + if hasattr(config, f.name) + } + ) super().__init__(config, dataset_stats=dataset_stats) if config.unfreeze_lm_head and config.text_loss_weight > 0: self._unfreeze_lm_head() @@ -76,13 +77,8 @@ class SmolVLA2Policy(SmolVLAPolicy): # ------------------------------------------------------------------ def _unfreeze_lm_head(self) -> None: - """Re-enable gradients on the SmolVLM ``lm_head`` (and the bits of - the text path SmolVLA freezes) so the text-loss can flow back. - - SmolVLA's ``SmolVLMWithExpertModel.set_requires_grad`` freezes - ``lm_head``, ``text_model.model.norm.weight``, and the last - ``text_model.layers.`` block. We undo that selectively when - text training is enabled. + """Re-enable gradients on the SmolVLM ``lm_head`` (and the bits + of the text path SmolVLA freezes) so the text-loss can flow back. """ vlm_with_expert = getattr(self.model, "vlm_with_expert", None) if vlm_with_expert is None: @@ -91,10 +87,7 @@ class SmolVLA2Policy(SmolVLAPolicy): if vlm is None: return for name, param in vlm.named_parameters(): - if ( - "lm_head" in name - or "text_model.model.norm.weight" in name - ): + if "lm_head" in name or "text_model.model.norm.weight" in name: param.requires_grad = True # ------------------------------------------------------------------ @@ -108,12 +101,144 @@ class SmolVLA2Policy(SmolVLAPolicy): time: Tensor | None = None, reduction: str = "mean", ) -> tuple[Tensor, dict[str, Any]]: - """Forward pass with optional text-head loss. + """Forward pass with optional dual-head loss. - SCAFFOLD: forwards directly to ``SmolVLAPolicy.forward``. The - actual text-loss / dual-head routing lands in the next commit on - this branch — it will read ``batch["text_labels"]`` and - ``batch["predict_actions"]`` (both produced by the SmolVLA2 - processor) to decide which head(s) to run. + Two routing knobs from the batch (produced by + :class:`SmolVLA2ChatTokenizerStep`): + + * ``text_labels`` — per-token labels with ``-100`` for non-target + positions. Triggers the text-loss path through ``lm_head``. + * ``predict_actions`` — per-sample bool tensor. ``True`` ⇒ + include this sample's action chunk in the flow loss. + + When neither is present, delegate to ``SmolVLAPolicy.forward``. """ - return super().forward(batch, noise=noise, time=time, reduction=reduction) + text_labels = batch.get("text_labels") + predict_actions_t = batch.get("predict_actions") + + has_text_data = ( + text_labels is not None + and isinstance(text_labels, Tensor) + and self.config.text_loss_weight > 0 + ) + has_per_sample_routing = ( + predict_actions_t is not None and isinstance(predict_actions_t, Tensor) + ) + + if not has_text_data and not has_per_sample_routing: + return super().forward(batch, noise=noise, time=time, reduction=reduction) + + loss_dict: dict[str, Any] = {} + device = batch[OBS_STATE].device + total = torch.zeros((), device=device, dtype=torch.float32) + + # ------------------------------------------------------------ + # Flow loss path — only when at least one sample wants actions. + # ------------------------------------------------------------ + run_flow = self.config.flow_loss_weight > 0 and ( + not has_per_sample_routing or bool(predict_actions_t.any().item()) + ) + if run_flow and ACTION in batch: + per_sample_flow, flow_diag = super().forward( + batch, noise=noise, time=time, reduction="none" + ) + # ``per_sample_flow`` has shape (B,) from the SmolVLA + # reduction="none" branch. + if has_per_sample_routing: + mask = predict_actions_t.to(per_sample_flow.dtype) + masked = per_sample_flow * mask + denom = mask.sum().clamp(min=1.0) + flow_loss = masked.sum() / denom + else: + flow_loss = per_sample_flow.mean() + total = total + self.config.flow_loss_weight * flow_loss + loss_dict["flow_loss"] = float(flow_loss.detach().item()) + for k, v in flow_diag.items(): + loss_dict[f"flow_{k}"] = v + + # ------------------------------------------------------------ + # Text loss path — prefix-only forward → lm_head → CE. + # ------------------------------------------------------------ + if has_text_data: + text_loss = self._compute_text_loss(batch, text_labels) + total = total + self.config.text_loss_weight * text_loss + loss_dict["text_loss"] = float(text_loss.detach().item()) + + loss_dict["loss"] = float(total.detach().item()) + + if reduction == "none": + # Per-sample loss isn't meaningfully defined for the dual + # path; broadcast the scalar to (B,) for caller compat. + return total.expand(batch[OBS_STATE].shape[0]), loss_dict + return total, loss_dict + + # ------------------------------------------------------------------ + # Text-loss internals + # ------------------------------------------------------------------ + + def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor: + """Cross-entropy on the SmolVLM ``lm_head`` over target tokens.""" + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens = batch[OBS_LANGUAGE_TOKENS] + lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.model.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + # Prefix-only forward. + out_pair, _ = self.model.vlm_with_expert.forward( + attention_mask=prefix_att_2d_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=False, + fill_kv_cache=False, + ) + prefix_out = out_pair[0] if isinstance(out_pair, (tuple, list)) else out_pair + if prefix_out is None: + raise RuntimeError( + "SmolVLA2: vlm_with_expert.forward returned no prefix hidden " + "states — text-loss path needs them." + ) + + # Lang token positions inside the prefix. ``embed_prefix`` lays + # out the prefix as ``[image_blocks..., lang, state]`` so the + # lang range is identifiable from the trailing state size and + # the known lang length. + num_lang = lang_tokens.shape[1] + state_for_dim = state if state.ndim >= 2 else state[:, None] + num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1 + if num_state < 1: + num_state = 1 + prefix_len = prefix_out.shape[1] + lang_end = prefix_len - num_state + lang_start = lang_end - num_lang + if lang_start < 0 or lang_end > prefix_len: + raise RuntimeError( + f"SmolVLA2: could not locate lang token range in prefix " + f"(prefix_len={prefix_len}, num_lang={num_lang}, " + f"num_state={num_state})." + ) + + lang_hidden = prefix_out[:, lang_start:lang_end] + vlm = self.model.vlm_with_expert.vlm + logits = vlm.lm_head(lang_hidden) # (B, num_lang, vocab) + + if text_labels.shape[1] != num_lang: + common = min(text_labels.shape[1], num_lang) + logits = logits[:, :common] + text_labels = text_labels[:, :common] + + loss = F.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + text_labels.reshape(-1).long(), + ignore_index=-100, + ) + return loss