chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -75,7 +75,7 @@ import torch.utils.data
import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import OBS_STATE
from lerobot.utils.constants import ACTION, OBS_STATE
class EpisodeSampler(torch.utils.data.Sampler):
@@ -157,9 +157,9 @@ def visualize_dataset(
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
# display each dimension of action space (e.g. actuators command)
if "action" in batch:
for dim_idx, val in enumerate(batch["action"][i]):
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
if ACTION in batch:
for dim_idx, val in enumerate(batch[ACTION][i]):
rr.log(f"{ACTION}/{dim_idx}", rr.Scalar(val.item()))
# display each dimension of observed state space (e.g. agent position in joint space)
if OBS_STATE in batch:

View File

@@ -81,7 +81,7 @@ from lerobot.envs.utils import (
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.utils.constants import OBS_STR
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
@@ -213,7 +213,7 @@ def rollout(
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
ret = {
"action": torch.stack(all_actions, dim=1),
ACTION: torch.stack(all_actions, dim=1),
"reward": torch.stack(all_rewards, dim=1),
"success": torch.stack(all_successes, dim=1),
"done": torch.stack(all_dones, dim=1),
@@ -440,14 +440,14 @@ def _compile_episode_data(
"""
ep_dicts = []
total_frames = 0
for ep_ix in range(rollout_data["action"].shape[0]):
for ep_ix in range(rollout_data[ACTION].shape[0]):
# + 2 to include the first done frame and the last observation frame.
num_frames = done_indices[ep_ix].item() + 2
total_frames += num_frames
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
ACTION: rollout_data[ACTION][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,

View File

@@ -109,7 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401
so101_leader,
)
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
from lerobot.utils.constants import OBS_STR
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import (
init_keyboard_listener,
is_headless,
@@ -319,7 +319,7 @@ def record_loop(
robot_type=robot.robot_type,
)
action_names = dataset.features["action"]["names"]
action_names = dataset.features[ACTION]["names"]
act_processed_policy: RobotAction = {
f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
}
@@ -361,7 +361,7 @@ def record_loop(
# Write to dataset
if dataset is not None:
action_frame = build_dataset_frame(dataset.features, action_values, prefix="action")
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": single_task}
dataset.add_frame(frame)

View File

@@ -60,6 +60,7 @@ from lerobot.robots import ( # noqa: F401
so100_follower,
so101_follower,
)
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
@@ -99,7 +100,7 @@ def replay(cfg: ReplayConfig):
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
actions = episode_frames.select_columns("action")
actions = episode_frames.select_columns(ACTION)
robot.connect()
@@ -107,9 +108,9 @@ def replay(cfg: ReplayConfig):
for idx in range(len(episode_frames)):
start_episode_t = time.perf_counter()
action_array = actions[idx]["action"]
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features["action"]["names"]):
for i, name in enumerate(dataset.features[ACTION]["names"]):
action[name] = action_array[i]
robot_obs = robot.get_observation()