mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 19:31:25 +00:00
chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -59,6 +59,7 @@ from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
@@ -196,7 +197,7 @@ def detect_features_and_norm_modes(
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
elif ACTION in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE # Default
|
||||
@@ -215,7 +216,7 @@ def detect_features_and_norm_modes(
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key or "joint" in key or "position" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
elif ACTION in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
@@ -321,7 +322,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
|
||||
feature_type = FeatureType.VISUAL
|
||||
elif "state" in key:
|
||||
feature_type = FeatureType.STATE
|
||||
elif "action" in key:
|
||||
elif ACTION in key:
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
feature_type = FeatureType.STATE
|
||||
|
||||
Reference in New Issue
Block a user