mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
refactor(constants, processor): standardize action and observation keys across multiple files (#1808)
- Added new constants for truncated and done states in constants.py. - Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability.
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE
|
||||
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.processor.pipeline import (
|
||||
EnvTransition,
|
||||
ObservationProcessor,
|
||||
@@ -156,10 +156,8 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
new_observation = dict(observation)
|
||||
|
||||
# Add tokenized data to observation
|
||||
new_observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
|
||||
new_observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
return new_observation
|
||||
|
||||
@@ -239,13 +237,13 @@ class TokenizerProcessor(ObservationProcessor):
|
||||
"""
|
||||
# Add features for tokenized output if they don't exist
|
||||
# Standard tokenizer output includes tokens and attention_mask
|
||||
tokens_key = f"{OBS_LANGUAGE}.tokens"
|
||||
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
|
||||
|
||||
if tokens_key not in features:
|
||||
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
if OBS_LANGUAGE_TOKENS not in features:
|
||||
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
|
||||
if attention_mask_key not in features:
|
||||
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||
if OBS_LANGUAGE_ATTENTION_MASK not in features:
|
||||
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user