mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
Add extensive language support
This commit is contained in:
@@ -23,6 +23,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
@@ -41,7 +42,10 @@ __all__ = [
|
||||
# Config classes
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
"TrainingRecipe",
|
||||
"WandBConfig",
|
||||
"load_recipe",
|
||||
]
|
||||
|
||||
167
src/lerobot/configs/recipe.py
Normal file
167
src/lerobot/configs/recipe.py
Normal file
@@ -0,0 +1,167 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
|
||||
DEFAULT_BINDINGS = {
|
||||
"subtask": "active_at(t, style=subtask)",
|
||||
"memory": "active_at(t, style=memory)",
|
||||
"plan": "active_at(t, style=plan)",
|
||||
"speech": "emitted_at(t, role=assistant, tool_name=say)",
|
||||
"interjection": "emitted_at(t, style=interjection)",
|
||||
"vqa": "emitted_at(t, style=vqa, role=assistant)",
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
_VALID_ROLES = {"user", "assistant", "system", "tool"}
|
||||
_VALID_STREAMS = {"high_level", "low_level"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
target: bool = False
|
||||
if_present: str | None = None
|
||||
tool_calls_from: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.role not in _VALID_ROLES:
|
||||
raise ValueError(f"Unsupported message role: {self.role!r}")
|
||||
if self.stream is not None and self.stream not in _VALID_STREAMS:
|
||||
raise ValueError(f"Unsupported message stream: {self.stream!r}")
|
||||
if self.content is None and self.tool_calls_from is None:
|
||||
raise ValueError("MessageTurn.content is required unless tool_calls_from is set.")
|
||||
if self.content is not None and not isinstance(self.content, (str, list)):
|
||||
raise TypeError("MessageTurn.content must be a string, a list of HF-style blocks, or None.")
|
||||
if isinstance(self.content, list):
|
||||
for block in self.content:
|
||||
if not isinstance(block, dict) or "type" not in block:
|
||||
raise ValueError(
|
||||
"Multimodal content blocks must be HF-style dictionaries with a type key."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> MessageTurn:
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
weight: float | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.messages is not None and self.blend is not None:
|
||||
raise ValueError("TrainingRecipe must set only one of messages or blend.")
|
||||
if self.messages is None and self.blend is None:
|
||||
raise ValueError("TrainingRecipe must set one of messages or blend.")
|
||||
|
||||
if self.messages is not None:
|
||||
self._validate_message_recipe()
|
||||
if self.blend is not None:
|
||||
self._validate_blend_recipe()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> TrainingRecipe:
|
||||
data = dict(data)
|
||||
if data.get("messages") is not None:
|
||||
data["messages"] = [
|
||||
turn if isinstance(turn, MessageTurn) else MessageTurn.from_dict(turn)
|
||||
for turn in data["messages"]
|
||||
]
|
||||
if data.get("blend") is not None:
|
||||
data["blend"] = {
|
||||
name: recipe if isinstance(recipe, TrainingRecipe) else cls.from_dict(recipe)
|
||||
for name, recipe in data["blend"].items()
|
||||
}
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str | Path) -> TrainingRecipe:
|
||||
import yaml # type: ignore[import-untyped]
|
||||
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Recipe YAML must contain a mapping at the top level: {path}")
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
for turn in self.messages:
|
||||
missing = self._referenced_bindings(turn) - known_bindings
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
assert self.blend is not None
|
||||
if not self.blend:
|
||||
raise ValueError("Blend recipes must contain at least one component.")
|
||||
|
||||
for name, recipe in self.blend.items():
|
||||
if recipe.blend is not None:
|
||||
raise ValueError(f"Blend component {name!r} cannot itself define a blend.")
|
||||
if recipe.messages is None:
|
||||
raise ValueError(f"Blend component {name!r} must define messages.")
|
||||
if recipe.weight is None:
|
||||
raise ValueError(f"Blend component {name!r} must define weight.")
|
||||
if recipe.weight <= 0:
|
||||
raise ValueError(f"Blend component {name!r} must have a positive weight.")
|
||||
|
||||
def _referenced_bindings(self, turn: MessageTurn) -> set[str]:
|
||||
names: set[str] = set()
|
||||
if turn.if_present is not None:
|
||||
names.add(turn.if_present)
|
||||
if turn.tool_calls_from is not None:
|
||||
names.add(turn.tool_calls_from)
|
||||
names.update(_placeholders_in_content(turn.content))
|
||||
return names
|
||||
|
||||
|
||||
def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[str]:
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(_PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(_PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
47
src/lerobot/configs/recipes/pi05_hirobot.yaml
Normal file
47
src/lerobot/configs/recipes/pi05_hirobot.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
blend:
|
||||
|
||||
memory_update:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "emitted_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.16
|
||||
bindings:
|
||||
prior_plan: "nth_prev(style=plan, offset=1)"
|
||||
current_plan: "emitted_at(t, style=plan)"
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.15
|
||||
bindings:
|
||||
next_subtask: "nth_next(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask}
|
||||
- {role: assistant, content: "${next_subtask}", stream: high_level, target: true}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.35
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
|
||||
|
||||
ask_vqa:
|
||||
weight: 0.20
|
||||
messages:
|
||||
- {role: user, content: "${vqa_query}", stream: high_level, if_present: vqa_query}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -37,6 +37,14 @@ from .dataset_tools import (
|
||||
from .factory import make_dataset, resolve_delta_timestamps
|
||||
from .image_writer import safe_stop_image_writer
|
||||
from .io_utils import load_episodes, write_stats
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
)
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
@@ -53,10 +61,15 @@ __all__ = [
|
||||
"CODEBASE_VERSION",
|
||||
"DEFAULT_EPISODES_PATH",
|
||||
"DEFAULT_QUANTILES",
|
||||
"EVENT_ONLY_STYLES",
|
||||
"EpisodeAwareSampler",
|
||||
"LANGUAGE_EVENTS",
|
||||
"LANGUAGE_PERSISTENT",
|
||||
"LeRobotDataset",
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"PERSISTENT_STYLES",
|
||||
"STYLE_REGISTRY",
|
||||
"StreamingLeRobotDataset",
|
||||
"VideoEncodingManager",
|
||||
"add_features",
|
||||
@@ -66,6 +79,7 @@ __all__ = [
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"column_for_style",
|
||||
"delete_episodes",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
|
||||
@@ -512,7 +512,7 @@ def compute_episode_stats(
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
if features[key]["dtype"] in {"string", "language"}:
|
||||
continue
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
|
||||
@@ -34,7 +34,6 @@ from .io_utils import (
|
||||
load_episodes,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
write_info,
|
||||
write_json,
|
||||
@@ -52,7 +51,7 @@ from .utils import (
|
||||
)
|
||||
from .video_utils import get_video_info
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
CODEBASE_VERSION = "v3.1"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -177,7 +176,6 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -635,7 +633,6 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
|
||||
@@ -295,9 +295,4 @@ class DatasetReader:
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self._meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self._meta.features and self._meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self._meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
@@ -22,6 +22,7 @@ from PIL import Image as PILImage
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import is_language_column, language_column_feature
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -45,7 +46,9 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if is_language_column(key):
|
||||
hf_features[key] = language_column_feature()
|
||||
elif ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -242,6 +245,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return ""
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_di
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_SUBTASKS_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
EPISODES_DIR,
|
||||
INFO_PATH,
|
||||
@@ -189,14 +188,6 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
@@ -268,11 +259,13 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in {"language_persistent", "language_events"}:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif first_item is None:
|
||||
elif first_item is None or isinstance(first_item, dict):
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||
@@ -308,7 +301,11 @@ def item_to_torch(item: dict) -> dict:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in [
|
||||
"task",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
96
src/lerobot/datasets/language.py
Normal file
96
src/lerobot/datasets/language.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
import pyarrow as pa
|
||||
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
LANGUAGE_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls")
|
||||
|
||||
CORE_STYLES = {"subtask", "plan", "memory", "interjection", "vqa"}
|
||||
EXTENDED_STYLES = set()
|
||||
RESERVED_STYLES = {"motion", "trace"}
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES | RESERVED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa"}
|
||||
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def language_row_arrow_type() -> pa.StructType:
|
||||
json_type = pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float64(), nullable=False),
|
||||
pa.field("tool_calls", pa.list_(json_type), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
return pa.list_(language_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
return pa.list_(language_row_arrow_type())
|
||||
|
||||
|
||||
def language_row_feature() -> dict[str, object]:
|
||||
json_feature = datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float64"),
|
||||
"tool_calls": datasets.List(json_feature),
|
||||
}
|
||||
|
||||
|
||||
def language_column_feature() -> datasets.List:
|
||||
return datasets.List(language_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in RESERVED_STYLES:
|
||||
raise ValueError(f"Style {style!r} is registered but has no storage column yet.")
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
445
src/lerobot/datasets/language_render.py
Normal file
445
src/lerobot/datasets/language_render.py
Normal file
@@ -0,0 +1,445 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
||||
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
)
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
|
||||
|
||||
def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow] | None = None,
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
|
||||
matches = [row for row in matches if _timestamp(row) <= t]
|
||||
return _select_latest(matches, style=style, role=role, tool_name=tool_name)
|
||||
|
||||
|
||||
def emitted_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
column = column_for_style(style)
|
||||
rows = persistent if column == LANGUAGE_PERSISTENT else events
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(rows, style=style, role=role, tool_name=tool_name)
|
||||
if _timestamp(row) == t
|
||||
]
|
||||
return _select_exact(matches, style=style, role=role, tool_name=tool_name)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow] | None = None,
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
return _nth_relative(
|
||||
t,
|
||||
persistent=persistent,
|
||||
style=style,
|
||||
offset=-offset,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
resolver_name="nth_prev",
|
||||
)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow] | None = None,
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
return _nth_relative(
|
||||
t,
|
||||
persistent=persistent,
|
||||
style=style,
|
||||
offset=offset,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
resolver_name="nth_next",
|
||||
)
|
||||
|
||||
|
||||
def render_sample(
|
||||
*,
|
||||
recipe: TrainingRecipe,
|
||||
persistent: Sequence[LanguageRow] | None,
|
||||
events: Sequence[LanguageRow] | None,
|
||||
t: float,
|
||||
sample_idx: int,
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
bindings = _resolve_bindings(
|
||||
selected_recipe,
|
||||
persistent=persistent_rows,
|
||||
events=event_rows,
|
||||
t=t,
|
||||
task=task,
|
||||
dataset_ctx=dataset_ctx,
|
||||
)
|
||||
return _render_message_recipe(selected_recipe, bindings)
|
||||
|
||||
|
||||
def _select_recipe(recipe: TrainingRecipe, sample_idx: int) -> TrainingRecipe:
|
||||
if recipe.blend is None:
|
||||
return recipe
|
||||
|
||||
total_weight = sum(component.weight or 0.0 for component in recipe.blend.values())
|
||||
if total_weight <= 0:
|
||||
raise ValueError("Blend weights must sum to a positive value.")
|
||||
|
||||
digest = hashlib.blake2b(str(sample_idx).encode(), digest_size=8).digest()
|
||||
draw = int.from_bytes(digest, "big") / 2**64 * total_weight
|
||||
cumulative = 0.0
|
||||
last_component: TrainingRecipe | None = None
|
||||
for component in recipe.blend.values():
|
||||
last_component = component
|
||||
cumulative += component.weight or 0.0
|
||||
if draw < cumulative:
|
||||
return component
|
||||
assert last_component is not None
|
||||
return last_component
|
||||
|
||||
|
||||
def _resolve_bindings(
|
||||
recipe: TrainingRecipe,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
task: str | None,
|
||||
dataset_ctx: Any | None,
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
bindings: dict[str, LanguageRow | str | None] = {"task": _resolve_task(task, dataset_ctx)}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||
return bindings
|
||||
|
||||
|
||||
def _resolve_task(task: str | None, dataset_ctx: Any | None) -> str | None:
|
||||
if task is not None:
|
||||
return task
|
||||
if dataset_ctx is None:
|
||||
return None
|
||||
if isinstance(dataset_ctx, dict):
|
||||
return dataset_ctx.get("task")
|
||||
return getattr(dataset_ctx, "task", None)
|
||||
|
||||
|
||||
def _resolve_spec(
|
||||
spec: str,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow],
|
||||
t: float,
|
||||
) -> LanguageRow | None:
|
||||
match = _RESOLVER_RE.match(spec.strip())
|
||||
if match is None:
|
||||
raise ValueError(f"Invalid resolver expression: {spec!r}")
|
||||
name = match.group("name")
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
resolvers = {
|
||||
"active_at": active_at,
|
||||
"emitted_at": emitted_at,
|
||||
"nth_prev": nth_prev,
|
||||
"nth_next": nth_next,
|
||||
}
|
||||
if name not in resolvers:
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
return resolvers[name](t, persistent=persistent, events=events, **kwargs)
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if not args.strip():
|
||||
return kwargs
|
||||
|
||||
parts = [part.strip() for part in args.split(",") if part.strip()]
|
||||
for part in parts:
|
||||
if part == "t":
|
||||
kwargs["t_arg"] = True
|
||||
continue
|
||||
if "=" not in part:
|
||||
raise ValueError(f"Invalid resolver argument: {part!r}")
|
||||
key, value = (item.strip() for item in part.split("=", 1))
|
||||
if key == "offset":
|
||||
kwargs[key] = int(value)
|
||||
else:
|
||||
kwargs[key] = value.strip("\"'")
|
||||
return kwargs
|
||||
|
||||
|
||||
def _render_message_recipe(
|
||||
recipe: TrainingRecipe,
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> RenderedMessages | None:
|
||||
assert recipe.messages is not None
|
||||
messages: list[dict[str, Any]] = []
|
||||
streams: list[str | None] = []
|
||||
target_indices: list[int] = []
|
||||
|
||||
for turn in recipe.messages:
|
||||
if turn.if_present is not None and bindings.get(turn.if_present) is None:
|
||||
continue
|
||||
|
||||
message = {"role": turn.role}
|
||||
if turn.content is not None:
|
||||
message["content"] = _render_content(turn.content, bindings)
|
||||
|
||||
if turn.tool_calls_from is not None:
|
||||
row = bindings.get(turn.tool_calls_from)
|
||||
tool_calls = row.get("tool_calls") if isinstance(row, dict) else None
|
||||
if tool_calls:
|
||||
message["tool_calls"] = copy.deepcopy(tool_calls)
|
||||
|
||||
message_idx = len(messages)
|
||||
messages.append(message)
|
||||
streams.append(turn.stream)
|
||||
if turn.target:
|
||||
target_indices.append(message_idx)
|
||||
|
||||
if not target_indices:
|
||||
return None
|
||||
|
||||
rendered = {
|
||||
"messages": messages,
|
||||
"message_streams": streams,
|
||||
"target_message_indices": target_indices,
|
||||
}
|
||||
_validate_rendered(rendered)
|
||||
return rendered
|
||||
|
||||
|
||||
def _render_content(
|
||||
content: str | list[dict[str, Any]],
|
||||
bindings: dict[str, LanguageRow | str | None],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
if isinstance(content, str):
|
||||
return _substitute(content, bindings)
|
||||
|
||||
rendered_blocks = []
|
||||
for block in content:
|
||||
rendered_block = copy.deepcopy(block)
|
||||
for key, value in rendered_block.items():
|
||||
if isinstance(value, str):
|
||||
rendered_block[key] = _substitute(value, bindings)
|
||||
rendered_blocks.append(rendered_block)
|
||||
return rendered_blocks
|
||||
|
||||
|
||||
def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) -> str:
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
name = match.group(1)
|
||||
if name not in bindings:
|
||||
raise ValueError(f"Unknown template binding: {name!r}")
|
||||
value = bindings[name]
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, dict):
|
||||
content = value.get("content")
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return _PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
messages = rendered["messages"]
|
||||
streams = rendered["message_streams"]
|
||||
target_indices = rendered["target_message_indices"]
|
||||
|
||||
if len(streams) != len(messages):
|
||||
raise ValueError("message_streams must be aligned with messages.")
|
||||
if not target_indices:
|
||||
raise ValueError("Rendered samples must contain at least one target message.")
|
||||
for idx in target_indices:
|
||||
if idx < 0 or idx >= len(messages):
|
||||
raise ValueError(f"Target message index {idx} is out of bounds.")
|
||||
for idx, stream in enumerate(streams):
|
||||
if stream is None:
|
||||
raise ValueError(f"Rendered message {idx} has no stream.")
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
resolver_name: str,
|
||||
) -> LanguageRow | None:
|
||||
_validate_persistent_resolver(resolver_name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{resolver_name} offset must be non-zero.")
|
||||
|
||||
rows = _sort_rows(_matching_rows(persistent, style=style, role=role, tool_name=tool_name))
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
anchor_idx = None
|
||||
for idx, row in enumerate(rows):
|
||||
if _timestamp(row) <= t:
|
||||
anchor_idx = idx
|
||||
else:
|
||||
break
|
||||
|
||||
target_idx = (offset - 1 if offset > 0 else None) if anchor_idx is None else anchor_idx + offset
|
||||
|
||||
if target_idx is None or target_idx < 0 or target_idx >= len(rows):
|
||||
return None
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None:
|
||||
if style is None:
|
||||
raise ValueError(f"{resolver_name} requires a persistent style.")
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
column_for_style(style)
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
) -> list[LanguageRow]:
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if (style is None or row.get("style") == style)
|
||||
and (role is None or row.get("role") == role)
|
||||
and (tool_name is None or _row_has_tool_name(row, tool_name))
|
||||
]
|
||||
|
||||
|
||||
def _select_latest(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
) -> LanguageRow | None:
|
||||
if not rows:
|
||||
return None
|
||||
rows = _sort_rows(rows)
|
||||
latest_ts = _timestamp(rows[-1])
|
||||
return _select_exact(
|
||||
[row for row in rows if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
|
||||
|
||||
def _select_exact(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
) -> LanguageRow | None:
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1 and role is None and tool_name is None:
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate."
|
||||
)
|
||||
return _sort_rows(rows)[0]
|
||||
|
||||
|
||||
def _sort_rows(rows: Sequence[LanguageRow]) -> list[LanguageRow]:
|
||||
return sorted(rows, key=lambda row: (_timestamp(row), row.get("style") or "", row.get("role") or ""))
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
value = row["timestamp"]
|
||||
return float(value.item() if hasattr(value, "item") else value)
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
for tool_call in row.get("tool_calls") or []:
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
function = tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
if isinstance(function, dict) and function.get("name") == tool_name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_rows(rows: Sequence[Any]) -> list[LanguageRow]:
|
||||
normalized = []
|
||||
for row in rows:
|
||||
if row is None:
|
||||
continue
|
||||
if hasattr(row, "as_py"):
|
||||
row = row.as_py()
|
||||
if not isinstance(row, dict):
|
||||
raise TypeError(f"Language rows must be dictionaries, got {type(row).__name__}.")
|
||||
normalized.append(dict(row))
|
||||
return normalized
|
||||
@@ -83,7 +83,6 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
|
||||
@@ -93,6 +93,7 @@ from .relative_action_processor import (
|
||||
to_relative_actions,
|
||||
)
|
||||
from .rename_processor import RenameObservationsProcessorStep, rename_stats
|
||||
from .render_messages_processor import RenderMessagesStep
|
||||
from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep
|
||||
|
||||
__all__ = [
|
||||
@@ -128,6 +129,7 @@ __all__ = [
|
||||
"make_default_robot_observation_processor",
|
||||
"AbsoluteActionsProcessorStep",
|
||||
"RelativeActionsProcessorStep",
|
||||
"RenderMessagesStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NewLineTaskProcessorStep",
|
||||
|
||||
@@ -174,6 +174,24 @@ class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep):
|
||||
task_index_value = complementary_data["task_index"]
|
||||
if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0:
|
||||
complementary_data["task_index"] = task_index_value.unsqueeze(0)
|
||||
|
||||
complementary_data.pop("language_persistent", None)
|
||||
complementary_data.pop("language_events", None)
|
||||
|
||||
if "messages" in complementary_data:
|
||||
messages = complementary_data["messages"]
|
||||
if isinstance(messages, list) and (not messages or isinstance(messages[0], dict)):
|
||||
complementary_data["messages"] = [messages]
|
||||
|
||||
if "message_streams" in complementary_data:
|
||||
streams = complementary_data["message_streams"]
|
||||
if isinstance(streams, list) and (not streams or isinstance(streams[0], str)):
|
||||
complementary_data["message_streams"] = [streams]
|
||||
|
||||
if "target_message_indices" in complementary_data:
|
||||
indices = complementary_data["target_message_indices"]
|
||||
if isinstance(indices, list) and (not indices or isinstance(indices[0], int)):
|
||||
complementary_data["target_message_indices"] = [indices]
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -171,8 +171,33 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
|
||||
language_persistent_key = (
|
||||
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
|
||||
)
|
||||
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
|
||||
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
|
||||
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
|
||||
target_message_indices_key = (
|
||||
{"target_message_indices": batch["target_message_indices"]}
|
||||
if "target_message_indices" in batch
|
||||
else {}
|
||||
)
|
||||
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {
|
||||
**pad_keys,
|
||||
**task_key,
|
||||
**subtask_key,
|
||||
**index_key,
|
||||
**task_index_key,
|
||||
**episode_index_key,
|
||||
**timestamp_key,
|
||||
**language_persistent_key,
|
||||
**language_events_key,
|
||||
**messages_key,
|
||||
**message_streams_key,
|
||||
**target_message_indices_key,
|
||||
}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
81
src/lerobot/processor/render_messages_processor.py
Normal file
81
src/lerobot/processor/render_messages_processor.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="render_messages_processor")
|
||||
class RenderMessagesStep(ProcessorStep):
|
||||
recipe: TrainingRecipe
|
||||
dataset_ctx: Any | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
if not persistent and not events:
|
||||
return transition
|
||||
|
||||
timestamp = complementary_data.get("timestamp")
|
||||
if timestamp is None:
|
||||
raise KeyError("RenderMessagesStep requires sample timestamp in complementary data.")
|
||||
|
||||
sample_idx = complementary_data.get("index", 0)
|
||||
rendered = render_sample(
|
||||
recipe=self.recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=_scalar(timestamp),
|
||||
sample_idx=int(_scalar(sample_idx)),
|
||||
task=complementary_data.get("task"),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data.pop(LANGUAGE_PERSISTENT, None)
|
||||
new_complementary_data.pop(LANGUAGE_EVENTS, None)
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
def _scalar(value: Any) -> float | int:
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list) and len(value) == 1:
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
@@ -47,6 +47,7 @@ from lerobot.datasets import EpisodeAwareSampler, make_dataset
|
||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.utils.collate import lerobot_collate_fn
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -386,6 +387,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
collate_fn=lerobot_collate_fn,
|
||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
)
|
||||
|
||||
48
src/lerobot/utils/collate.py
Normal file
48
src/lerobot/utils/collate.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from torch.utils.data._utils.collate import default_collate
|
||||
|
||||
from lerobot.datasets.language import LANGUAGE_COLUMNS
|
||||
|
||||
_PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
|
||||
|
||||
|
||||
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
|
||||
batch = [sample for sample in batch if sample is not None]
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
preserved = {
|
||||
key: [sample[key] for sample in batch if key in sample]
|
||||
for key in _PYTHON_LIST_KEYS
|
||||
if any(key in sample for sample in batch)
|
||||
}
|
||||
tensorizable = [
|
||||
{
|
||||
key: value
|
||||
for key, value in sample.items()
|
||||
if key not in _PYTHON_LIST_KEYS and key not in LANGUAGE_COLUMNS
|
||||
}
|
||||
for sample in batch
|
||||
]
|
||||
collated = default_collate(tensorizable)
|
||||
collated.update(preserved)
|
||||
return collated
|
||||
Reference in New Issue
Block a user