mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -23,7 +23,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.utils.constants import OBS_PREFIX
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX
|
||||
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey
|
||||
|
||||
@@ -344,7 +344,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
action = batch.get("action")
|
||||
action = batch.get(ACTION)
|
||||
if action is not None and not isinstance(action, PolicyAction):
|
||||
raise ValueError(f"Action should be a PolicyAction type got {type(action)}")
|
||||
|
||||
@@ -354,7 +354,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition:
|
||||
|
||||
return create_transition(
|
||||
observation=observation_keys if observation_keys else None,
|
||||
action=batch.get("action"),
|
||||
action=batch.get(ACTION),
|
||||
reward=batch.get("next.reward", 0.0),
|
||||
done=batch.get("next.done", False),
|
||||
truncated=batch.get("next.truncated", False),
|
||||
@@ -379,7 +379,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}")
|
||||
|
||||
batch = {
|
||||
"action": transition.get(TransitionKey.ACTION),
|
||||
ACTION: transition.get(TransitionKey.ACTION),
|
||||
"next.reward": transition.get(TransitionKey.REWARD, 0.0),
|
||||
"next.done": transition.get(TransitionKey.DONE, False),
|
||||
"next.truncated": transition.get(TransitionKey.TRUNCATED, False),
|
||||
|
||||
Reference in New Issue
Block a user