chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -23,7 +23,7 @@ from typing import Any
import numpy as np
import torch
from lerobot.utils.constants import OBS_PREFIX
from lerobot.utils.constants import ACTION, OBS_PREFIX
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
@@ -344,7 +344,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
if not isinstance(batch, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
action = batch.get("action")
action = batch.get(ACTION)
if action is not None and not isinstance(action, PolicyAction):
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
@@ -354,7 +354,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
return create_transition(
observation=observation_keys if observation_keys else None,
action=batch.get("action"),
action=batch.get(ACTION),
reward=batch.get("next.reward", 0.0),
done=batch.get("next.done", False),
truncated=batch.get("next.truncated", False),
@@ -379,7 +379,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
batch = {
"action": transition.get(TransitionKey.ACTION),
ACTION: transition.get(TransitionKey.ACTION),
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
"next.done": transition.get(TransitionKey.DONE, False),
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),