mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
add training
This commit is contained in:
@@ -168,10 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
user_prompt_key = {"user_prompt": batch["user_prompt"]} if "user_prompt" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **user_prompt_key, **subtask_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -47,7 +47,6 @@ class RenameObservationsProcessorStep(ObservationProcessorStep):
|
||||
processed_obs[self.rename_map[key]] = value
|
||||
else:
|
||||
processed_obs[key] = value
|
||||
|
||||
return processed_obs
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
|
||||
@@ -29,7 +29,14 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.constants import (
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_TOKENS,
|
||||
OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
from .core import EnvTransition, TransitionKey
|
||||
@@ -52,6 +59,9 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting
|
||||
token IDs and attention mask to the `observation` dictionary.
|
||||
|
||||
Optionally, this step can also tokenize a high-level task (e.g., user prompt) and/or
|
||||
a subtask if present in the complementary data, creating separate tokenized observations.
|
||||
|
||||
Requires the `transformers` library to be installed.
|
||||
|
||||
Attributes:
|
||||
@@ -59,6 +69,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored.
|
||||
max_length: The maximum length to pad or truncate sequences to.
|
||||
task_key: The key in `complementary_data` where the task string is stored.
|
||||
high_level_task_key: The key in `complementary_data` where the high-level task (user prompt) is stored.
|
||||
subtask_key: The key in `complementary_data` where the subtask string is stored.
|
||||
padding_side: The side to pad on ('left' or 'right').
|
||||
padding: The padding strategy ('max_length', 'longest', etc.).
|
||||
truncation: Whether to truncate sequences longer than `max_length`.
|
||||
@@ -69,6 +81,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency
|
||||
max_length: int = 512
|
||||
task_key: str = "task"
|
||||
high_level_task_key: str = "user_prompt"
|
||||
subtask_key: str = "subtask"
|
||||
padding_side: str = "right"
|
||||
padding: str = "max_length"
|
||||
truncation: bool = True
|
||||
@@ -121,6 +135,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
raise ValueError("Complementary data is None so no task can be extracted from it")
|
||||
|
||||
task = complementary_data[self.task_key]
|
||||
|
||||
if task is None:
|
||||
raise ValueError("Task extracted from Complementary data is None")
|
||||
|
||||
@@ -132,6 +147,60 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_high_level_task(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the high-level task description(s) from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of high-level task strings, or None if the high-level task key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
high_level_task = complementary_data.get(self.high_level_task_key)
|
||||
|
||||
if high_level_task is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(high_level_task, str):
|
||||
return [high_level_task]
|
||||
elif isinstance(high_level_task, list) and all(isinstance(t, str) for t in high_level_task):
|
||||
return high_level_task
|
||||
|
||||
return None
|
||||
|
||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the subtask description(s) from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
subtask = complementary_data.get(self.subtask_key)
|
||||
|
||||
if subtask is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(subtask, str):
|
||||
return [subtask]
|
||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||
return subtask
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
@@ -169,6 +238,40 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Also tokenize high-level task if available
|
||||
high_level_task = self.get_high_level_task(self.transition)
|
||||
if high_level_task is not None:
|
||||
# Tokenize the high-level task
|
||||
tokenized_high_level_prompt = self._tokenize_text(high_level_task)
|
||||
|
||||
# Move to the same device
|
||||
if target_device is not None:
|
||||
tokenized_high_level_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_high_level_prompt.items()
|
||||
}
|
||||
|
||||
# Add high-level tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = tokenized_high_level_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = tokenized_high_level_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Also tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
if subtask is not None:
|
||||
# Tokenize the subtask
|
||||
tokenized_subtask_prompt = self._tokenize_text(subtask)
|
||||
|
||||
# Move to the same device
|
||||
if target_device is not None:
|
||||
tokenized_subtask_prompt = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_subtask_prompt.items()
|
||||
}
|
||||
|
||||
# Add subtask tokenized data to the observation
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = tokenized_subtask_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = tokenized_subtask_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
@@ -229,6 +332,7 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
config = {
|
||||
"max_length": self.max_length,
|
||||
"task_key": self.task_key,
|
||||
"high_level_task_key": self.high_level_task_key,
|
||||
"padding_side": self.padding_side,
|
||||
"padding": self.padding,
|
||||
"truncation": self.truncation,
|
||||
@@ -267,4 +371,25 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
# Add features for high-level task tokens and attention mask if they don't already exist
|
||||
if OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_HIGH_LEVEL_TASK_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_SUBTASK_ONLY_TOKENS not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_TOKENS] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
if OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]:
|
||||
features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_SUBTASK_ONLY_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user