chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.constants import ACTION, OBS_IMAGE
from lerobot.utils.transition import Transition
@@ -467,7 +467,7 @@ class ReplayBuffer:
if list_transition:
first_transition = list_transition[0]
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition["action"].to(device)
first_action = first_transition[ACTION].to(device)
# Get complementary info if available
first_complementary_info = None
@@ -492,7 +492,7 @@ class ReplayBuffer:
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
action = data["action"]
action = data[ACTION]
replay_buffer.add(
state=data["state"],
@@ -530,8 +530,8 @@ class ReplayBuffer:
# Add "action"
sample_action = self.actions[0]
act_info = guess_feature_info(t=sample_action, name="action")
features["action"] = act_info
act_info = guess_feature_info(t=sample_action, name=ACTION)
features[ACTION] = act_info
# Add "reward" and "done"
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
@@ -577,7 +577,7 @@ class ReplayBuffer:
frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done
frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict[ACTION] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
frame_dict["task"] = task_name
@@ -668,7 +668,7 @@ class ReplayBuffer:
current_state[key] = val.unsqueeze(0) # Add batch dimension
# ----- 2) Action -----
action = current_sample["action"].unsqueeze(0) # Add batch dimension
action = current_sample[ACTION].unsqueeze(0) # Add batch dimension
# ----- 3) Reward and done -----
reward = float(current_sample["next.reward"].item()) # ensure float
@@ -788,8 +788,8 @@ def concatenate_batch_transitions(
}
# Concatenate basic fields
left_batch_transitions["action"] = torch.cat(
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
left_batch_transitions[ACTION] = torch.cat(
[left_batch_transitions[ACTION], right_batch_transition[ACTION]], dim=0
)
left_batch_transitions["reward"] = torch.cat(
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0