chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -23,7 +23,7 @@ from typing import Any
import numpy as np
import torch
from lerobot.utils.constants import OBS_PREFIX
from lerobot.utils.constants import ACTION, OBS_PREFIX
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
@@ -344,7 +344,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
action = batch.get("action")
action = batch.get(ACTION)
if action is not None and not isinstance(action, PolicyAction):
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
@@ -354,7 +354,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
return create_transition(
observation=observation_keys if observation_keys else None,
action=batch.get("action"),
action=batch.get(ACTION),
reward=batch.get("next.reward", 0.0),
done=batch.get("next.done", False),
truncated=batch.get("next.truncated", False),
@@ -379,7 +379,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
batch = {
"action": transition.get(TransitionKey.ACTION),
ACTION: transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),

View File

@@ -59,6 +59,7 @@ from safetensors.torch import load_file as load_safetensors
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
from lerobot.utils.constants import ACTION
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
@@ -196,7 +197,7 @@ def detect_features_and_norm_modes(
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
elif ACTION in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE # Default
@@ -215,7 +216,7 @@ def detect_features_and_norm_modes(
feature_type = FeatureType.VISUAL
elif "state" in key or "joint" in key or "position" in key:
feature_type = FeatureType.STATE
elif "action" in key:
elif ACTION in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE
@@ -321,7 +322,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
feature_type = FeatureType.VISUAL
elif "state" in key:
feature_type = FeatureType.STATE
elif "action" in key:
elif ACTION in key:
feature_type = FeatureType.ACTION
else:
feature_type = FeatureType.STATE

View File

@@ -26,6 +26,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION
from .converters import from_tensor_to_numpy, to_tensor
from .core import EnvTransition, PolicyAction, TransitionKey
@@ -272,7 +273,7 @@ class _NormalizationMixin:
Returns:
The transformed action tensor.
"""
processed_action = self._apply_transform(action, "action", FeatureType.ACTION, inverse=inverse)
processed_action = self._apply_transform(action, ACTION, FeatureType.ACTION, inverse=inverse)
return processed_action
def _apply_transform(

View File

@@ -5,6 +5,7 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction
from lerobot.utils.constants import ACTION
@dataclass
@@ -23,7 +24,7 @@ class RobotActionToPolicyActionProcessorStep(ActionProcessorStep):
return asdict(self)
def transform_features(self, features):
features[PipelineFeatureType.ACTION]["action"] = PolicyFeature(
features[PipelineFeatureType.ACTION][ACTION] = PolicyFeature(
type=FeatureType.ACTION, shape=(len(self.motor_names),)
)
return features