mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
Compare commits
6 Commits
feat/relat
...
feat/inter
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a178ddb240 | ||
|
|
498e215444 | ||
|
|
c85f1692d6 | ||
|
|
9fd329713a | ||
|
|
97d068e5a2 | ||
|
|
e5bea36387 |
@@ -33,6 +33,11 @@ Example usage:
|
||||
python examples/openarms/evaluate_with_rtc.py \
|
||||
--rtc.execution_horizon=12 \
|
||||
--rtc.max_guidance_weight=10.0
|
||||
|
||||
# With action interpolation (policy at 30Hz, robot at 50Hz)
|
||||
python examples/openarms/evaluate_with_rtc.py \
|
||||
--action_interpolation_enabled=true \
|
||||
--control_hz=50
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -82,6 +87,8 @@ DEFAULT_FPS = 30
|
||||
DEFAULT_EPISODE_TIME_SEC = 300
|
||||
DEFAULT_RESET_TIME_SEC = 60
|
||||
|
||||
DEFAULT_CONTROL_HZ = 50
|
||||
|
||||
DEFAULT_FOLLOWER_LEFT_PORT = "can0"
|
||||
DEFAULT_FOLLOWER_RIGHT_PORT = "can1"
|
||||
|
||||
@@ -167,6 +174,9 @@ class OpenArmsRTCEvalConfig(HubMixin):
|
||||
record_dataset: bool = True
|
||||
push_to_hub: bool = True
|
||||
|
||||
action_interpolation_enabled: bool = False
|
||||
control_hz: float = DEFAULT_CONTROL_HZ
|
||||
|
||||
use_torch_compile: bool = False
|
||||
torch_compile_backend: str = "inductor"
|
||||
torch_compile_mode: str = "default"
|
||||
@@ -309,6 +319,11 @@ def get_actions_thread(
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _interpolate_actions(prev_action: Tensor, next_action: Tensor, alpha: float) -> Tensor:
|
||||
"""Linear interpolation between two action tensors."""
|
||||
return prev_action + alpha * (next_action - prev_action)
|
||||
|
||||
|
||||
def actor_thread(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
@@ -324,49 +339,101 @@ def actor_thread(
|
||||
"""Thread function to execute actions on the robot."""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
logger.info(f"[ACTOR] interpolation={cfg.action_interpolation_enabled}, control_hz={cfg.control_hz}")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
|
||||
if cfg.action_interpolation_enabled:
|
||||
control_interval = 1.0 / cfg.control_hz
|
||||
interp_steps = int(cfg.control_hz / cfg.fps)
|
||||
else:
|
||||
control_interval = 1.0 / cfg.fps
|
||||
interp_steps = 1
|
||||
|
||||
prev_action: Tensor | None = None
|
||||
current_action: Tensor | None = None
|
||||
interp_step = 0
|
||||
last_dataset_frame_time = 0.0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not episode_active.is_set():
|
||||
prev_action = None
|
||||
current_action = None
|
||||
interp_step = 0
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
start_time = time.perf_counter()
|
||||
action = action_queue.get()
|
||||
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
if cfg.action_interpolation_enabled:
|
||||
if interp_step == 0 or current_action is None:
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
prev_action = current_action if current_action is not None else new_action.cpu()
|
||||
current_action = new_action.cpu()
|
||||
interp_step = 0
|
||||
|
||||
action_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action):
|
||||
action_dict[key] = action[i].item()
|
||||
if current_action is not None:
|
||||
if prev_action is not None and interp_steps > 1:
|
||||
alpha = (interp_step + 1) / interp_steps
|
||||
action_to_send = _interpolate_actions(prev_action, current_action, alpha)
|
||||
else:
|
||||
action_to_send = current_action
|
||||
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
action_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action_to_send):
|
||||
action_dict[key] = action_to_send[i].item()
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
with dataset_lock:
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
action_for_dataset = teleop_action_processor((action_dict, None))
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
action_count += 1
|
||||
|
||||
frame = {}
|
||||
for key, value in obs_processed.items():
|
||||
frame[f"observation.{key}"] = value
|
||||
for key, value in action_for_dataset.items():
|
||||
frame[f"action.{key}"] = value
|
||||
frame["task"] = cfg.task
|
||||
interp_step = (interp_step + 1) % interp_steps
|
||||
|
||||
dataset.add_frame(frame)
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
if time.perf_counter() - last_dataset_frame_time >= (1.0 / cfg.fps):
|
||||
last_dataset_frame_time = time.perf_counter()
|
||||
with dataset_lock:
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
action_for_dataset = teleop_action_processor((action_dict, None))
|
||||
frame = {}
|
||||
for key, value in obs_processed.items():
|
||||
frame[f"observation.{key}"] = value
|
||||
for key, value in action_for_dataset.items():
|
||||
frame[f"action.{key}"] = value
|
||||
frame["task"] = cfg.task
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
action = action_queue.get()
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action):
|
||||
action_dict[key] = action[i].item()
|
||||
|
||||
action_count += 1
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
action_count += 1
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
with dataset_lock:
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
action_for_dataset = teleop_action_processor((action_dict, None))
|
||||
frame = {}
|
||||
for key, value in obs_processed.items():
|
||||
frame[f"observation.{key}"] = value
|
||||
for key, value in action_for_dataset.items():
|
||||
frame[f"action.{key}"] = value
|
||||
frame["task"] = cfg.task
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
sleep_time = max(0, action_interval - dt_s - 0.001)
|
||||
sleep_time = max(0, control_interval - dt_s - 0.001)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
@@ -434,6 +501,9 @@ def main(cfg: OpenArmsRTCEvalConfig):
|
||||
print(f"RTC Enabled: {cfg.rtc.enabled}")
|
||||
print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}")
|
||||
print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
|
||||
print(f"Action Interpolation: {cfg.action_interpolation_enabled}")
|
||||
if cfg.action_interpolation_enabled:
|
||||
print(f"Control Hz: {cfg.control_hz}")
|
||||
print(f"Device: {cfg.device}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
152
examples/openarms/unify_task.py
Normal file
152
examples/openarms/unify_task.py
Normal file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unify all tasks in a dataset to a single task (modifies in-place).
|
||||
|
||||
This script:
|
||||
1. Loads a dataset
|
||||
2. Sets all task_index to 0 and task description to "fold"
|
||||
3. Updates tasks.parquet and task_index in data files (in-place, no copying)
|
||||
|
||||
Usage:
|
||||
python examples/openarms/unify_task.py --repo-id lerobot-data-collection/level1_rac1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DATA_DIR,
|
||||
write_info,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
|
||||
# Single unified task
|
||||
UNIFIED_TASK = "fold"
|
||||
|
||||
|
||||
def unify_dataset_tasks(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
push_to_hub: bool = False,
|
||||
) -> None:
|
||||
"""Unify all tasks in a dataset to a single task (modifies in-place).
|
||||
|
||||
Args:
|
||||
repo_id: Dataset repository ID.
|
||||
root: Optional root path for dataset.
|
||||
push_to_hub: Whether to push the result to HuggingFace Hub.
|
||||
"""
|
||||
input_root = root if root else HF_LEROBOT_HOME / repo_id
|
||||
input_repo_id = repo_id
|
||||
|
||||
logging.info(f"Loading metadata from {repo_id}")
|
||||
|
||||
# Load source metadata
|
||||
src_meta = LeRobotDatasetMetadata(repo_id, root=input_root)
|
||||
|
||||
logging.info(f"Source dataset: {src_meta.total_episodes} episodes, {src_meta.total_frames} frames")
|
||||
logging.info(f"Original tasks: {len(src_meta.tasks)}")
|
||||
|
||||
# Modify in-place (input_root == output_root supported)
|
||||
data_dir = input_root / DATA_DIR
|
||||
|
||||
# Process data files - set all task_index to 0
|
||||
logging.info("Processing data files (in-place)...")
|
||||
for parquet_file in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Processing data"):
|
||||
df = pd.read_parquet(parquet_file)
|
||||
df["task_index"] = 0 # All tasks unified to index 0
|
||||
df.to_parquet(parquet_file)
|
||||
|
||||
# Process episodes metadata - set all tasks to unified task
|
||||
logging.info("Processing episodes metadata (in-place)...")
|
||||
episodes_dir = input_root / "meta" / "episodes"
|
||||
if episodes_dir.exists():
|
||||
for parquet_file in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Processing episodes"):
|
||||
df = pd.read_parquet(parquet_file)
|
||||
df["tasks"] = [[UNIFIED_TASK]] * len(df) # All episodes get the unified task
|
||||
df.to_parquet(parquet_file)
|
||||
else:
|
||||
logging.warning(f"No episodes directory found at {episodes_dir}, skipping")
|
||||
|
||||
# Update tasks.parquet with single task
|
||||
logging.info(f"Creating single task: {UNIFIED_TASK}")
|
||||
new_tasks = pd.DataFrame({"task_index": [0]}, index=[UNIFIED_TASK])
|
||||
write_tasks(new_tasks, input_root)
|
||||
|
||||
# Update info.json
|
||||
new_info = src_meta.info.copy()
|
||||
new_info["total_tasks"] = 1
|
||||
write_info(new_info, input_root)
|
||||
|
||||
logging.info(f"Dataset modified in-place at {input_root}")
|
||||
logging.info(f"Task: {UNIFIED_TASK}")
|
||||
|
||||
if push_to_hub:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logging.info(f"Pushing {input_repo_id} to hub")
|
||||
dataset = LeRobotDataset(input_repo_id, root=input_root)
|
||||
dataset.push_to_hub(private=True)
|
||||
logging.info("Push complete!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Unify all tasks in a dataset to a single task 'fold' (modifies in-place)."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Dataset repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Optional root path (defaults to HF_LEROBOT_HOME/repo_id)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push result to HuggingFace Hub",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
unify_dataset_tasks(
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user