mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -80,6 +80,7 @@ from lerobot.transport.utils import (
|
||||
state_to_bytes,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
@@ -402,7 +403,7 @@ def add_actor_information_and_train(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
@@ -415,7 +416,7 @@ def add_actor_information_and_train(
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
@@ -460,7 +461,7 @@ def add_actor_information_and_train(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
@@ -474,7 +475,7 @@ def add_actor_information_and_train(
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
"action": actions,
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
@@ -1155,7 +1156,7 @@ def process_transitions(
|
||||
# Skip transitions with NaN values
|
||||
if check_nan_in_transition(
|
||||
observations=transition["state"],
|
||||
actions=transition["action"],
|
||||
actions=transition[ACTION],
|
||||
next_state=transition["next_state"],
|
||||
):
|
||||
logging.warning("[LEARNER] NaN detected in transition, skipping")
|
||||
|
||||
Reference in New Issue
Block a user