fix(smolvla2): train on rendered language batches

Keep annotated language columns through collation, render batched recipe samples, and make SmolVLA2 text loss robust enough for distributed training on the steerable dataset.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-05 08:55:56 +00:00
parent 5f7c6ba61d
commit a1b8134ef1
9 changed files with 253 additions and 99 deletions

View File

@@ -175,9 +175,6 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):

View File

@@ -51,6 +51,9 @@ class RenderMessagesStep(ProcessorStep):
if not persistent and not events:
return transition
if _is_batched_language(persistent) or _is_batched_language(events):
return self._call_batch(transition, complementary_data, persistent, events)
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
@@ -69,13 +72,64 @@ class RenderMessagesStep(ProcessorStep):
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def _call_batch(
self,
transition: EnvTransition,
complementary_data: dict[str, Any],
persistent_batch: list,
events_batch: list,
) -> EnvTransition | None:
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
batch_size = max(len(persistent_batch), len(events_batch))
messages: list[list[dict[str, Any]]] = []
message_streams: list[list[str | None]] = []
target_message_indices: list[list[int]] = []
keep_indices: list[int] = []
for i in range(batch_size):
rendered = render_sample(
recipe=self.recipe,
persistent=persistent_batch[i] if i < len(persistent_batch) else [],
events=events_batch[i] if i < len(events_batch) else [],
t=_batch_value(timestamp, i),
sample_idx=int(_batch_value(complementary_data.get("index", 0), i)),
task=_batch_value(complementary_data.get("task"), i),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
continue
keep_indices.append(i)
messages.append(rendered["messages"])
message_streams.append(rendered["message_streams"])
target_message_indices.append(rendered["target_message_indices"])
if not messages:
return None
new_transition = (
_select_batch_indices(transition, keep_indices)
if len(keep_indices) != batch_size
else transition.copy()
)
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data["messages"] = messages
new_complementary_data["message_streams"] = message_streams
new_complementary_data["target_message_indices"] = target_message_indices
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
@@ -90,3 +144,37 @@ def _scalar(value: Any) -> float | int:
if isinstance(value, list) and len(value) == 1:
return _scalar(value[0])
return value
def _is_batched_language(value: Any) -> bool:
return isinstance(value, list) and bool(value) and isinstance(value[0], list)
def _batch_value(value: Any, index: int) -> Any:
if value is None:
return None
if isinstance(value, list):
return value[index]
if hasattr(value, "ndim") and getattr(value, "ndim") > 0:
return _scalar(value[index])
return _scalar(value)
def _select_batch_indices(transition: EnvTransition, indices: list[int]) -> EnvTransition:
selected = transition.copy()
for key in (TransitionKey.OBSERVATION, TransitionKey.COMPLEMENTARY_DATA):
data = selected.get(key)
if isinstance(data, dict):
selected[key] = {k: _select_value(v, indices) for k, v in data.items()}
action = selected.get(TransitionKey.ACTION)
if action is not None:
selected[TransitionKey.ACTION] = _select_value(action, indices)
return selected
def _select_value(value: Any, indices: list[int]) -> Any:
if isinstance(value, list) and len(value) >= len(indices):
return [value[i] for i in indices]
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
return value.index_select(0, value.new_tensor(indices).long())
return value