mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
@@ -75,7 +75,7 @@ import torch.utils.data
|
||||
import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
@@ -157,9 +157,9 @@ def visualize_dataset(
|
||||
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
||||
|
||||
# display each dimension of action space (e.g. actuators command)
|
||||
if "action" in batch:
|
||||
for dim_idx, val in enumerate(batch["action"][i]):
|
||||
rr.log(f"action/{dim_idx}", rr.Scalar(val.item()))
|
||||
if ACTION in batch:
|
||||
for dim_idx, val in enumerate(batch[ACTION][i]):
|
||||
rr.log(f"{ACTION}/{dim_idx}", rr.Scalar(val.item()))
|
||||
|
||||
# display each dimension of observed state space (e.g. agent position in joint space)
|
||||
if OBS_STATE in batch:
|
||||
|
||||
@@ -81,7 +81,7 @@ from lerobot.envs.utils import (
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -213,7 +213,7 @@ def rollout(
|
||||
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
ACTION: torch.stack(all_actions, dim=1),
|
||||
"reward": torch.stack(all_rewards, dim=1),
|
||||
"success": torch.stack(all_successes, dim=1),
|
||||
"done": torch.stack(all_dones, dim=1),
|
||||
@@ -440,14 +440,14 @@ def _compile_episode_data(
|
||||
"""
|
||||
ep_dicts = []
|
||||
total_frames = 0
|
||||
for ep_ix in range(rollout_data["action"].shape[0]):
|
||||
for ep_ix in range(rollout_data[ACTION].shape[0]):
|
||||
# + 2 to include the first done frame and the last observation frame.
|
||||
num_frames = done_indices[ep_ix].item() + 2
|
||||
total_frames += num_frames
|
||||
|
||||
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||
ACTION: rollout_data[ACTION][ep_ix, : num_frames - 1],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||
|
||||
@@ -109,7 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import (
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
@@ -319,7 +319,7 @@ def record_loop(
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
action_names = dataset.features["action"]["names"]
|
||||
action_names = dataset.features[ACTION]["names"]
|
||||
act_processed_policy: RobotAction = {
|
||||
f"{name}": float(action_values[i]) for i, name in enumerate(action_names)
|
||||
}
|
||||
@@ -361,7 +361,7 @@ def record_loop(
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix="action")
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
@@ -99,7 +100,7 @@ def replay(cfg: ReplayConfig):
|
||||
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode)
|
||||
actions = episode_frames.select_columns("action")
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
robot.connect()
|
||||
|
||||
@@ -107,9 +108,9 @@ def replay(cfg: ReplayConfig):
|
||||
for idx in range(len(episode_frames)):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action_array = actions[idx]["action"]
|
||||
action_array = actions[idx][ACTION]
|
||||
action = {}
|
||||
for i, name in enumerate(dataset.features["action"]["names"]):
|
||||
for i, name in enumerate(dataset.features[ACTION]["names"]):
|
||||
action[name] = action_array[i]
|
||||
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
Reference in New Issue
Block a user