Add extensive language support

This commit is contained in:
Pepijn
2026-04-27 10:56:32 +02:00
parent ba27aab79c
commit 8833d735a1
29 changed files with 1445 additions and 517 deletions

View File

@@ -93,6 +93,7 @@ from .relative_action_processor import (
to_relative_actions,
)
from .rename_processor import RenameObservationsProcessorStep, rename_stats
from .render_messages_processor import RenderMessagesStep
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
__all__ = [
@@ -128,6 +129,7 @@ __all__ = [
"make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"RelativeActionsProcessorStep",
"RenderMessagesStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NewLineTaskProcessorStep",

View File

@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
task_index_value = complementary_data["task_index"]
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
complementary_data["task_index"] = task_index_value.unsqueeze(0)
complementary_data.pop("language_persistent", None)
complementary_data.pop("language_events", None)
if "messages" in complementary_data:
messages = complementary_data["messages"]
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
complementary_data["messages"] = [messages]
if "message_streams" in complementary_data:
streams = complementary_data["message_streams"]
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
complementary_data["message_streams"] = [streams]
if "target_message_indices" in complementary_data:
indices = complementary_data["target_message_indices"]
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
complementary_data["target_message_indices"] = [indices]
return complementary_data
def transform_features(

View File

@@ -171,8 +171,33 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
language_persistent_key = (
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
)
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
target_message_indices_key = (
{"target_message_indices": batch["target_message_indices"]}
if "target_message_indices" in batch
else {}
)
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
return {
**pad_keys,
**task_key,
**subtask_key,
**index_key,
**task_index_key,
**episode_index_key,
**timestamp_key,
**language_persistent_key,
**language_events_key,
**messages_key,
**message_streams_key,
**target_message_indices_key,
}
def create_transition(

View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from lerobot.configs import PipelineFeatureType, PolicyFeature
from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
from lerobot.datasets.language_render import render_sample
from lerobot.types import EnvTransition, TransitionKey
from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep):
recipe: TrainingRecipe
dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
timestamp = complementary_data.get("timestamp")
if timestamp is None:
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
sample_idx = complementary_data.get("index", 0)
rendered = render_sample(
recipe=self.recipe,
persistent=persistent,
events=events,
t=_scalar(timestamp),
sample_idx=int(_scalar(sample_idx)),
task=complementary_data.get("task"),
dataset_ctx=self.dataset_ctx,
)
if rendered is None:
return None
new_transition = transition.copy()
new_complementary_data = dict(complementary_data)
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
new_complementary_data.pop(LANGUAGE_EVENTS, None)
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
def _scalar(value: Any) -> float | int:
if hasattr(value, "item"):
return value.item()
if isinstance(value, list) and len(value) == 1:
return _scalar(value[0])
return value