refactor(constants, processor): standardize action and observation keys across multiple files (#1808)

- Added new constants for truncated and done states in constants.py.
- Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability.
This commit is contained in:
Adil Zouitine
2025-08-31 22:53:13 +02:00
committed by GitHub
parent 574a708950
commit 08fb310eaa
6 changed files with 123 additions and 114 deletions

View File

@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.processor.pipeline import (
EnvTransition,
ObservationProcessor,
@@ -156,10 +156,8 @@ class TokenizerProcessor(ObservationProcessor):
new_observation = dict(observation)
# Add tokenized data to observation
new_observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
new_observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
dtype=torch.bool
)
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
return new_observation
@@ -239,13 +237,13 @@ class TokenizerProcessor(ObservationProcessor):
"""
# Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask
tokens_key = f"{OBS_LANGUAGE}.tokens"
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
if tokens_key not in features:
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if OBS_LANGUAGE_TOKENS not in features:
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if attention_mask_key not in features:
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if OBS_LANGUAGE_ATTENTION_MASK not in features:
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
return features